imerge 'github/tensorflow': automatic merge 125-7
diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def
index 983107f..329d744 100644
--- a/include/swift/AST/DiagnosticsSema.def
+++ b/include/swift/AST/DiagnosticsSema.def
@@ -3391,6 +3391,16 @@
ERROR(missing_builtin_precedence_group,none,
"broken standard library: missing builtin precedence group %0",
(Identifier))
+WARNING(nan_comparison, none,
+ "comparison with '.nan' using %0 is always %select{false|true}1, use "
+ "'%2.isNaN' to check if '%3' %select{is not a number|is a number}1",
+ (Identifier, bool, StringRef, StringRef))
+WARNING(nan_comparison_without_isnan, none,
+ "comparison with '.nan' using %0 is always %select{false|true}1",
+ (Identifier, bool))
+WARNING(nan_comparison_both_nan, none,
+ "'.nan' %0 '.nan' is always %select{false|true}1",
+ (StringRef, bool))
// If you change this, also change enum TryKindForDiagnostics.
#define TRY_KIND_SELECT(SUB) "%select{try|try!|try?|await}" #SUB
diff --git a/include/swift/AST/Identifier.h b/include/swift/AST/Identifier.h
index 85e4101..2aa585e 100644
--- a/include/swift/AST/Identifier.h
+++ b/include/swift/AST/Identifier.h
@@ -109,7 +109,14 @@
// Handle the high unicode case out of line.
return isOperatorSlow();
}
-
+
+ // Returns whether this is a standard comparison operator,
+ // such as '==', '>=' or '!=='.
+ bool isStandardComparisonOperator() const {
+ return is("==") || is("!=") || is("===") || is("!==") || is("<") ||
+ is(">") || is("<=") || is(">=");
+ }
+
/// isOperatorStartCodePoint - Return true if the specified code point is a
/// valid start of an operator.
static bool isOperatorStartCodePoint(uint32_t C) {
diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def
index 57ebf2e..0166353 100644
--- a/include/swift/AST/KnownProtocols.def
+++ b/include/swift/AST/KnownProtocols.def
@@ -87,6 +87,8 @@
PROTOCOL(AdditiveArithmetic)
PROTOCOL(Differentiable)
+PROTOCOL(FloatingPoint)
+
// SWIFT_ENABLE_TENSORFLOW
PROTOCOL(PointwiseMultiplicative)
PROTOCOL(ElementaryFunctions)
diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp
index 845a725..1d0814b 100644
--- a/lib/IRGen/GenMeta.cpp
+++ b/lib/IRGen/GenMeta.cpp
@@ -5044,6 +5044,7 @@
case KnownProtocolKind::StringInterpolationProtocol:
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::Differentiable:
+ case KnownProtocolKind::FloatingPoint:
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::PointwiseMultiplicative:
case KnownProtocolKind::ElementaryFunctions:
diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp
index 6c48dd1..1579371 100644
--- a/lib/Sema/CSDiagnostics.cpp
+++ b/lib/Sema/CSDiagnostics.cpp
@@ -165,7 +165,7 @@
bool FailureDiagnostic::conformsToKnownProtocol(
Type type, KnownProtocolKind protocol) const {
auto &cs = getConstraintSystem();
- return constraints::conformsToKnownProtocol(cs, type, protocol);
+ return constraints::conformsToKnownProtocol(cs.DC, type, protocol);
}
Type RequirementFailure::getOwnerType() const {
diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp
index 44cbfd0..1035ef1 100644
--- a/lib/Sema/CSGen.cpp
+++ b/lib/Sema/CSGen.cpp
@@ -1767,11 +1767,11 @@
auto type = contextualType->lookThroughAllOptionalTypes();
if (conformsToKnownProtocol(
- CS, type, KnownProtocolKind::ExpressibleByArrayLiteral))
+ CS.DC, type, KnownProtocolKind::ExpressibleByArrayLiteral))
return false;
return conformsToKnownProtocol(
- CS, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
+ CS.DC, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
};
if (isDictionaryContextualType(contextualType)) {
diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp
index 2fd8123..f7aae30 100644
--- a/lib/Sema/ConstraintSystem.cpp
+++ b/lib/Sema/ConstraintSystem.cpp
@@ -4191,11 +4191,11 @@
doesMemberRefApplyCurriedSelf(baseType, decl);
}
-bool constraints::conformsToKnownProtocol(ConstraintSystem &cs, Type type,
+bool constraints::conformsToKnownProtocol(DeclContext *dc, Type type,
KnownProtocolKind protocol) {
if (auto *proto =
- TypeChecker::getProtocol(cs.getASTContext(), SourceLoc(), protocol))
- return (bool)TypeChecker::conformsToProtocol(type, proto, cs.DC);
+ TypeChecker::getProtocol(dc->getASTContext(), SourceLoc(), protocol))
+ return (bool)TypeChecker::conformsToProtocol(type, proto, dc);
return false;
}
@@ -4220,7 +4220,8 @@
ConstraintSystem &cs, Type type,
KnownProtocolKind rawRepresentableProtocol) {
Type rawTy = isRawRepresentable(cs, type);
- if (!rawTy || !conformsToKnownProtocol(cs, rawTy, rawRepresentableProtocol))
+ if (!rawTy ||
+ !conformsToKnownProtocol(cs.DC, rawTy, rawRepresentableProtocol))
return Type();
return rawTy;
@@ -4514,9 +4515,7 @@
if (!expr) return false;
if (auto opName = getOperatorName(expr)) {
- return opName->is("==") || opName->is("!=") || opName->is("===") ||
- opName->is("!==") || opName->is("<") || opName->is(">") ||
- opName->is("<=") || opName->is(">=");
+ return opName->isStandardComparisonOperator();
}
return false;
}
diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h
index 019fffc..4adce22 100644
--- a/lib/Sema/ConstraintSystem.h
+++ b/lib/Sema/ConstraintSystem.h
@@ -5521,7 +5521,7 @@
llvm::function_ref<Type(Type)> getFixedType);
/// Check whether type conforms to a given known protocol.
-bool conformsToKnownProtocol(ConstraintSystem &cs, Type type,
+bool conformsToKnownProtocol(DeclContext *dc, Type type,
KnownProtocolKind protocol);
/// Check whether given type conforms to `RawPepresentable` protocol
diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp
index ed1bd72..4aa6289 100644
--- a/lib/Sema/MiscDiagnostics.cpp
+++ b/lib/Sema/MiscDiagnostics.cpp
@@ -15,8 +15,9 @@
//===----------------------------------------------------------------------===//
#include "MiscDiagnostics.h"
-#include "TypeChecker.h"
+#include "ConstraintSystem.h"
#include "TypeCheckAvailability.h"
+#include "TypeChecker.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/NameLookupRequests.h"
@@ -34,6 +35,7 @@
#define DEBUG_TYPE "Sema"
using namespace swift;
+using namespace constraints;
/// Return true if this expression is an implicit promotion from T to T?.
static Expr *isImplicitPromotionToOptional(Expr *E) {
@@ -4438,6 +4440,131 @@
const_cast<Expr *>(E)->walk(Walker);
}
+static void diagnoseComparisonWithNaN(const Expr *E, const DeclContext *DC) {
+ class ComparisonWithNaNFinder : public ASTWalker {
+ const ASTContext &C;
+ const DeclContext *DC;
+
+ public:
+ ComparisonWithNaNFinder(const DeclContext *dc)
+ : C(dc->getASTContext()), DC(dc) {}
+
+ void tryDiagnoseComparisonWithNaN(BinaryExpr *BE) {
+ ValueDecl *comparisonDecl = nullptr;
+
+ // Comparison functions like == or <= take two arguments.
+ if (BE->getArg()->getNumElements() != 2) {
+ return;
+ }
+
+ // Dig out the function declaration.
+ if (auto Fn = BE->getFn()) {
+ if (auto DSCE = dyn_cast<DotSyntaxCallExpr>(Fn)) {
+ comparisonDecl = DSCE->getCalledValue();
+ } else {
+ comparisonDecl = BE->getCalledValue();
+ }
+ }
+
+ // Bail out if it isn't a function.
+ if (!comparisonDecl || !isa<FuncDecl>(comparisonDecl)) {
+ return;
+ }
+
+ // We're only interested in comparison functions like == or <=.
+ auto comparisonDeclName = comparisonDecl->getBaseIdentifier();
+ if (!comparisonDeclName.isStandardComparisonOperator()) {
+ return;
+ }
+
+ auto firstArg = BE->getArg()->getElement(0);
+ auto secondArg = BE->getArg()->getElement(1);
+
+ // Both arguments must conform to FloatingPoint protocol.
+ if (!conformsToKnownProtocol(const_cast<DeclContext *>(DC),
+ firstArg->getType(),
+ KnownProtocolKind::FloatingPoint) ||
+ !conformsToKnownProtocol(const_cast<DeclContext *>(DC),
+ secondArg->getType(),
+ KnownProtocolKind::FloatingPoint)) {
+ return;
+ }
+
+ // Convenience utility to extract argument decl.
+ auto extractArgumentDecl = [&](Expr *arg) -> ValueDecl * {
+ if (auto DRE = dyn_cast<DeclRefExpr>(arg)) {
+ return DRE->getDecl();
+ } else if (auto MRE = dyn_cast<MemberRefExpr>(arg)) {
+ return MRE->getMember().getDecl();
+ }
+ return nullptr;
+ };
+
+ // Dig out the declarations for the arguments.
+ auto *firstVal = extractArgumentDecl(firstArg);
+ auto *secondVal = extractArgumentDecl(secondArg);
+
+ // If we can't find declarations for both arguments, bail out,
+ // because one of them has to be '.nan'.
+ if (!firstArg && !secondArg) {
+ return;
+ }
+
+ // Convenience utility to check if this is a 'nan' variable.
+ auto isNanDecl = [&](ValueDecl *VD) {
+ return VD && isa<VarDecl>(VD) && VD->getBaseIdentifier().is("nan");
+ };
+
+ // Diagnose comparison with '.nan'.
+ //
+ // If the comparison is done using '<=', '<', '==', '>', '>=', then
+ // the result is always false. If the comparison is done using '!=',
+ // then the result is always true.
+ //
+ // Emit a different diagnostic which doesn't mention using '.isNaN' if
+ // the comparison isn't done using '==' or '!=' or if both sides are
+ // '.nan'.
+ if (isNanDecl(firstVal) && isNanDecl(secondVal)) {
+ C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_both_nan,
+ comparisonDeclName.str(), comparisonDeclName.is("!="));
+ } else if (isNanDecl(firstVal) || isNanDecl(secondVal)) {
+ if (comparisonDeclName.is("==") || comparisonDeclName.is("!=")) {
+ auto exprStr =
+ C.SourceMgr
+ .extractText(Lexer::getCharSourceRangeFromSourceRange(
+ C.SourceMgr, firstArg->getSourceRange()))
+ .str();
+ auto prefix = exprStr;
+ if (comparisonDeclName.is("!=")) {
+ prefix = "!" + prefix;
+ }
+ C.Diags.diagnose(BE->getLoc(), diag::nan_comparison,
+ comparisonDeclName, comparisonDeclName.is("!="),
+ prefix, exprStr);
+ } else {
+ C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_without_isnan,
+ comparisonDeclName, comparisonDeclName.is("!="));
+ }
+ }
+ }
+
+ std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
+ if (!E || isa<ErrorExpr>(E) || !E->getType())
+ return {false, E};
+
+ if (auto *BE = dyn_cast<BinaryExpr>(E)) {
+ tryDiagnoseComparisonWithNaN(BE);
+ return {false, E};
+ }
+
+ return {true, E};
+ }
+ };
+
+ ComparisonWithNaNFinder Walker(DC);
+ const_cast<Expr *>(E)->walk(Walker);
+}
+
//===----------------------------------------------------------------------===//
// High-level entry points.
//===----------------------------------------------------------------------===//
@@ -4454,6 +4581,7 @@
diagnoseUnintendedOptionalBehavior(E, DC);
maybeDiagnoseCallToKeyValueObserveMethod(E, DC);
diagnoseExplicitUseOfLazyVariableStorage(E, DC);
+ diagnoseComparisonWithNaN(E, DC);
if (!ctx.isSwiftVersionAtLeast(5))
diagnoseDeprecatedWritableKeyPath(E, DC);
if (!ctx.LangOpts.DisableAvailabilityChecking)
diff --git a/test/decl/var/nan_comparisons.swift b/test/decl/var/nan_comparisons.swift
new file mode 100644
index 0000000..3dc8ddf
--- /dev/null
+++ b/test/decl/var/nan_comparisons.swift
@@ -0,0 +1,31 @@
+// RUN: %target-typecheck-verify-swift
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+/////// Comparison with '.nan' static property instead of using '.isNaN' instance property ///////
+//////////////////////////////////////////////////////////////////////////////////////////////////
+
+// One side is '.nan' and the other isn't.
+// Using '==' or '!=' for comparison should suggest using '.isNaN'.
+
+let double: Double = 0.0
+_ = double == .nan // expected-warning {{comparison with '.nan' using '==' is always false, use 'double.isNaN' to check if 'double' is not a number}}
+_ = double != .nan // expected-warning {{comparison with '.nan' using '!=' is always true, use '!double.isNaN' to check if 'double' is a number}}
+_ = 0.0 == .nan // // expected-warning {{comparison with '.nan' using '==' is always false, use '0.0.isNaN' to check if '0.0' is not a number}}
+
+// One side is '.nan' and the other isn't. Using '>=', '>', '<', '<=' for comparison:
+// We can't suggest using '.isNaN' here.
+
+_ = 0.0 >= .nan // expected-warning {{comparison with '.nan' using '>=' is always false}}
+_ = .nan > 1.1 // expected-warning {{comparison with '.nan' using '>' is always false}}
+_ = .nan < 2.2 // expected-warning {{comparison with '.nan' using '<' is always false}}
+_ = 3.3 <= .nan // expected-warning {{comparison with '.nan' using '<=' is always false}}
+
+// Both sides are '.nan':
+// We can't suggest using '.isNaN' here.
+
+_ = Double.nan == Double.nan // expected-warning {{'.nan' == '.nan' is always false}}
+_ = Double.nan != Double.nan // expected-warning {{'.nan' != '.nan' is always true}}
+_ = Double.nan < Double.nan // expected-warning {{'.nan' < '.nan' is always false}}
+_ = Double.nan <= Double.nan // expected-warning {{'.nan' <= '.nan' is always false}}
+_ = Double.nan > Double.nan // expected-warning {{'.nan' > '.nan' is always false}}
+_ = Double.nan >= Double.nan // expected-warning {{'.nan' >= '.nan' is always false}}