Merge github/tensorflow into copy of github/master (using imerge)
diff --git a/.gitignore b/.gitignore
index 2c30c45..33a3ec0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,6 +18,7 @@
# vim swap files
.*.sw[a-z]
.sw?
+*.vscode
#==============================================================================#
# Explicit files to ignore (only matches one).
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a84adb0..2d834ba 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -86,6 +86,12 @@
"If not building stdlib, controls whether to build 'stdlib/toolchain' content"
TRUE)
+# `SWIFT_ENABLE_TENSOFRFLOW` affects lit tests, adding "tensorflow" as an
+# available feature.
+option(SWIFT_ENABLE_TENSORFLOW
+ "Enable TensorFlow in the compiler"
+ FALSE)
+
# In many cases, the CMake build system needs to determine whether to include
# a directory, or perform other actions, based on whether the stdlib or SDK is
# being built at all -- statically or dynamically. Please note that these
@@ -374,6 +380,19 @@
FALSE)
#
+# User-configurable TensorFlow specific options.
+#
+set(SWIFT_TENSORFLOW_HOST_LIB_DIR "" CACHE PATH
+ "If set, directory with TensorFlow host libraries to be linked into swiftc.")
+set(SWIFT_TENSORFLOW_TARGET_LIB_DIR "" CACHE PATH
+ "If set, directory with TensorFlow target libraries to be linked into swift programs.")
+
+set(SWIFT_TENSORFLOW_HOST_INCLUDE_DIR "" CACHE PATH
+ "If set, directory with host TensorFlow headers.")
+set(SWIFT_TENSORFLOW_TARGET_INCLUDE_DIR "" CACHE PATH
+ "If set, directory with target TensorFlow headers.")
+
+#
# User-configurable experimental options. Do not use in production builds.
#
@@ -1084,6 +1103,11 @@
endif()
endif()
+# Enable TensorFlow in compiler code.
+if(SWIFT_ENABLE_TENSORFLOW)
+ add_definitions(-DSWIFT_ENABLE_TENSORFLOW)
+endif()
+
# Add all of the subdirectories, where we actually do work.
###############
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 0283424..24a2d10 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -5,7 +5,23 @@
---
-Before submitting the pull request, please make sure you have [tested your
-changes](https://github.com/apple/swift/blob/master/docs/ContinuousIntegration.md)
-and that they follow the Swift project [guidelines for contributing
-code](https://swift.org/contributing/#contributing-code).
+Before submitting a pull request, please make sure that your changes follow the
+[Swift project guidelines for contributing code](https://swift.org/contributing/#contributing-code).
+
+We request that code changes unrelated to Swift for TensorFlow be submitted to
+the [upstream Swift repository](https://github.com/apple/swift). For example,
+code formatting changes that do not affect Swift for TensorFlow source code
+should be submitted upstream.
+
+It is a good idea to discuss any non-trivial submissions with the project
+maintainers before submitting a pull request: please join the
+[swift@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/swift)
+mailing list to participate in discussions.
+
+All changes to existing Swift source code should be marked clearly with a
+`SWIFT_ENABLE_TENSORFLOW` comment at the top of every diff hunk. This makes
+it easier to merge from upstream.
+
+Continuous integration (CI) is still being set up, so if you submit a pull
+request with non-trivial changes that require test coverage, please anticipate
+some delay before we can properly review and merge it.
diff --git a/README.md b/README.md
index 6f11e2b..7edd94e 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,9 @@
-<img src="https://swift.org/assets/images/swift.svg" alt="Swift logo" height="70" >
+# Swift for TensorFlow
-# Swift Programming Language
-
+| OS | CI platform | x86_64 | GPU |
+|---|:---:|:---:|:---:|
+| **macOS** | Google Kokoro |  | - |
+| **Ubuntu 16.04** | Swift.org CI | [](https://ci-external.swift.org/job/oss-swift-RA-linux-ubuntu-16.04-tensorflow) | [](https://ci-external.swift.org/job/oss-swift-RA-linux-ubuntu-16.04-tensorflow-gpu) |
| | **Architecture** | **Master** | **Package** |
|---|:---:|:---:|:---:|
@@ -34,6 +36,13 @@
|**[macOS 10.13](https://github.com/apple/swift-community-hosted-continuous-integration/blob/master/nodes/x86_64_macos_high_sierra_tensorflow.json)** | x86_64 |[](https://ci-external.swift.org/job/oss-swift-RA-macOS-tensorflow)|
|**[Ubuntu 16.04 (GPU)](https://github.com/apple/swift-community-hosted-continuous-integration/blob/master/nodes/x86_64_ubuntu_16_04_tensorflow_gpu.json)** | x86_64 |[](https://ci-external.swift.org/job/oss-swift-RA-linux-ubuntu-16.04-tensorflow-gpu)|
+<!-- SWIFT_ENABLE_TENSORFLOW -->
+
+Swift for TensorFlow is a new programming language for TensorFlow. It is a copy of the compiler for the [Swift Programming Language](https://swift.org) that adds first-class compiler and language support for machine learning.
+
+This repository covers the compiler and standard libraries. Please visit the [documentation repository](https://github.com/tensorflow/swift) for more information about the project, including a project overview, technical details, and guidelines for contributing. To use Swift for TensorFlow out of the box, follow the [installation instructions](https://github.com/tensorflow/swift/blob/master/Installation.md). To build from source, follow the instructions below.
+<!-- SWIFT_ENABLE_TENSORFLOW END -->
+
## Welcome to Swift
Swift is a high-performance system programming language. It has a clean
@@ -82,37 +91,36 @@
[Getting Started guide]: /docs/HowToGuides/GettingStarted.md
-### Swift Toolchains
+### Swift For TensorFlow Toolchains
#### Building
-Swift toolchains are created using the script
-[build-toolchain](https://github.com/apple/swift/blob/master/utils/build-toolchain). This
-script is used by swift.org's CI to produce snapshots and can allow for one to
+Swift for TensorFlow toolchains are created using the script
+[build-toolchain-tensorflow](https://github.com/apple/swift/blob/tensorflow/utils/build-toolchain-tensorflow).
+This script is used by swift.org's CI to produce snapshots and can allow for one to
locally reproduce such builds for development or distribution purposes. A typical
invocation looks like the following:
```
- $ ./swift/utils/build-toolchain $BUNDLE_PREFIX
+ $ ./swift/utils/build-toolchain-tensorflow $BUNDLE_PREFIX
```
-where ``$BUNDLE_PREFIX`` is a string that will be prepended to the build
-date to give the bundle identifier of the toolchain's ``Info.plist``. For
-instance, if ``$BUNDLE_PREFIX`` was ``com.example``, the toolchain
-produced will have the bundle identifier ``com.example.YYYYMMDD``. It
-will be created in the directory you run the script with a filename
-of the form: ``swift-LOCAL-YYYY-MM-DD-a-osx.tar.gz``.
+where ``$BUNDLE_PREFIX`` is a string that will be prepended to the build
+date to give the bundle identifier of the toolchain's ``Info.plist``. For
+instance, if ``$BUNDLE_PREFIX`` was ``com.example``, the toolchain
+produced will have the bundle identifier ``com.example.YYYYMMDD``. It
+will be created in the directory you run the script with a filename
+ of the form: ``swift-tensorflow-LOCAL-YYYY-MM-DD-a-osx.tar.gz``.
-Beyond building the toolchain, ``build-toolchain`` also supports the
-following (non-exhaustive) set of useful options::
+Beyond building the toolchain, ``build-toolchain-tensorflow`` also supports the
+following (non-exhaustive) set of useful options:
- ``--dry-run``: Perform a dry run build. This is off by default.
- ``--test``: Test the toolchain after it has been compiled. This is off by default.
-- ``--distcc``: Use distcc to speed up the build by distributing the c++ part of
- the swift build. This is off by default.
+- ``--pkg`` (macOS only): Build a toolchain installer package (`.pkg`). This is off by default.
More options may be added over time. Please pass ``--help`` to
-``build-toolchain`` to see the full set of options.
+``build-toolchain-tensorflow`` to see the full set of options.
#### Installing into Xcode
diff --git a/cmake/modules/FindTensorFlow.cmake b/cmake/modules/FindTensorFlow.cmake
new file mode 100644
index 0000000..9b4c0fe
--- /dev/null
+++ b/cmake/modules/FindTensorFlow.cmake
@@ -0,0 +1,23 @@
+# SWIFT_ENABLE_TENSORFLOW
+# Find TensorFlow.
+
+include(FindPackageHandleStandardArgs)
+
+find_path(TF_INCLUDE_DIR
+ NAMES third_party/tensorflow/c tensorflow/c
+ HINTS ${SWIFT_TENSORFLOW_TARGET_INCLUDE_DIR} /usr/include /usr/local/include)
+if (EXISTS ${TF_INCLUDE_DIR}/third_party/tensorflow/c/c_api.h)
+ # This is experimental and not covered by CI.
+ set(TF_PATH_ADJUSTMENT "third_party")
+else()
+ # Note: This is the normal workflow.
+ set(TF_PATH_ADJUSTMENT "")
+endif()
+
+find_library(TF_LIBRARY
+ NAMES tensorflow
+ HINTS ${SWIFT_TENSORFLOW_TARGET_LIB_DIR} /usr/lib /usr/local/lib)
+set(TF_LIBRARIES ${TF_LIBRARY})
+
+find_package_handle_standard_args(TensorFlow DEFAULT_MSG TF_INCLUDE_DIR TF_LIBRARIES)
+mark_as_advanced(${TF_INCLUDE_DIR} ${TF_LIBRARIES})
diff --git a/docs/ContinuousIntegration.md b/docs/ContinuousIntegration.md
index 545b08a..e83a28a 100644
--- a/docs/ContinuousIntegration.md
+++ b/docs/ContinuousIntegration.md
@@ -28,7 +28,21 @@
In order for the Swift project to be able to advance quickly, it is important that we maintain a green build [[1]](#footnote-1). In order to help maintain this green build, the Swift project heavily uses pull request (PR) testing. Specifically, an important general rule is that **all** non-trivial checkins to any Swift Project repository should at least perform a [smoke test](#smoke-testing) if simulators will not be impacted *or* a full [validation test](#validation-testing) if simulators may be impacted. If in addition one is attempting to make a source breaking change across multiple repositories, one should follow the cross repo source breaking changes workflow. We now continue by describing the Swift system for Pull Request testing, @swift-ci:
-### @swift-ci
+### @swift-ci (Swift for TensorFlow)
+
+
+Platform | Comment
+------------ | -------
+All supported platforms | @swift-ci Please test tensorflow
+All supported platforms | @swift-ci Please clean test tensorflow
+Linux | @swift-ci Please test tensorflow Linux
+Linux GPU | @swift-ci Please test tensorflow Linux GPU
+macOS | @swift-ci Please test tensorflow macOS
+
+Status can be checked at https://ci-external.swift.org/view/Pull%20Request/.
+
+
+### @swift-ci (general Swift)
Users with [commit access](https://swift.org/contributing/#commit-access) can trigger pull request testing by writing a comment on a PR addressed to the GitHub user @swift-ci. Different tests will run depending on the specific comment used. The current test types are:
diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h
index 264af7e..c8aa0bc 100644
--- a/include/swift/AST/ASTContext.h
+++ b/include/swift/AST/ASTContext.h
@@ -296,7 +296,7 @@
/// Cached mapping from types to their associated tangent spaces.
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
-
+
/// A cache of derivative function types per configuration.
llvm::DenseMap<SILAutoDiffDerivativeFunctionKey, CanSILFunctionType>
SILAutoDiffDerivativeFunctions;
@@ -309,8 +309,10 @@
/// Cache of `@derivative` attributes keyed by parameter indices and
/// derivative function kind. Used to diagnose duplicate `@derivative`
/// attributes for the same key.
- // TODO(TF-1042): remove `DerivativeAttrs` from `ASTContext`. Serialize
- // derivative function configurations per original `AbstractFunctionDecl`.
+ // NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
+ // signature as a key is possible. It requires derivative generic signature
+ // mangling to avoid name collisions for SIL derivative functions with the
+ // same parameter indices but different derivative generic signatures.
llvm::DenseMap<
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
llvm::SmallPtrSet<DerivativeAttr *, 1>>
@@ -336,6 +338,7 @@
/// Cache of module names that fail the 'canImport' test in this context.
llvm::SmallPtrSet<Identifier, 8> FailedModuleImportNames;
+public:
/// Retrieve the allocator for the given arena.
llvm::BumpPtrAllocator &
getAllocator(AllocationArena arena = AllocationArena::Permanent) const;
@@ -485,6 +488,35 @@
/// Retrieve the type Swift.AnyObject.
CanType getAnyObjectType() const;
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Retrieve the decl for TensorFlow.TensorHandle iff the TensorFlow module
+ /// has been imported. Otherwise, this returns null.
+ ClassDecl *getTensorHandleDecl() const;
+
+ /// Retrieve the decl for TensorFlow.TensorShape iff the TensorFlow module
+ /// has been imported. Otherwise, this returns null.
+ StructDecl *getTensorShapeDecl() const;
+
+ /// Retrieve the decl for TensorFlow.TensorDataType iff the TensorFlow module
+ /// has been imported. Otherwise, this returns null.
+ StructDecl *getTensorDataTypeDecl() const;
+
+ /// Retrieve the decl for the Quote module iff it has been imported.
+ /// Otherwise, this returns null.
+ ModuleDecl *getQuoteModule() const;
+
+ /// Retrieve the decl for Quote.Tree iff the Quote module has been imported.
+ /// Otherwise, this returns null.
+ ProtocolDecl *getTreeDecl() const;
+
+ /// Retrieve the decl for Quote.Quote iff the Quote module has been imported.
+ /// Otherwise, this returns null.
+ ClassDecl *getQuoteDecl() const;
+
+ /// Retrieve the decl for Quote.FunctionQuoteN iff the Quote module has been
+ /// imported. Otherwise, this returns null.
+ ClassDecl *getFunctionQuoteDecl(unsigned n) const;
+
/// Retrieve the type Swift.Never.
CanType getNeverType() const;
diff --git a/include/swift/AST/ASTMangler.h b/include/swift/AST/ASTMangler.h
index b244e76..a62730a 100644
--- a/include/swift/AST/ASTMangler.h
+++ b/include/swift/AST/ASTMangler.h
@@ -342,7 +342,7 @@
/// \returns \c true if a generic signature was appended, \c false
/// if it was empty.
bool appendGenericSignature(GenericSignature sig,
- GenericSignature contextSig = nullptr);
+ GenericSignature contextSig = GenericSignature());
void appendRequirement(const Requirement &reqt);
diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def
index 2c744e4..ba351d8 100644
--- a/include/swift/AST/Attr.def
+++ b/include/swift/AST/Attr.def
@@ -566,6 +566,13 @@
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
101)
+// SWIFT_ENABLE_TENSORFLOW
+SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
+ OnAccessor | OnFunc | OnConstructor | OnSubscript |
+ ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
+ NotSerialized, 102)
+// SWIFT_ENABLE_TENSORFLOW END
+
#undef TYPE_ATTR
#undef DECL_ATTR_ALIAS
#undef CONTEXTUAL_DECL_ATTR_ALIAS
diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h
index 0c67ece..009a151 100644
--- a/include/swift/AST/Attr.h
+++ b/include/swift/AST/Attr.h
@@ -51,6 +51,7 @@
class AbstractFunctionDecl;
class FuncDecl;
class ClassDecl;
+class FuncDecl;
class GenericFunctionType;
class LazyConformanceLoader;
class LazyMemberLoader;
diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h
index 79f3bce..55da4b4 100644
--- a/include/swift/AST/Decl.h
+++ b/include/swift/AST/Decl.h
@@ -3502,6 +3502,20 @@
/// Retrieve information about this type as a property wrapper.
PropertyWrapperTypeInfo getPropertyWrapperTypeInfo() const;
+private:
+ /// Predicate used to filter StoredPropertyRange.
+ struct ToStoredProperty {
+ ToStoredProperty(bool skipInaccessible = false) :
+ skipUserInaccessible(skipInaccessible) {}
+ bool skipUserInaccessible;
+ Optional<VarDecl *> operator()(Decl *decl) const;
+ };
+
+public:
+ /// A range for iterating the stored member variables of a structure.
+ using StoredPropertyRange = OptionalTransformRange<DeclRange,
+ ToStoredProperty>;
+
/// Return a collection of the stored member variables of this type.
ArrayRef<VarDecl *> getStoredProperties() const;
@@ -4198,6 +4212,15 @@
Decodable,
AdditiveArithmetic,
Differentiable,
+ // SWIFT_ENABLE_TENSORFLOW
+ PointwiseMultiplicative,
+ ElementaryFunctions,
+ KeyPathIterable,
+ TensorArrayProtocol,
+ TensorGroup,
+ VectorProtocol,
+ EuclideanDifferentiable,
+ // SWIFT_ENABLE_TENSORFLOW END
};
/// ProtocolDecl - A declaration of a protocol, for example:
@@ -6586,8 +6609,9 @@
/// This represents a 'case' declaration in an 'enum', which may declare
/// one or more individual comma-separated EnumElementDecls.
-class EnumCaseDecl final : public Decl,
- private llvm::TrailingObjects<EnumCaseDecl, EnumElementDecl *> {
+class EnumCaseDecl final
+ : public Decl,
+ private llvm::TrailingObjects<EnumCaseDecl, EnumElementDecl *> {
friend TrailingObjects;
friend class Decl;
SourceLoc CaseLoc;
diff --git a/include/swift/AST/DiagnosticsDriver.def b/include/swift/AST/DiagnosticsDriver.def
index 17b5b0d..119322b 100644
--- a/include/swift/AST/DiagnosticsDriver.def
+++ b/include/swift/AST/DiagnosticsDriver.def
@@ -190,5 +190,12 @@
"SDK settings were ignored because 'SDKSettings.json' could not be parsed",
())
+// SWIFT_ENABLE_TENSORFLOW
+ERROR(error_tensorflow_toolchain_repl_not_supported, none,
+ "The Swift for TensorFlow toolchain does not support the Swift REPL. Colab "
+ "(https://github.com/tensorflow/swift/blob/master/Usage.md#colaboratory) and Swift-Jupyter "
+ "(https://github.com/google/swift-jupyter) are supported alternatives.",
+ ())
+
#define UNDEFINE_DIAGNOSTIC_MACROS
#include "DefineDiagnosticMacros.h"
diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def
index eaae1b7..7fa3199 100644
--- a/include/swift/AST/DiagnosticsParse.def
+++ b/include/swift/AST/DiagnosticsParse.def
@@ -1765,5 +1765,22 @@
ERROR(expected_multiple_closures_block_rbrace,none,
"expected '}' at the end of a trailing closures block", ())
+// SWIFT_ENABLE_TENSORFLOW
+ERROR(sil_const_expected_int_datatype,PointsToFirstBadToken,
+ "expected integer datatype ('i[0-9]+', e.g. 'i32')", ())
+ERROR(sil_const_expected_int_value,PointsToFirstBadToken,
+ "expected integer value in SIL constant value", ())
+ERROR(sil_const_expected_fp_datatype,PointsToFirstBadToken,
+ "expected floating point datatype ('f32' or 'f64')", ())
+ERROR(sil_const_expected_fp_value,PointsToFirstBadToken,
+ "expected floating point value in SIL constant value", ())
+ERROR(sil_const_array_expected_rsquare,PointsToFirstBadToken,
+ "expected ']' at end of array 'SymbolicValue'", ())
+ERROR(sil_const_aggregate_expected_rparen,PointsToFirstBadToken,
+ "expected ')' at end of aggregate 'SymbolicValue'", ())
+ERROR(sil_const_expected_fn_sub_conv,PointsToFirstBadToken,
+ "expected '(N)' or '(W)' function substitution convention", ())
+// SWIFT_ENABLE_TENSORFLOW END
+
#define UNDEFINE_DIAGNOSTIC_MACROS
#include "DefineDiagnosticMacros.h"
diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def
index 36bd43a..a6feaa2 100644
--- a/include/swift/AST/DiagnosticsSIL.def
+++ b/include/swift/AST/DiagnosticsSIL.def
@@ -246,6 +246,10 @@
NOTE(designated_init_c_struct_fix,none,
"use \"self.init()\" to initialize the struct with zero values", ())
+NOTE(constexpr_unknown_reason, none, "%0", (StringRef))
+NOTE(constexpr_called_from, none, "when called from here", ())
+NOTE(constexpr_not_evaluable, none,
+ "expression not evaluable as constant here", ())
// Control flow diagnostics.
ERROR(missing_return,none,
@@ -338,6 +342,7 @@
ERROR(static_report_error, none,
"static report error", ())
+// SWIFT_ENABLE_TENSORFLOW
ERROR(pound_assert_condition_not_constant,none,
"#assert condition not constant", ())
ERROR(pound_assert_failure,none,
@@ -416,6 +421,13 @@
"%select{for this call|for a witness-method invoked during this call}0",
(bool))
+NOTE(constexpr_witness_call_with_no_conformance_found, none,
+ "cannot find concrete conformance for a witness method call", ())
+NOTE(constexpr_witness_call_with_no_target_found, none,
+ "cannot resolve a witness method call to a concrete function", ())
+NOTE(constexpr_witness_call_found_here, none,
+ "witness method call found here", ())
+
NOTE(constexpr_unknown_control_flow_due_to_skip,none, "branch depends on "
"non-constant value produced by an unevaluated instructions", ())
NOTE(constexpr_returned_by_unevaluated_instruction,none,
diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def
index 13717b1..329d744 100644
--- a/include/swift/AST/DiagnosticsSema.def
+++ b/include/swift/AST/DiagnosticsSema.def
@@ -1258,6 +1258,7 @@
"consider using an existing high level algorithm, "
"str.startIndex.advanced(by: n), or a projection like str.utf8", ())
+
ERROR(invalid_c_function_pointer_conversion_expr,none,
"a C function pointer can only be formed from a reference to a 'func' or "
"a literal closure", ())
@@ -2838,6 +2839,22 @@
"add an explicit '@noDerivative' attribute "
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
+// SWIFT_ENABLE_TENSORFLOW
+ERROR(broken_pointwise_multiplicative_requirement,none,
+ "PointwiseMultiplicative protocol is broken: unexpected requirement", ())
+ERROR(broken_elementary_functions_requirement,none,
+ "ElementaryFunctions protocol is broken: unexpected requirement", ())
+ERROR(broken_vector_protocol_requirement,none,
+ "VectorProtocol protocol is broken: unexpected requirement", ())
+ERROR(broken_euclidean_differentiable_requirement,none,
+ "EuclideanDifferentiable protocol is broken: unexpected requirement", ())
+ERROR(broken_key_path_iterable_requirement,none,
+ "KeyPathIterable protocol is broken: unexpected requirement", ())
+ERROR(broken_tensor_array_protocol_requirement,none,
+ "TensorArrayProtocol protocol is broken: unexpected requirement", ())
+ERROR(broken_tensor_group_requirement,none,
+ "TensorGroup protocol is broken: unexpected requirement", ())
+// SWIFT_ENABLE_TENSORFLOW END
NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))
@@ -3102,6 +3119,10 @@
"class; consider making %0 final", (Type))
ERROR(differentiable_attr_empty_where_clause,none,
"empty 'where' clause in '@differentiable' attribute", ())
+// SWIFT_ENABLE_TENSORFLOW
+ERROR(differentiable_attr_nongeneric_trailing_where,none,
+ "trailing 'where' clause in '@differentiable' attribute of non-generic "
+ "function %0", (DeclName))
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
"'where' clause is valid only when original function is generic %0",
(DeclName))
@@ -3248,6 +3269,25 @@
"can only differentiate with respect to parameters that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
+// @compilerEvaluable
+ERROR(compiler_evaluable_bad_context,none,
+ "@compilerEvaluable functions not allowed here", ())
+ERROR(compiler_evaluable_loop,none,
+ "loops not allowed in @compilerEvaluable functions", ())
+ERROR(compiler_evaluable_forbidden_expression,none,
+ "expression not allowed in @compilerEvaluable functions", ())
+ERROR(compiler_evaluable_non_local_mutable,none,
+ "referencing non-local mutable variables not allowed in @compilerEvaluable functions", ())
+ERROR(compiler_evaluable_forbidden_type,none,
+ "type %0 cannot be used in @compilerEvaluable functions", (Type))
+ERROR(compiler_evaluable_ref_non_compiler_evaluable,none,
+ "@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ())
+
+// @noDerivative attribute
+ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
+ "'@noDerivative' is only allowed on stored properties in structure or "
+ "class types that declare a conformance to 'Differentiable'", ())
+
//------------------------------------------------------------------------------
// MARK: Type Check Expressions
//------------------------------------------------------------------------------
@@ -3409,6 +3449,22 @@
ERROR(object_literal_broken_proto,none,
"object literal protocol is broken", ())
+ERROR(quote_literal_no_function_quote_class,none,
+ "cannot expand quote literal: Quote.FunctionQuote%0 not found", (int))
+ERROR(quote_literal_no_quote_class,none,
+ "cannot expand quote literal: Quote.Quote not found", ())
+ERROR(quote_literal_no_quote_module,none,
+ "cannot expand quote literal: Quote not imported", ())
+ERROR(quote_literal_no_tree_proto,none,
+ "cannot expand quote literal: Quote.Tree not found", ())
+WARNING(quote_literal_unsupported_brief,none,
+ "unsupported code snippet", ())
+WARNING(quote_literal_unsupported_detailed,none,
+ "unsupported code snippet\n%0", (StringRef))
+
+ERROR(unquote_wrong_type,none,
+ "can only unquote expressions of type Quote and FunctionQuoteN", ())
+
ERROR(discard_expr_outside_of_assignment,none,
"'_' can only appear in a pattern or on the left side of an assignment",
())
@@ -4319,6 +4375,9 @@
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(linear)}1'",
(StringRef, bool))
+ERROR(attr_differentiable_no_vjp_or_jvp_when_linear,none,
+ "cannot specify 'vjp:' or 'jvp:' for linear functions; use "
+ "'transpose:' instead", ())
// SIL
ERROR(opened_non_protocol,none,
diff --git a/include/swift/AST/IRGenOptions.h b/include/swift/AST/IRGenOptions.h
index f805562..3d1facf 100644
--- a/include/swift/AST/IRGenOptions.h
+++ b/include/swift/AST/IRGenOptions.h
@@ -354,7 +354,9 @@
PrespecializeGenericMetadata(false), UseIncrementalLLVMCodeGen(true),
UseSwiftCall(false), UseTypeLayoutValueHandling(true),
GenerateProfile(false), EnableDynamicReplacementChaining(false),
- DisableRoundTripDebugTypes(false), DisableDebuggerShadowCopies(false),
+ // SWIFT_ENABLE_TENSORFLOW
+ // TODO(TF-486): Reenable round type debug types.
+ DisableRoundTripDebugTypes(true), DisableDebuggerShadowCopies(false),
DisableConcreteTypeMetadataMangledNameAccessors(false), CmdArgs(),
SanitizeCoverage(llvm::SanitizerCoverageOptions()),
TypeInfoFilter(TypeInfoDumpFilter::All) {}
diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def
index 3789816..27b678a 100644
--- a/include/swift/AST/KnownIdentifiers.def
+++ b/include/swift/AST/KnownIdentifiers.def
@@ -225,6 +225,35 @@
IDENTIFIER(zero)
IDENTIFIER(zeroTangentVectorInitializer)
+// SWIFT_ENABLE_TENSORFLOW
+IDENTIFIER(TensorFlow)
+// Module that supports #quote(...) literals.
+IDENTIFIER(Quote)
+// KeyPathIterable
+IDENTIFIER(AllKeyPaths)
+IDENTIFIER(allKeyPaths)
+IDENTIFIER(recursivelyAllKeyPaths)
+IDENTIFIER(allWritableKeyPaths)
+IDENTIFIER(recursivelyAllWritableKeyPaths)
+// TensorArrayProtocol
+IDENTIFIER_(unpackTensorHandles)
+IDENTIFIER_(tensorHandleCount)
+// TensorGroup
+IDENTIFIER_(typeList)
+// AdditiveArithmetic, PointwiseMultiplicative, VectorProtocol
+IDENTIFIER(one)
+IDENTIFIER(reciprocal)
+IDENTIFIER(VectorSpaceScalar)
+IDENTIFIER(adding)
+IDENTIFIER(subtracting)
+IDENTIFIER(scaled)
+IDENTIFIER(by)
+IDENTIFIER(scale)
+IDENTIFIER(x)
+// Differentiable
+IDENTIFIER(differentiableVectorView)
+// SWIFT_ENABLE_TENSORFLOW END
+
#undef IDENTIFIER
#undef IDENTIFIER_
#undef IDENTIFIER_WITH_NAME
diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def
index e530e38..0166353 100644
--- a/include/swift/AST/KnownProtocols.def
+++ b/include/swift/AST/KnownProtocols.def
@@ -89,6 +89,17 @@
PROTOCOL(FloatingPoint)
+// SWIFT_ENABLE_TENSORFLOW
+PROTOCOL(PointwiseMultiplicative)
+PROTOCOL(ElementaryFunctions)
+PROTOCOL(KeyPathIterable)
+PROTOCOL(TensorArrayProtocol)
+PROTOCOL(TensorGroup)
+PROTOCOL(VectorProtocol)
+PROTOCOL(EuclideanDifferentiable)
+PROTOCOL(Expression)
+// SWIFT_ENABLE_TENSORFLOW END
+
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)
@@ -103,6 +114,9 @@
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByColorLiteral, "_ColorLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByImageLiteral, "_ImageLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByFileReferenceLiteral, "_FileReferenceLiteralType", true)
+// SWIFT_ENABLE_TENSORFLOW
+// TODO(TF-735): Implement ExpressibleByQuoteLiteral.
+// EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByQuoteLiteral, "_QuoteLiteralType", true)
BUILTIN_EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByBuiltinBooleanLiteral)
BUILTIN_EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByBuiltinExtendedGraphemeClusterLiteral)
diff --git a/include/swift/AST/ModuleLoader.h b/include/swift/AST/ModuleLoader.h
index 3660f00..c78f31f 100644
--- a/include/swift/AST/ModuleLoader.h
+++ b/include/swift/AST/ModuleLoader.h
@@ -36,6 +36,9 @@
namespace swift {
+// SWIFT_ENABLE_TENSORFLOW
+struct AutoDiffConfig;
+// SWIFT_ENABLE_TENSORFLOW END
class AbstractFunctionDecl;
struct AutoDiffConfig;
class ClangImporterOptions;
diff --git a/include/swift/AST/SourceFile.h b/include/swift/AST/SourceFile.h
index 5fc32b5..52170e1 100644
--- a/include/swift/AST/SourceFile.h
+++ b/include/swift/AST/SourceFile.h
@@ -181,11 +181,18 @@
/// been validated.
llvm::SetVector<ValueDecl *> UnvalidatedDeclsWithOpaqueReturnTypes;
+// SWIFT_ENABLE_TENSORFLOW
+// For TensorFlow, keep `Decls` public because SwiftCodeCompletion needs it.
+public:
+// SWIFT_ENABLE_TENSORFLOW END
/// The list of top-level declarations in the source file. This is \c None if
/// they have not yet been parsed.
/// FIXME: Once addTopLevelDecl/prependTopLevelDecl
/// have been removed, this can become an optional ArrayRef.
Optional<std::vector<Decl *>> Decls;
+// SWIFT_ENABLE_TENSORFLOW
+private:
+// SWIFT_ENABLE_TENSORFLOW END
/// The list of hoisted declarations. See Decl::isHoisted().
/// This is only used by lldb.
diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h
index 6810676..5996248 100644
--- a/include/swift/Basic/LangOptions.h
+++ b/include/swift/Basic/LangOptions.h
@@ -30,6 +30,10 @@
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/VersionTuple.h"
+
+// SWIFT_ENABLE_TENSORFLOW
+#include "clang/Basic/InMemoryOutputFileSystem.h"
+
#include <string>
#include <vector>
@@ -183,6 +187,8 @@
/// options. Disabled by default because there is no way to control the
/// language mode of clang on a per-header or even per-module basis. Also
/// disabled because it is not complete.
+ /// options.
+ /// FIXME: Disabled by default until this is fully baked.
bool EnableCXXInterop = false;
/// On Darwin platforms, use the pre-stable ABI's mark bit for Swift
@@ -350,7 +356,10 @@
/// Whether to enable experimental differentiable programming features:
/// `@differentiable` declaration attribute, etc.
- bool EnableExperimentalDifferentiableProgramming = false;
+ // SWIFT_ENABLE_TENSORFLOW
+ // Use default value true on `tensorflow` branch.
+ bool EnableExperimentalDifferentiableProgramming = true;
+ // SWIFT_ENABLE_TENSORFLOW END
/// Whether to enable forward mode differentiation.
bool EnableExperimentalForwardModeDifferentiation = false;
@@ -649,6 +658,13 @@
/// contains the full option set.
bool ExtraArgsOnly = false;
+ // SWIFT_ENABLE_TENSORFLOW
+ /// When set, clang writes its output files (module caches) to this instead
+ /// of to the real filesystem.
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem>
+ InMemoryOutputFileSystem;
+ // SWIFT_ENABLE_TENSORFLOW END
+
/// Return a hash code of any components from these options that should
/// contribute to a Swift Bridging PCH hash.
llvm::hash_code getPCHHashComponents() const {
diff --git a/include/swift/Basic/SourceManager.h b/include/swift/Basic/SourceManager.h
index a12aa16..f76c23d 100644
--- a/include/swift/Basic/SourceManager.h
+++ b/include/swift/Basic/SourceManager.h
@@ -89,6 +89,11 @@
CodeCompletionOffset = Offset;
}
+ // SWIFT_ENABLE_TENSORFLOW
+ void clearCodeCompletionPoint() {
+ CodeCompletionBufferID = 0U;
+ }
+
bool hasCodeCompletionBuffer() const {
return CodeCompletionBufferID != 0U;
}
diff --git a/include/swift/Demangling/TypeDecoder.h b/include/swift/Demangling/TypeDecoder.h
index f7a5a22..35b62c5 100644
--- a/include/swift/Demangling/TypeDecoder.h
+++ b/include/swift/Demangling/TypeDecoder.h
@@ -629,6 +629,17 @@
flags = flags.withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind::Linear);
}
+ // SWIFT_ENABLE_TENSORFLOW
+ else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
+ Node->getKind() ==
+ NodeKind::EscapingDifferentiableFunctionType) {
+ flags = flags.withDifferentiabilityKind(
+ FunctionMetadataDifferentiabilityKind::Normal);
+ } else if (Node->getKind() == NodeKind::LinearFunctionType ||
+ Node->getKind() == NodeKind::EscapingLinearFunctionType) {
+ flags = flags.withDifferentiabilityKind(
+ FunctionMetadataDifferentiabilityKind::Linear);
+ }
unsigned firstChildIdx = 0;
bool isThrow = false;
diff --git a/include/swift/IDE/Utils.h b/include/swift/IDE/Utils.h
index c916f77..44ebdb0 100644
--- a/include/swift/IDE/Utils.h
+++ b/include/swift/IDE/Utils.h
@@ -22,6 +22,9 @@
#include "swift/IDE/SourceEntityWalker.h"
#include "swift/Parse/Token.h"
#include "llvm/ADT/StringRef.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "clang/Basic/InMemoryOutputFileSystem.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include <memory>
#include <string>
#include <functional>
@@ -85,6 +88,9 @@
CompilerInvocation &Invocation, ArrayRef<const char *> OrigArgs,
DiagnosticEngine &Diags, StringRef UnresolvedPrimaryFile,
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
+ // SWIFT_ENABLE_TENSORFLOW
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> InMemoryOutputFileSystem,
+ // SWIFT_ENABLE_TENSORFLOW END
const std::string &runtimeResourcePath,
const std::string &diagnosticDocumentationPath,
bool shouldOptimizeForIDE, time_t sessionTimestamp, std::string &Error);
diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h
index 1c4584e..fae3ada 100644
--- a/include/swift/Parse/Parser.h
+++ b/include/swift/Parse/Parser.h
@@ -1604,6 +1604,9 @@
parseTrailingClosures(bool isExprBasic, SourceRange calleeRange,
SmallVectorImpl<TrailingClosure> &closures);
+ ParserResult<Expr> parseExprQuoteLiteral();
+ ParserResult<Expr> parseExprUnquote();
+ ParserResult<Expr> parseExprPoundAssert();
/// Parse an object literal.
///
/// \param LK The literal kind as determined by the first token.
diff --git a/include/swift/SIL/Dominance.h b/include/swift/SIL/Dominance.h
index 505c1cd..89dc818 100644
--- a/include/swift/SIL/Dominance.h
+++ b/include/swift/SIL/Dominance.h
@@ -96,10 +96,12 @@
/// domOrder.pushChildren(block);
/// }
/// \endcode
-class DominanceOrder {
+// SWIFT_ENABLE_TENSORFLOW
+template <class DomInfo>
+class DominanceOrderBase {
SmallVector<SILBasicBlock *, 16> buffer;
- DominanceInfo *DT;
+ DomInfo *DT;
size_t srcIdx = 0;
public:
@@ -109,7 +111,8 @@
/// \p DT The dominance info of the function.
/// \p capacity Should be the number of basic blocks in the dominator tree to
/// reduce memory allocation.
- DominanceOrder(SILBasicBlock *root, DominanceInfo *DT, int capacity = 0) :
+ // SWIFT_ENABLE_TENSORFLOW
+ DominanceOrderBase(SILBasicBlock *root, DomInfo *DT, int capacity = 0) :
DT(DT) {
buffer.reserve(capacity);
buffer.push_back(root);
@@ -183,6 +186,9 @@
using super::properlyDominates;
};
+// SWIFT_ENABLE_TENSORFLOW
+using DominanceOrder = DominanceOrderBase<DominanceInfo>;
+using PostDominanceOrder = DominanceOrderBase<PostDominanceInfo>;
} // end namespace swift
diff --git a/include/swift/SIL/SILLocation.h b/include/swift/SIL/SILLocation.h
index 9e0d07a..c2637bc 100644
--- a/include/swift/SIL/SILLocation.h
+++ b/include/swift/SIL/SILLocation.h
@@ -241,6 +241,11 @@
friend class CleanupLocation;
void setLocationKind(LocationKind K) { KindData |= (K & LocationKindMask); }
+ // SWIFT_ENABLE_TENSORFLOW
+ // TODO: Assess if this API can be unified with the one above.
+ void setAndOverwriteLocationKind(LocationKind K) {
+ KindData = (K & LocationKindMask);
+ }
void setStorageKind(StorageKind K) { KindData |= (K & StorageKindMask); }
unsigned getSpecialFlags() const { return KindData & SpecialFlagsMask; }
void setSpecialFlags(unsigned Flags) {
@@ -435,6 +440,18 @@
return RegularLoc;
}
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Convert a specialized location kind into a regular location, by completely
+ /// overwriting the existing location kind in `RegularLoc`. In contrast, the
+ /// above function does a bitwise OR on the existing location kind and
+ /// RegularKind.
+ // TODO: Assess if these two APIs can be unified.
+ SILLocation getAsRegularLocationWithOverwrite() {
+ SILLocation RegularLoc = *this;
+ RegularLoc.setAndOverwriteLocationKind(RegularKind);
+ return RegularLoc;
+ }
+
SourceLoc getDebugSourceLoc() const;
SourceLoc getSourceLoc() const;
SourceLoc getStartSourceLoc() const;
diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h
index 988fe9f..eb58807 100644
--- a/include/swift/SIL/SILModule.h
+++ b/include/swift/SIL/SILModule.h
@@ -143,6 +143,9 @@
friend SILProperty;
friend SILUndef;
friend SILWitnessTable;
+ // SWIFT_ENABLE_TENSORFLOW
+ friend SILDifferentiabilityWitness;
+ // SWIFT_ENABLE_TENSORFLOW END
friend Lowering::SILGenModule;
friend Lowering::TypeConverter;
class SerializationCallback;
diff --git a/include/swift/SIL/SILNode.h b/include/swift/SIL/SILNode.h
index 829ac35..2138510 100644
--- a/include/swift/SIL/SILNode.h
+++ b/include/swift/SIL/SILNode.h
@@ -334,6 +334,8 @@
UIWTDOB_BITFIELD(ConvertFunctionInst, ConversionInst, 1,
WithoutActuallyEscaping : 1);
+ // SWIFT_ENABLE_TENSORFLOW
+ UIWTDOB_BITFIELD_EMPTY(GradientInst, SingleValueInstruction);
UIWTDOB_BITFIELD_EMPTY(PointerToThinFunctionInst, ConversionInst);
UIWTDOB_BITFIELD_EMPTY(UnconditionalCheckedCastInst, ConversionInst);
UIWTDOB_BITFIELD_EMPTY(UpcastInst, ConversionInst);
diff --git a/include/swift/SILOptimizer/Differentiation/Common.h b/include/swift/SILOptimizer/Differentiation/Common.h
index 90e3c26..54269ad 100644
--- a/include/swift/SILOptimizer/Differentiation/Common.h
+++ b/include/swift/SILOptimizer/Differentiation/Common.h
@@ -21,6 +21,7 @@
#include "swift/AST/Expr.h"
#include "swift/AST/SemanticAttrs.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"
+#include "swift/SIL/SILType.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/TypeSubstCloner.h"
diff --git a/include/swift/Serialization/SerializedSILLoader.h b/include/swift/Serialization/SerializedSILLoader.h
index 1707f54..6d23220 100644
--- a/include/swift/Serialization/SerializedSILLoader.h
+++ b/include/swift/Serialization/SerializedSILLoader.h
@@ -13,7 +13,9 @@
#ifndef SWIFT_SERIALIZATION_SILLOADER_H
#define SWIFT_SERIALIZATION_SILLOADER_H
+// SWIFT_ENABLE_TENSORFLOW
#include "swift/AST/AutoDiff.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include "swift/AST/Decl.h"
#include "swift/AST/Identifier.h"
#include "swift/SIL/Notifications.h"
@@ -99,8 +101,10 @@
/// Deserialize all Properties in all SILModules.
void getAllProperties();
+ // SWIFT_ENABLE_TENSORFLOW
/// Deserialize all DifferentiabilityWitnesses in all SILModules.
void getAllDifferentiabilityWitnesses();
+ // SWIFT_ENABLE_TENSORFLOW END
SerializedSILLoader(const SerializedSILLoader &) = delete;
SerializedSILLoader(SerializedSILLoader &&) = delete;
diff --git a/include/swift/Subsystems.h b/include/swift/Subsystems.h
index a4e545e..00c7404 100644
--- a/include/swift/Subsystems.h
+++ b/include/swift/Subsystems.h
@@ -25,6 +25,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Mutex.h"
+#include "llvm/Support/raw_ostream.h"
#include <memory>
@@ -188,6 +189,23 @@
std::unique_ptr<llvm::MemoryBuffer> *moduleSourceInfoBuffer,
const SILModule *M = nullptr);
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Serializes a module or single source file to a memory buffer, and returns
+ /// the memory buffer in an output parameter. Does not write to the
+ /// filesystem.
+ ///
+ /// \param moduleBuffer will be set to a pointer to the serialized module
+ /// buffer. nullptr is allowed, in which case the module
+ /// will not be serialized.
+ /// \param moduleDocBuffer will be set to a pointer to the serialized module
+ /// doc buffer. nullptr is allowed, in which case the
+ /// module doc will not be serialized.
+ void serializeToMemory(ModuleOrSourceFile DC,
+ const SerializationOptions &options,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleBuffer,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleDocBuffer,
+ const SILModule *M = nullptr);
+
/// Get the CPU, subtarget feature options, and triple to use when emitting code.
std::tuple<llvm::TargetOptions, std::string, std::vector<std::string>,
std::string>
diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp
index 76ddd21..ac6fac8 100644
--- a/lib/AST/ASTContext.cpp
+++ b/lib/AST/ASTContext.cpp
@@ -186,6 +186,21 @@
/// The AnyObject type.
CanType AnyObjectType;
+ // SWIFT_ENABLE_TENSORFLOW
+ /// The declaration of TensorFlow.TensorHandle<T>.
+ ClassDecl *TensorHandleDecl = nullptr;
+ /// The declaration of TensorFlow.TensorShape.
+ StructDecl *TensorShapeDecl = nullptr;
+ /// The declaration of TensorFlow.TensorDataType.
+ StructDecl *TensorDataTypeDecl = nullptr;
+
+ /// The declaration of Quote.Tree.
+ ProtocolDecl *TreeDecl = nullptr;
+ /// The declaration of Quote.Quote.
+ ClassDecl *QuoteDecl = nullptr;
+ /// The declarations of Quote.FunctionQuoteN.
+ SmallVector<ClassDecl *, 16> FunctionQuoteDecls;
+
#define KNOWN_STDLIB_TYPE_DECL(NAME, DECL_CLASS, NUM_GENERIC_PARAMS) \
/** The declaration of Swift.NAME. */ \
DECL_CLASS *NAME##Decl = nullptr;
@@ -874,6 +889,150 @@
return getImpl().AnyObjectType;
}
+// SWIFT_ENABLE_TENSORFLOW
+
+/// Retrieve the decl for TensorFlow.TensorHandle iff the TensorFlow module has
+/// been imported. Otherwise, this returns null.
+ClassDecl *ASTContext::getTensorHandleDecl() const {
+ if (getImpl().TensorHandleDecl)
+ return getImpl().TensorHandleDecl;
+
+ // See if the TensorFlow module was imported. If not, return null.
+ auto tfModule = getLoadedModule(Id_TensorFlow);
+ if (!tfModule)
+ return nullptr;
+
+ SmallVector<ValueDecl *, 1> results;
+ tfModule->lookupValue(getIdentifier("TensorHandle"),
+ NLKind::UnqualifiedLookup, results);
+
+ for (auto result : results)
+ if (auto CD = dyn_cast<ClassDecl>(result))
+ return getImpl().TensorHandleDecl = CD;
+ return nullptr;
+}
+
+/// Retrieve the decl for TensorFlow.TensorShape iff the TensorFlow module has
+/// been imported. Otherwise, this returns null.
+StructDecl *ASTContext::getTensorShapeDecl() const {
+ if (getImpl().TensorShapeDecl)
+ return getImpl().TensorShapeDecl;
+
+ // See if the TensorFlow module was imported. If not, return null.
+ auto tfModule = getLoadedModule(Id_TensorFlow);
+ if (!tfModule)
+ return nullptr;
+
+ SmallVector<ValueDecl *, 1> results;
+ tfModule->lookupValue(getIdentifier("TensorShape"),
+ NLKind::UnqualifiedLookup, results);
+
+ for (auto result : results)
+ if (auto CD = dyn_cast<StructDecl>(result))
+ return getImpl().TensorShapeDecl = CD;
+ return nullptr;
+}
+
+/// Retrieve the decl for TensorFlow.TensorDataType iff the TensorFlow module has
+/// been imported. Otherwise, this returns null.
+StructDecl *ASTContext::getTensorDataTypeDecl() const {
+ if (getImpl().TensorDataTypeDecl)
+ return getImpl().TensorDataTypeDecl;
+
+ // See if the TensorFlow module was imported. If not, return null.
+ auto tfModule = getLoadedModule(Id_TensorFlow);
+ if (!tfModule)
+ return nullptr;
+
+ SmallVector<ValueDecl *, 1> results;
+ tfModule->lookupValue(getIdentifier("TensorDataType"),
+ NLKind::UnqualifiedLookup, results);
+
+ for (auto result : results)
+ if (auto CD = dyn_cast<StructDecl>(result))
+ return getImpl().TensorDataTypeDecl = CD;
+ return nullptr;
+}
+
+/// Retrieve the decl for the Quote module iff it has been imported.
+/// Otherwise, this returns null.
+ModuleDecl *ASTContext::getQuoteModule() const {
+ return getLoadedModule(Id_Quote);
+}
+
+/// Retrieve the decl for Quote.Tree iff the Quote module has been imported.
+/// Otherwise, this returns null.
+ProtocolDecl *ASTContext::getTreeDecl() const {
+ if (getImpl().TreeDecl)
+ return getImpl().TreeDecl;
+
+ auto quoteModule = getLoadedModule(Id_Quote);
+ if (!quoteModule)
+ return nullptr;
+
+ SmallVector<ValueDecl *, 1> results;
+ quoteModule->lookupValue(getIdentifier("Tree"), NLKind::UnqualifiedLookup,
+ results);
+
+ for (auto result : results)
+ if (auto CD = dyn_cast<ProtocolDecl>(result))
+ return getImpl().TreeDecl = CD;
+ return nullptr;
+}
+
+/// Retrieve the decl for Quote.Quote iff the Quote module has been imported.
+/// Otherwise, this returns null.
+ClassDecl *ASTContext::getQuoteDecl() const {
+ if (getImpl().QuoteDecl)
+ return getImpl().QuoteDecl;
+
+ auto quoteModule = getLoadedModule(Id_Quote);
+ if (!quoteModule)
+ return nullptr;
+
+ SmallVector<ValueDecl *, 1> results;
+ quoteModule->lookupValue(getIdentifier("Quote"), NLKind::UnqualifiedLookup,
+ results);
+
+ for (auto result : results)
+ if (auto CD = dyn_cast<ClassDecl>(result))
+ return getImpl().QuoteDecl = CD;
+ return nullptr;
+}
+
+/// Retrieve the decl for Quote.FunctionQuoteN iff the Quote module has been
+/// imported. Otherwise, this returns null.
+ClassDecl *ASTContext::getFunctionQuoteDecl(unsigned n) const {
+ auto cache = getImpl().FunctionQuoteDecls;
+ if (cache.size() == 0) {
+ auto quoteModule = getLoadedModule(Id_Quote);
+ if (!quoteModule)
+ return nullptr;
+
+ for (auto i = 0; i < 16; ++i) {
+ llvm::SmallString<16> SS;
+ llvm::raw_svector_ostream OS(SS);
+ OS << "FunctionQuote" << n;
+ auto id = getIdentifier(SS);
+
+ SmallVector<ValueDecl *, 1> results;
+ quoteModule->lookupValue(id, NLKind::UnqualifiedLookup, results);
+
+ for (auto result : results) {
+ if (auto CD = dyn_cast<ClassDecl>(result)) {
+ cache.push_back(CD);
+ break;
+ }
+ }
+ }
+ }
+ if (n < cache.size()) {
+ return cache[n];
+ } else {
+ return nullptr;
+ }
+}
+
CanType ASTContext::getNeverType() const {
auto neverDecl = getNeverDecl();
if (!neverDecl)
@@ -931,9 +1090,18 @@
case KnownProtocolKind::CFObject:
M = getLoadedModule(Id_CoreFoundation);
break;
- case KnownProtocolKind::Differentiable:
- M = getLoadedModule(Id_Differentiation);
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // NOTE: The `Differentiable` protocol is in the stdlib module on tensorflow
+ // branch, but in `_Differentiation` module on the master branch.
+ case KnownProtocolKind::TensorArrayProtocol:
+ case KnownProtocolKind::TensorGroup:
+ M = getLoadedModule(Id_TensorFlow);
break;
+ case KnownProtocolKind::Expression:
+ M = getLoadedModule(Id_Quote);
+ break;
+ // SWIFT_ENABLE_TENSORFLOW END
default:
M = getStdlibModule();
break;
diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp
index 65b3ea8..cc04850 100644
--- a/lib/AST/ASTPrinter.cpp
+++ b/lib/AST/ASTPrinter.cpp
@@ -38,6 +38,7 @@
#include "swift/Basic/Defer.h"
#include "swift/Basic/PrimitiveParsing.h"
#include "swift/Basic/QuotedString.h"
+#include "swift/Basic/SourceManager.h"
#include "swift/Basic/STLExtras.h"
#include "swift/Basic/StringExtras.h"
#include "swift/Config.h"
@@ -3731,6 +3732,22 @@
}
bool shouldPrintFullyQualified(TypeBase *T) {
+ // SWIFT_ENABLE_TENSORFLOW
+ // NOTE(TF-590): Workaround for REPL qualified module name bug.
+ // Do not print qualified LLDB module names.
+ {
+ Decl *D;
+ if (auto *TAT = dyn_cast<TypeAliasType>(T))
+ D = TAT->getDecl();
+ else
+ D = T->getAnyGeneric();
+
+ ModuleDecl *M = D->getDeclContext()->getParentModule();
+ if (isLLDBExpressionModule(M))
+ return false;
+ }
+ // SWIFT_ENABLE_TENSORFLOW END
+
if (Options.FullyQualifiedTypes)
return true;
diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp
index 78ce700..a4ae20e 100644
--- a/lib/AST/ASTVerifier.cpp
+++ b/lib/AST/ASTVerifier.cpp
@@ -2775,6 +2775,11 @@
abort();
}
+ // Skip implicit generic param decls. Their depth and index may not be
+ // consistent with the generic context's parameter list.
+ if (GTPD->isImplicit())
+ return;
+
unsigned currentDepth = DC->getGenericContextDepth();
if (currentDepth < GTPD->getDepth()) {
Out << "GenericTypeParamDecl has incorrect depth\n";
diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp
index 70a1ec3..ec39415 100644
--- a/lib/AST/Attr.cpp
+++ b/lib/AST/Attr.cpp
@@ -20,12 +20,18 @@
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "swift/AST/GenericSignatureBuilder.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include "swift/AST/IndexSubset.h"
#include "swift/AST/LazyResolver.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/TypeRepr.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "swift/AST/TypeCheckRequests.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include "swift/AST/Types.h"
#include "swift/Basic/Defer.h"
#include "swift/Basic/QuotedString.h"
@@ -1640,7 +1646,15 @@
IndexSubset *parameterIndices,
GenericSignature derivativeGenSig) {
auto &ctx = original->getASTContext();
-
+ // SWIFT_ENABLE_TENSORFLOW
+ // Register derivative function configuration for the given original
+ // declaration.
+ // NOTE(TF-1038): `@differentiable` attributes currently always have
+ // effective result indices `{0}` (the first and only result index).
+ auto *resultIndices = IndexSubset::get(ctx, 1, {0});
+ original->addDerivativeFunctionConfiguration(
+ {parameterIndices, resultIndices, derivativeGenSig});
+ // SWIFT_ENABLE_TENSORFLOW END
size_t size = totalSizeToAlloc<ParsedAutoDiffParameter>(0);
void *mem = ctx.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp
index f492762..d122189 100644
--- a/lib/AST/Builtins.cpp
+++ b/lib/AST/Builtins.cpp
@@ -1725,6 +1725,13 @@
return BuiltinVectorType::get(Context, eltType, width);
}
+ /// Create a vector type.
+ Type makeVector(Type eltType, llvm::ElementCount width) {
+ // Need an actual element count
+ assert(!width.Scalable);
+ return makeVector(eltType, width.Min);
+ }
+
/// Return the first type or, if the second type is a vector type, a vector
/// of the first type of the same length as the second type.
Type maybeMakeVectorized(Type eltType, Type maybeVectorType) {
@@ -2155,7 +2162,6 @@
return nullptr;
return getLinearFunctionConstructor(Context, Id, arity, throws);
}
-
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
#define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id)
#include "swift/AST/Builtins.def"
diff --git a/lib/AST/ConcreteDeclRef.cpp b/lib/AST/ConcreteDeclRef.cpp
index 248d57c..330c3f7 100644
--- a/lib/AST/ConcreteDeclRef.cpp
+++ b/lib/AST/ConcreteDeclRef.cpp
@@ -15,13 +15,14 @@
//
//===----------------------------------------------------------------------===//
-#include "swift/AST/ASTContext.h"
#include "swift/AST/ConcreteDeclRef.h"
+#include "swift/AST/ASTContext.h"
#include "swift/AST/Decl.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/Types.h"
+#include "swift/AST/USRGeneration.h"
#include "llvm/Support/raw_ostream.h"
using namespace swift;
diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp
index 979b144..36c7e37 100644
--- a/lib/AST/Decl.cpp
+++ b/lib/AST/Decl.cpp
@@ -46,6 +46,7 @@
#include "swift/AST/TypeLoc.h"
#include "swift/AST/SwiftNameTranslation.h"
#include "swift/Parse/Lexer.h" // FIXME: Bad dependency
+#include "clang/Basic/Module.h"
#include "clang/Lex/MacroInfo.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
@@ -566,6 +567,16 @@
}
for (auto Attr : getAttrs()) {
+ // SWIFT_ENABLE_TENSORFLOW
+ // Skip implicitly `@differentiable` attribute generated during
+ // `@derivative` attribute type-checking.
+ // TODO(TF-835): Instead of generating implicit `@differentiable`
+ // attributes, lower `@derivative` attributes to differentiability witnesses
+ // for the referenced declaration.
+ if (auto *diffAttr = dyn_cast<DifferentiableAttr>(Attr))
+ if (diffAttr->isImplicit())
+ continue;
+ // SWIFT_ENABLE_TENSORFLOW END
if (Attr->getRange().isValid())
Range.widen(Attr->getRangeWithAt());
}
@@ -642,7 +653,8 @@
if (isa<ModuleDecl>(this))
return SourceLoc();
// When the decl is context-free, we should get loc from source buffer.
- if (!getDeclContext())
+ if (!getDeclContext() ||
+ !isa<FileUnit>(getDeclContext()->getModuleScopeContext()))
return getLocFromSource();
auto *File = cast<FileUnit>(getDeclContext()->getModuleScopeContext());
switch(File->getKind()) {
@@ -1365,9 +1377,31 @@
auto PBE = PatternBindingEntry(Pat, EqualLoc, E, BindingInitContext);
auto *Result = create(Ctx, StaticLoc, StaticSpelling, VarLoc, PBE, Parent);
- if (BindingInitContext)
+ if (BindingInitContext) {
cast<PatternBindingInitializer>(BindingInitContext)->setBinding(Result, 0);
+ // If the expression contains any closures, then we must change the
+ // closures' parents to `BindingInitContext`, because the closures are now
+ // children of `BindingInitContext`.
+ if (E) {
+ class Walker : public ASTWalker {
+ public:
+ DeclContext *NewParent;
+ explicit Walker(DeclContext *NewParent) : NewParent(NewParent) {}
+ virtual std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
+ if (auto *ACE = dyn_cast<AbstractClosureExpr>(E)) {
+ ACE->setParent(NewParent);
+ // Don't set the parents of nested closures.
+ return { false, E };
+ }
+ return { true, E };
+ }
+ };
+ Walker walker(BindingInitContext);
+ E->walk(walker);
+ }
+ }
+
return Result;
}
@@ -5021,7 +5055,10 @@
auto module = getModuleContext();
if (module != module->getASTContext().getStdlibModule() &&
!module->getName().is("Foundation") &&
- !module->getName().is("_Differentiation")) {
+ !module->getName().is("_Differentiation") &&
+ // SWIFT_ENABLE_TENSORFLOW
+ !module->getName().is("TensorFlow")) {
+ // SWIFT_ENABLE_TENSORFLOW END
const_cast<ProtocolDecl *>(this)->Bits.ProtocolDecl.KnownProtocol = 1;
return;
}
@@ -5067,6 +5104,22 @@
return KnownDerivableProtocolKind::AdditiveArithmetic;
case KnownProtocolKind::Differentiable:
return KnownDerivableProtocolKind::Differentiable;
+ // SWIFT_ENABLE_TENSORFLOW
+ case KnownProtocolKind::PointwiseMultiplicative:
+ return KnownDerivableProtocolKind::PointwiseMultiplicative;
+ case KnownProtocolKind::ElementaryFunctions:
+ return KnownDerivableProtocolKind::ElementaryFunctions;
+ case KnownProtocolKind::KeyPathIterable:
+ return KnownDerivableProtocolKind::KeyPathIterable;
+ case KnownProtocolKind::TensorArrayProtocol:
+ return KnownDerivableProtocolKind::TensorArrayProtocol;
+ case KnownProtocolKind::TensorGroup:
+ return KnownDerivableProtocolKind::TensorGroup;
+ case KnownProtocolKind::VectorProtocol:
+ return KnownDerivableProtocolKind::VectorProtocol;
+ case KnownProtocolKind::EuclideanDifferentiable:
+ return KnownDerivableProtocolKind::EuclideanDifferentiable;
+ // SWIFT_ENABLE_TENSORFLOW END
default: return None;
}
}
diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp
index b6bf413..b2dabac 100644
--- a/lib/AST/TypeCheckRequests.cpp
+++ b/lib/AST/TypeCheckRequests.cpp
@@ -380,6 +380,10 @@
if (auto attr = source.dyn_cast<SpecializeAttr *>())
return attr->getLocation();
+ // SWIFT_ENABLE_TENSORFLOW
+ if (auto attr = source.dyn_cast<DifferentiableAttr *>())
+ return attr->getLocation();
+
return source.get<GenericParamList *>()->getWhereLoc();
}
@@ -420,6 +424,7 @@
return whereClause->getRequirements();
}
+
return { };
}
diff --git a/lib/AST/TypeRepr.cpp b/lib/AST/TypeRepr.cpp
index 6490258..60cbddb 100644
--- a/lib/AST/TypeRepr.cpp
+++ b/lib/AST/TypeRepr.cpp
@@ -148,7 +148,6 @@
Printer.printSimpleAttr("@escaping") << " ";
if (hasAttr(TAK_noDerivative))
Printer.printSimpleAttr("@noDerivative") << " ";
-
if (hasAttr(TAK_differentiable)) {
if (Attrs.isLinear()) {
Printer.printSimpleAttr("@differentiable(linear)") << " ";
diff --git a/lib/Basic/Platform.cpp b/lib/Basic/Platform.cpp
index ad1db82..526f0c2 100644
--- a/lib/Basic/Platform.cpp
+++ b/lib/Basic/Platform.cpp
@@ -81,9 +81,10 @@
bool swift::tripleRequiresRPathForSwiftInOS(const llvm::Triple &triple) {
if (triple.isMacOSX()) {
- // macOS 10.14.4 contains a copy of Swift, but the linker will still use an
- // rpath-based install name until 10.15.
- return triple.isMacOSXVersionLT(10, 15);
+ // SWIFT_ENABLE_TENSORFLOW
+ // For TensorFlow, use the toolchain libs, not system ones
+ return false;
+ // SWIFT_ENABLE_TENSORFLOW END
}
if (triple.isiOS()) {
diff --git a/lib/ClangImporter/ClangImporter.cpp b/lib/ClangImporter/ClangImporter.cpp
index 0e4978a..ab2612c 100644
--- a/lib/ClangImporter/ClangImporter.cpp
+++ b/lib/ClangImporter/ClangImporter.cpp
@@ -629,6 +629,21 @@
}
}
+ // SWIFT_ENABLE_TENSORFLOW
+ // Include platform-specific standard library modulemaps.
+ SmallString<128> platformSpecificModuleMapDir;
+ platformSpecificModuleMapDir = searchPathOpts.RuntimeResourcePath;
+ llvm::sys::path::append(
+ platformSpecificModuleMapDir,
+ swift::getPlatformNameForTriple(triple),
+ swift::getMajorArchitectureName(triple),
+ "modulemaps");
+ if (llvm::sys::fs::exists(platformSpecificModuleMapDir)) {
+ invocationArgStrs.push_back("-I");
+ invocationArgStrs.push_back(
+ std::string(platformSpecificModuleMapDir.str()));
+ }
+
if (searchPathOpts.SDKPath.empty()) {
invocationArgStrs.push_back("-Xclang");
invocationArgStrs.push_back("-nostdsysteminc");
@@ -1059,10 +1074,22 @@
// Set up the file manager.
{
+ // SWIFT_ENABLE_TENSORFLOW
+ auto clangFileSystem = ctx.SourceMgr.getFileSystem();
+ if (importerOpts.InMemoryOutputFileSystem) {
+ instance.setInMemoryOutputFileSystem(
+ importerOpts.InMemoryOutputFileSystem);
+ llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> overlayFileSystem(
+ new llvm::vfs::OverlayFileSystem(clangFileSystem));
+ overlayFileSystem->pushOverlay(importerOpts.InMemoryOutputFileSystem);
+ clangFileSystem = overlayFileSystem;
+ }
+ // TODO(asuhan): Check this, CompilerInstance::setVirtualFileSystem removed in
+ // https://github.com/apple/swift-clang/commit/5f92395a7e64526f6f94f31a9d143f81a69e9209.
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS =
clang::createVFSFromCompilerInvocation(instance.getInvocation(),
instance.getDiagnostics(),
- ctx.SourceMgr.getFileSystem());
+ clangFileSystem);
instance.createFileManager(std::move(VFS));
}
diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp
index db5dd76..9bf8f81 100644
--- a/lib/ClangImporter/ImportDecl.cpp
+++ b/lib/ClangImporter/ImportDecl.cpp
@@ -3295,6 +3295,12 @@
// Create the struct declaration and record it.
auto name = importedName.getDeclName().getBaseIdentifier();
+ {
+ // Hack for nested types (They produce cycles)...
+ auto Known = Impl.ImportedDecls.find({decl->getCanonicalDecl(), getVersion()});
+ if (Known != Impl.ImportedDecls.end())
+ return Known->second;
+ }
auto result = Impl.createDeclWithClangNode<StructDecl>(decl,
AccessLevel::Public,
Impl.importSourceLoc(decl->getBeginLoc()),
diff --git a/lib/Driver/DarwinToolChains.cpp b/lib/Driver/DarwinToolChains.cpp
index 35026fd..58ab251 100644
--- a/lib/Driver/DarwinToolChains.cpp
+++ b/lib/Driver/DarwinToolChains.cpp
@@ -412,8 +412,16 @@
Arguments.push_back(context.Args.MakeArgString(path));
}
- if (context.Args.hasFlag(options::OPT_toolchain_stdlib_rpath,
- options::OPT_no_toolchain_stdlib_rpath, false)) {
+ // SWIFT_ENABLE_TENSORFLOW
+ // NOTE(TF-797): default true for toolchain stdlib rpath to prevent linker
+ // issues. This works around the fact that TensorFlow/Python modules do not
+ // exist in `/usr/lib/swift` on Darwin platforms.
+ // Relevant Swift-in-Darwin-OSs patch:
+ // https://github.com/apple/swift/pull/24787.
+ if (!context.Args.hasArg(options::OPT_no_stdlib_rpath) &&
+ context.Args.hasFlag(options::OPT_toolchain_stdlib_rpath,
+ options::OPT_no_toolchain_stdlib_rpath, true)) {
+ // SWIFT_ENABLE_TENSORFLOW END
// If the user has explicitly asked for a toolchain stdlib, we should
// provide one using -rpath. This used to be the default behaviour but it
// was considered annoying in at least the SwiftPM scenario (see
diff --git a/lib/Driver/Driver.cpp b/lib/Driver/Driver.cpp
index 2f6fab9..dc4656a 100644
--- a/lib/Driver/Driver.cpp
+++ b/lib/Driver/Driver.cpp
@@ -2368,6 +2368,15 @@
}
}
+ // SWIFT_ENABLE_TENSORFLOW
+ for (const Action *A : TopLevelActions) {
+ if (A->getKind() == Action::Kind::REPLJob) {
+ Diags.diagnose(SourceLoc(),
+ diag::error_tensorflow_toolchain_repl_not_supported);
+ return;
+ }
+ }
+
for (const Action *A : TopLevelActions) {
if (auto *JA = dyn_cast<JobAction>(A)) {
(void)buildJobsForAction(C, JA, OFM, workingDirectory, /*TopLevel=*/true,
diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp
index 0cdcd7a..352bc79 100644
--- a/lib/Driver/ToolChains.cpp
+++ b/lib/Driver/ToolChains.cpp
@@ -272,6 +272,11 @@
options::OPT_enable_direct_intramodule_dependencies,
options::OPT_disable_direct_intramodule_dependencies);
+ // SWIFT_ENABLE_TENSORFLOW
+ inputArgs.AddLastArg(
+ arguments, options::OPT_enable_experimental_forward_mode_differentiation);
+ // SWIFT_ENABLE_TENSORFLOW END
+
// Pass on any build config options
inputArgs.AddAllArgs(arguments, options::OPT_D);
diff --git a/lib/Frontend/ModuleInterfaceLoader.cpp b/lib/Frontend/ModuleInterfaceLoader.cpp
index 1f2df7b..d2d3923 100644
--- a/lib/Frontend/ModuleInterfaceLoader.cpp
+++ b/lib/Frontend/ModuleInterfaceLoader.cpp
@@ -1288,16 +1288,14 @@
// required by sourcekitd.
subClangImporterOpts.DetailedPreprocessingRecord =
clangImporterOpts.DetailedPreprocessingRecord;
- // We need to add these extra clang flags because explict module building
- // related flags are all there: -fno-implicit-modules, -fmodule-map-file=,
- // and -fmodule-file=.
- // If we don't add these flags, the interface will be built with implicit
- // PCMs.
- subClangImporterOpts.ExtraArgs = clangImporterOpts.ExtraArgs;
- for (auto arg: subClangImporterOpts.ExtraArgs) {
- GenericArgs.push_back("-Xcc");
- GenericArgs.push_back(ArgSaver.save(arg));
- }
+ // SWIFT_ENABLE_TENSORFLOW
+ // If the ClangModuleLoader is using an InMemoryOutputFileSystem, the
+ // subinstance loader should use it as well, as files written to the file
+ // system may not be visible to read, causing subinvocations to fail loading
+ // dependencies.
+ subClangImporterOpts.InMemoryOutputFileSystem =
+ clangImporterOpts.InMemoryOutputFileSystem;
+ // SWIFT_ENABLE_TENSORFLOW END
// Tell the genericSubInvocation to serialize dependency hashes if asked to do so.
auto &frontendOpts = genericSubInvocation.getFrontendOptions();
diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp
index d3b9362..d6fe81c 100644
--- a/lib/IDE/CodeCompletion.cpp
+++ b/lib/IDE/CodeCompletion.cpp
@@ -2926,6 +2926,12 @@
DynamicLookupInfo dynamicLookupInfo) {
if (FD->getBaseIdentifier().empty())
return;
+
+ // Suppress "sequenced" as a result, because it crashes completions.
+ // TODO(TF-315): Fix properly and then remove this.
+ if (FD->getBaseIdentifier().str() == "sequenced")
+ return;
+
foundFunction(FD);
const Identifier Name = FD->getBaseIdentifier();
diff --git a/lib/IDE/Utils.cpp b/lib/IDE/Utils.cpp
index 0ee903e..40433cd 100644
--- a/lib/IDE/Utils.cpp
+++ b/lib/IDE/Utils.cpp
@@ -274,6 +274,9 @@
CompilerInvocation &Invocation, ArrayRef<const char *> OrigArgs,
DiagnosticEngine &Diags, StringRef UnresolvedPrimaryFile,
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
+ // SWIFT_ENABLE_TENSORFLOW
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> InMemoryOutputFileSystem,
+ // SWIFT_ENABLE_TENSORFLOW END
const std::string &runtimeResourcePath,
const std::string &diagnosticDocumentationPath,
bool shouldOptimizeForIDE, time_t sessionTimestamp, std::string &Error) {
@@ -315,6 +318,9 @@
ClangImporterOptions &ImporterOpts = Invocation.getClangImporterOptions();
ImporterOpts.DetailedPreprocessingRecord = true;
+ // SWIFT_ENABLE_TENSORFLOW
+ ImporterOpts.InMemoryOutputFileSystem = InMemoryOutputFileSystem;
+ // SWIFT_ENABLE_TENSORFLOW END
assert(!Invocation.getModuleName().empty());
diff --git a/lib/IRGen/GenDecl.cpp b/lib/IRGen/GenDecl.cpp
index 5c7c205..83cf02b 100644
--- a/lib/IRGen/GenDecl.cpp
+++ b/lib/IRGen/GenDecl.cpp
@@ -1907,21 +1907,19 @@
info.UseDLLStorage ? llvm::GlobalValue::DLLImportStorageClass
: llvm::GlobalValue::DefaultStorageClass;
+ // SWIFT_ENABLE_TENSORFLOW: Cases slightly modified to fix TF-587.
switch (linkage) {
case SILLinkage::Public:
return {llvm::GlobalValue::ExternalLinkage, PublicDefinitionVisibility,
ExportedStorage};
- case SILLinkage::PublicNonABI:
- return isDefinition ? RESULT(WeakODR, Hidden, Default)
- : RESULT(External, Hidden, Default);
-
case SILLinkage::Shared:
case SILLinkage::SharedExternal:
return isDefinition ? RESULT(LinkOnceODR, Hidden, Default)
: RESULT(External, Hidden, Default);
case SILLinkage::Hidden:
+ case SILLinkage::PublicNonABI:
return RESULT(External, Hidden, Default);
case SILLinkage::Private: {
diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp
index 45fce96..98232c7 100644
--- a/lib/IRGen/GenFunc.cpp
+++ b/lib/IRGen/GenFunc.cpp
@@ -536,7 +536,6 @@
// contexts into the pointer value, so let's not take any spare bits from
// it.
spareBits.appendClearBits(IGM.getPointerSize().getValueInBits());
-
if (T->isNoEscape()) {
// @noescape thick functions are trivial types.
return FuncTypeInfo::create(
@@ -767,6 +766,16 @@
// Create a new explosion for potentially reabstracted parameters.
Explosion args;
+ // SWIFT_ENABLE_TENSORFLOW
+ // The witness method self argument comes after polymorphic arguments (and is
+ // followed by the self type and the witness table). However, we may encounter
+ // the witness method self value before reaching the polymorphic arguments. So
+ // we create a special explosion for storing the witness method self value
+ // until it's time to add it to 'args'.
+ bool isWitnessMethodCallee = origType->getRepresentation() ==
+ SILFunctionTypeRepresentation::WitnessMethod;
+ Explosion witnessMethodSelfValue;
+
Address resultValueAddr;
{
@@ -813,6 +822,10 @@
// Reemit the parameters as unsubstituted.
for (unsigned i = 0; i < outType->getParameters().size(); ++i) {
+ // SWIFT_ENABLE_TENSORFLOW
+ bool isWitnessMethodCalleeSelf =
+ (isWitnessMethodCallee && i + 1 == origType->getParameters().size());
+
auto origParamInfo = origType->getParameters()[i];
auto &ti = IGM.getTypeInfoForLowered(origParamInfo.getArgumentType(
IGM.getSILModule(), origType, IGM.getMaximalTypeExpansionContext()));
@@ -828,7 +841,8 @@
if (addr->getType() != ti.getStorageType()->getPointerTo())
addr = subIGF.Builder.CreateBitCast(addr,
ti.getStorageType()->getPointerTo());
- args.add(addr);
+ // SWIFT_ENABLE_TENSORFLOW
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args).add(addr);
continue;
}
@@ -841,7 +855,9 @@
origParamInfo,
outType,
outTypeParamInfo,
- origParams, args);
+ origParams,
+ // SWIFT_ENABLE_TENSORFLOW
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args));
continue;
}
@@ -872,7 +888,10 @@
Explosion nativeApplyArg = nativeSchemaOrigParam.mapIntoNative(
subIGF.IGM, subIGF, nonNativeApplyArg, origParamSILType, false);
assert(nonNativeApplyArg.empty());
- nativeApplyArg.transferInto(args, nativeApplyArg.size());
+ // SWIFT_ENABLE_TENSORFLOW
+ nativeApplyArg.transferInto(
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args),
+ nativeApplyArg.size());
}
}
@@ -983,12 +1002,14 @@
auto haveContextArgument =
calleeHasContext || hasSelfContextParameter(origType);
+#if 0 // from master - is this okay?
// Witness method calls expect self, followed by the self type followed by,
// the witness table at the end of the parameter list. But polymorphic
// arguments come before this.
bool isWitnessMethodCallee = origType->getRepresentation() ==
SILFunctionTypeRepresentation::WitnessMethod;
Explosion witnessMethodSelfValue;
+#endif
llvm::Value *lastCapturedFieldPtr = nullptr;
diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp
index d18859e..1d0814b 100644
--- a/lib/IRGen/GenMeta.cpp
+++ b/lib/IRGen/GenMeta.cpp
@@ -5045,6 +5045,16 @@
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::Differentiable:
case KnownProtocolKind::FloatingPoint:
+ // SWIFT_ENABLE_TENSORFLOW
+ case KnownProtocolKind::PointwiseMultiplicative:
+ case KnownProtocolKind::ElementaryFunctions:
+ case KnownProtocolKind::KeyPathIterable:
+ case KnownProtocolKind::TensorArrayProtocol:
+ case KnownProtocolKind::TensorGroup:
+ case KnownProtocolKind::VectorProtocol:
+ case KnownProtocolKind::EuclideanDifferentiable:
+ case KnownProtocolKind::Expression:
+ // SWIFT_ENABLE_TENSORFLOW END
return SpecialProtocol::None;
}
diff --git a/lib/IRGen/IRGenModule.cpp b/lib/IRGen/IRGenModule.cpp
index dd90ae9..91eefd6 100644
--- a/lib/IRGen/IRGenModule.cpp
+++ b/lib/IRGen/IRGenModule.cpp
@@ -225,6 +225,10 @@
Int32Ty = llvm::Type::getInt32Ty(getLLVMContext());
Int32PtrTy = Int32Ty->getPointerTo();
Int64Ty = llvm::Type::getInt64Ty(getLLVMContext());
+ // SWIFT_ENABLE_TENSORFLOW
+ DoubleTy = llvm::Type::getDoubleTy(getLLVMContext());
+ FloatTy = llvm::Type::getFloatTy(getLLVMContext());
+
Int8PtrTy = llvm::Type::getInt8PtrTy(getLLVMContext());
Int8PtrPtrTy = Int8PtrTy->getPointerTo(0);
SizeTy = DataLayout.getIntPtrType(getLLVMContext(), /*addrspace*/ 0);
@@ -694,6 +698,12 @@
static bool isReturnedAttribute(llvm::Attribute::AttrKind Attr) {
return Attr == llvm::Attribute::Returned;
}
+// SWIFT_ENABLE_TENSORFLOW
+// Similar to the 'return' attribute we assume that the 'sret' attributed is
+// associated with the first function parameter.
+static bool isStructRetAttribute(llvm::Attribute::AttrKind Attr) {
+ return Attr == llvm::Attribute::StructRet;
+}
namespace {
bool isStandardLibrary(const llvm::Module &M) {
@@ -792,7 +802,8 @@
for (auto Attr : attrs) {
if (isReturnAttribute(Attr))
buildRetAttr.addAttribute(Attr);
- else if (isReturnedAttribute(Attr))
+ // SWIFT_ENABLE_TENSORFLOW
+ else if (isReturnedAttribute(Attr) || isStructRetAttribute(Attr))
buildFirstParamAttr.addAttribute(Attr);
else
buildFnAttr.addAttribute(Attr);
diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp
index 50b443d..1ea62f4 100644
--- a/lib/IRGen/IRGenSIL.cpp
+++ b/lib/IRGen/IRGenSIL.cpp
@@ -848,7 +848,7 @@
}
}
}
-
+
//===--------------------------------------------------------------------===//
// SIL instruction lowering
//===--------------------------------------------------------------------===//
diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp
index 7e1ec01..7feccc2 100644
--- a/lib/Parse/ParseDecl.cpp
+++ b/lib/Parse/ParseDecl.cpp
@@ -34,6 +34,7 @@
#include "swift/Basic/Defer.h"
#include "swift/Basic/Statistic.h"
#include "swift/Basic/StringExtras.h"
+#include "clang/Basic/CharInfo.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
@@ -41,6 +42,7 @@
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/ADT/StringSet.h"
#include <algorithm>
using namespace swift;
@@ -994,6 +996,7 @@
.fixItReplace(withRespectToRange, "wrt:");
return errorAndSkipUntilConsumeRightParen(*this, AttrName);
}
+
// Parse the optional 'wrt' differentiability parameters clause.
if (isIdentifier(Tok, "wrt")) {
if (parseDifferentiabilityParametersClause(parameters, AttrName))
@@ -4005,8 +4008,11 @@
if (DeclResult.isNonNull()) {
Decl *D = DeclResult.get();
- if (!declWasHandledAlready(D))
+
+ if (!declWasHandledAlready(D)) {
Handler(D);
+ }
+
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
}
@@ -6039,6 +6045,7 @@
pattern->forEachVariable([&](VarDecl *VD) {
VD->setStatic(StaticLoc.isValid());
VD->getAttrs() = Attributes;
+
setLocalDiscriminator(VD);
VD->setTopLevelGlobal(topLevelDecl);
diff --git a/lib/SIL/IR/SILFunction.cpp b/lib/SIL/IR/SILFunction.cpp
index dd118be..d51c3c6 100644
--- a/lib/SIL/IR/SILFunction.cpp
+++ b/lib/SIL/IR/SILFunction.cpp
@@ -68,6 +68,9 @@
if (!name.empty()) {
entry = &*M.FunctionTable.insert(std::make_pair(name, nullptr)).first;
PrettyStackTraceSILFunction trace("creating", entry->getValue());
+ if (entry->getValue()) {
+ entry->getValue()->dump();
+ }
assert(!entry->getValue() && "function already exists");
name = entry->getKey();
}
@@ -113,6 +116,11 @@
validateSubclassScope(classSubclassScope, isThunk, nullptr);
setDebugScope(DebugScope);
+ // SWIFT_ENABLE_TENSORFLOW
+ // Function type cannot be @differentiable.
+ assert(!LoweredType->isDifferentiable() &&
+ "SIL function declarations cannot have an @differentiable type");
+
if (InsertBefore)
Module.functions.insert(SILModule::iterator(InsertBefore), this);
else
@@ -229,6 +237,10 @@
SILType GenericEnvironment::mapTypeIntoContext(SILModule &M,
SILType type) const {
+ if (type.hasArchetype()) {
+ llvm::errs() << "TYPE HAS ARCHETYPE\n";
+ type.dump();
+ }
assert(!type.hasArchetype());
auto genericSig = getGenericSignature().getCanonicalSignature();
diff --git a/lib/SIL/Parser/ParseSIL.cpp b/lib/SIL/Parser/ParseSIL.cpp
index f12b216..b6603a1 100644
--- a/lib/SIL/Parser/ParseSIL.cpp
+++ b/lib/SIL/Parser/ParseSIL.cpp
@@ -616,7 +616,6 @@
return fn;
}
-
/// getBBForDefinition - Return the SILBasicBlock for a definition of the
/// specified block.
SILBasicBlock *SILParser::getBBForDefinition(Identifier Name, SourceLoc Loc) {
@@ -2767,6 +2766,7 @@
ResultVal = B.createFunctionRef(InstLoc, Fn);
break;
}
+
case SILInstructionKind::DynamicFunctionRefInst: {
SILFunction *Fn;
if (parseSILFunctionRef(InstLoc, Fn) || parseSILDebugLocation(InstLoc, B))
@@ -2789,6 +2789,7 @@
ResultVal = B.createPreviousDynamicFunctionRef(InstLoc, Fn);
break;
}
+
case SILInstructionKind::BuiltinInst: {
if (P.Tok.getKind() != tok::string_literal) {
P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "builtin name");
@@ -6526,6 +6527,8 @@
}
SILDeclRef Ref;
+
+
Identifier FuncName;
SourceLoc FuncLoc;
if (witnessState.parseSILDeclRef(Ref, true) ||
@@ -6547,6 +6550,7 @@
return true;
}
}
+
witnessEntries.push_back(SILWitnessTable::MethodWitness{
Ref, Func
});
diff --git a/lib/SIL/Verifier/SILVerifier.cpp b/lib/SIL/Verifier/SILVerifier.cpp
index c2f3de4..ff29fdb 100644
--- a/lib/SIL/Verifier/SILVerifier.cpp
+++ b/lib/SIL/Verifier/SILVerifier.cpp
@@ -66,9 +66,16 @@
"verify-abort-on-failure",
llvm::cl::init(true));
+// SWIFT_ENABLE_TENSORFLOW
+// This flag is temporarily set to false because debug scope verification does
+// not handle inlined call sites. This is problematic for deabstraction, which
+// does performance inlining at -Onone.
+// When debug scope verification handles inlined call sites, set this flag to
+// true.
+// Documented at SR-8114.
static llvm::cl::opt<bool> VerifyDIHoles(
"verify-di-holes",
- llvm::cl::init(true));
+ llvm::cl::init(false));
static llvm::cl::opt<bool> SkipConvertEscapeToNoescapeAttributes(
"verify-skip-convert-escape-to-noescape-attributes", llvm::cl::init(false));
@@ -1572,6 +1579,12 @@
}
void checkPartialApplyInst(PartialApplyInst *PAI) {
+ if (PAI->getModule().getStage() != SILStage::Raw) {
+ require(!PAI->getFunctionType()->isDifferentiable(),
+ "partial_apply of differentiable funtions is only allowed "
+ "in raw SIL");
+ }
+
auto resultInfo = requireObjectType(SILFunctionType, PAI,
"result of partial_apply");
verifySILFunctionType(resultInfo);
@@ -1719,7 +1732,7 @@
return;
}
}
-
+
void checkFunctionRefBaseInst(FunctionRefBaseInst *FRI) {
auto fnType = requireObjectType(SILFunctionType, FRI,
"result of function_ref");
@@ -5564,6 +5577,70 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+/// Verify that a differentiability witness follows invariants.
+void SILDifferentiabilityWitness::verify(const SILModule &M) const {
+ // FIXME(TF-1197): Re-enable verification after substituted SIL function
+ // types.
+ return;
+#if 0
+#ifdef NDEBUG
+ if (!M.getOptions().VerifyAll)
+ return;
+#endif
+ // Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
+ // TODO: Check that derivative function types match excluding
+ // parameter/result conventions in lowered SIL.
+ if (M.getStage() == SILStage::Lowered)
+ return;
+ auto *origFn = getOriginalFunction();
+ auto origFnType = origFn->getLoweredFunctionType();
+ bool origIsReabstractionThunk = origFn->isThunk() == IsReabstractionThunk;
+ CanGenericSignature derivativeCanGenSig;
+ if (auto derivativeGenSig = getDerivativeGenericSignature())
+ derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
+ auto requireSameType =
+ [&](CanSILFunctionType type1, CanSILFunctionType type2,
+ const Twine &complaint) {
+ if (type1 == type2)
+ return;
+ llvm::dbgs() << "SIL verification failed: " << complaint << "\n";
+ llvm::dbgs() << " " << type1 << "\n " << type2 << "\n\n";
+ llvm::dbgs() << "In differentiability witness:\n";
+ print(llvm::dbgs());
+ // We abort by default because we want to always crash in
+ // the debugger.
+ if (AbortOnFailure)
+ abort();
+ else
+ exit(1);
+ };
+ if (auto *jvp = getJVP()) {
+ // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
+ // to accept result indices.
+ auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
+ getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
+ AutoDiffDerivativeFunctionKind::JVP, M.Types,
+ LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig,
+ origIsReabstractionThunk);
+ requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
+ "JVP type does not match expected JVP type");
+ }
+ if (auto *vjp = getVJP()) {
+ // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
+ // to result indices.
+ auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
+ getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
+ AutoDiffDerivativeFunctionKind::VJP, M.Types,
+ LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig,
+ origIsReabstractionThunk);
+ requireSameType(vjp->getLoweredFunctionType(), expectedVJPType,
+ "VJP type does not match expected VJP type");
+ }
+#endif
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
/// Verify the module.
void SILModule::verify() const {
if (!verificationEnabled(*this))
@@ -5641,6 +5718,22 @@
}
wt.verify(*this);
}
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // Check all differentiability witnesses.
+ LLVM_DEBUG(llvm::dbgs() <<
+ "*** Checking differentiability witnesses for duplicates ***\n");
+ llvm::DenseSet<SILDifferentiabilityWitnessKey> diffWitnesses;
+ for (auto &dw : getDifferentiabilityWitnesses()) {
+ LLVM_DEBUG(llvm::dbgs() << "Differentiability Witness:\n"; dw.dump());
+ if (!diffWitnesses.insert(dw.getKey()).second) {
+ llvm::errs() << "Differentiability witness redefined: ";
+ dw.dump();
+ assert(false && "triggering standard assertion failure routine");
+ }
+ dw.verify(*this);
+ }
+ // SWIFT_ENABLE_TENSORFLOW END
// Check property descriptors.
LLVM_DEBUG(llvm::dbgs() << "*** Checking property descriptors ***\n");
diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp
index 49433302..0fc857d 100644
--- a/lib/SILGen/SILGen.cpp
+++ b/lib/SILGen/SILGen.cpp
@@ -1006,6 +1006,7 @@
assert(!F->isExternalDeclaration() && "did not emit any function body?!");
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
F->print(llvm::dbgs()));
+
F->verify();
emitDifferentiabilityWitnessesForFunction(constant, F);
diff --git a/lib/SILOptimizer/IPO/CapturePropagation.cpp b/lib/SILOptimizer/IPO/CapturePropagation.cpp
index 2a9f2e9..06c5d19 100644
--- a/lib/SILOptimizer/IPO/CapturePropagation.cpp
+++ b/lib/SILOptimizer/IPO/CapturePropagation.cpp
@@ -408,6 +408,17 @@
return nullptr;
}
+ // SWIFT_ENABLE_TENSORFLOW
+ // Disable specialization for instructions that are operands of
+ // `differentiable_function` instructions. `differentiable_function`
+ // requires derivative function operand types to match expected derivative
+ // function types computed from the original function operand's type, so
+ // operands cannot be specialized individually without specializing the
+ // others.
+ if (!PAI->getUsersOfType<DifferentiableFunctionInst>().empty())
+ return nullptr;
+ // SWIFT_ENABLE_TENSORFLOW END
+
auto Rep = Specialized->getLoweredFunctionType()->getRepresentation();
if (getSILFunctionLanguage(Rep) != SILFunctionLanguage::Swift)
return nullptr;
diff --git a/lib/SILOptimizer/Mandatory/CMakeLists.txt b/lib/SILOptimizer/Mandatory/CMakeLists.txt
index 8249f146..72aac6f 100644
--- a/lib/SILOptimizer/Mandatory/CMakeLists.txt
+++ b/lib/SILOptimizer/Mandatory/CMakeLists.txt
@@ -5,6 +5,9 @@
ClosureLifetimeFixup.cpp
ConstantPropagation.cpp
DefiniteInitialization.cpp
+ # SWIFT_ENABLE_TENSORFLOW
+ Differentiation.cpp
+ # SWIFT_ENABLE_TENSORFLOW END
DIMemoryUseCollector.cpp
DataflowDiagnostics.cpp
DiagnoseInfiniteRecursion.cpp
diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp
index 083f840..6c5f962 100644
--- a/lib/SILOptimizer/Mandatory/Differentiation.cpp
+++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp
@@ -16,6 +16,9 @@
#define DEBUG_TYPE "differentiation"
+#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/CommandLine.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/AnyFunctionRef.h"
@@ -35,6 +38,7 @@
#include "swift/SIL/PrettyStackTrace.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/TypeSubstCloner.h"
+#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/SILOptimizer/Differentiation/JVPCloner.h"
@@ -1414,7 +1418,7 @@
diag::autodiff_internal_swift_not_imported);
return;
}
- if (!astCtx.getLoadedModule(astCtx.Id_Differentiation)) {
+ if (!astCtx.getProtocol(KnownProtocolKind::Differentiable)) {
SourceLoc loc;
if (!context.getInvokers().empty()) {
loc = context.getInvokers().front().second.getLocation();
diff --git a/lib/SILOptimizer/PassManager/PassManager.cpp b/lib/SILOptimizer/PassManager/PassManager.cpp
index e6bee28..463fd6d 100644
--- a/lib/SILOptimizer/PassManager/PassManager.cpp
+++ b/lib/SILOptimizer/PassManager/PassManager.cpp
@@ -383,6 +383,17 @@
llvm::dbgs() << '\n';
}
+// SWIFT_ENABLE_TENSORFLOW
+static void logS4TFPassEvent(long long Delta, llvm::sys::TimePoint<> StartTime,
+ StringRef passName, bool isFunctionPass,
+ StringRef funcName) {
+ auto tt = llvm::sys::toTimeT(StartTime);
+ auto strTime = ctime(&tt);
+ strTime[strlen(strTime) - 1] = '\0';
+ llvm::dbgs() << "S4TF," << Delta << "," << strTime << "," << passName << ","
+ << (isFunctionPass ? "F" : "M") << "," << funcName << "\n";
+}
+
void SILPassManager::runPassOnFunction(unsigned TransIdx, SILFunction *F) {
assert(analysesUnlocked() && "Expected all analyses to be unlocked!");
@@ -460,6 +471,13 @@
if (SILPrintPassTime) {
llvm::dbgs() << Delta << " (" << SFT->getID() << "," << F->getName()
<< ")\n";
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // Write CSV-formatted events, so that we can do aggregate analysis. Format:
+ // [S4TF] Delta,StartTime,PassName,PassType,FuncName
+ // Here PassType is F since it's a function pass.
+ logS4TFPassEvent(Delta, StartTime, SFT->getID(), /*isFunctionPass*/ true,
+ F->getName());
}
// If this pass invalidated anything, print and verify.
@@ -648,6 +666,13 @@
auto Delta = (std::chrono::system_clock::now() - StartTime).count();
if (SILPrintPassTime) {
llvm::dbgs() << Delta << " (" << SMT->getID() << ",Module)\n";
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // Write CSV-formatted events, so that we can do aggregate analysis. Format:
+ // [S4TF] Delta,StartTime,PassName,PassType
+ // Here PassType is M since it's a module pass.
+ logS4TFPassEvent(Delta, StartTime, SMT->getID(), /*isFunctionPass*/ false,
+ /*funcName*/ "");
}
// If this pass invalidated anything, print and verify.
diff --git a/lib/SILOptimizer/Transforms/PruneVTables.cpp b/lib/SILOptimizer/Transforms/PruneVTables.cpp
index f1c37c8..1548d38 100644
--- a/lib/SILOptimizer/Transforms/PruneVTables.cpp
+++ b/lib/SILOptimizer/Transforms/PruneVTables.cpp
@@ -104,6 +104,10 @@
}
void run() override {
+ // SWIFT_ENABLE_TENSORFLOW
+ return;
+ // SWIFT_ENABLE_TENSORFLOW END
+
SILModule *M = getModule();
for (auto &vtable : M->getVTables()) {
diff --git a/lib/SILOptimizer/Transforms/SILMem2Reg.cpp b/lib/SILOptimizer/Transforms/SILMem2Reg.cpp
index 332af38..95cdcc3 100644
--- a/lib/SILOptimizer/Transforms/SILMem2Reg.cpp
+++ b/lib/SILOptimizer/Transforms/SILMem2Reg.cpp
@@ -176,6 +176,10 @@
/// Promote memory to registers. Return True on change.
bool run();
+
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Promote specific allocations.
+ void promoteAllocs(ArrayRef<AllocStackInst*> allocs);
};
} // end anonymous namespace
@@ -950,6 +954,26 @@
return Changed;
}
+
+/// SWIFT_ENABLE_TENSORFLOW
+/// Promote specific allocations.
+void MemoryToRegisters::promoteAllocs(ArrayRef<AllocStackInst*> allocs) {
+ F.verifyCriticalEdges();
+
+ // Compute dominator tree node levels for the function.
+ DomTreeLevelMap DomTreeLevels;
+ computeDomTreeLevels(DT, DomTreeLevels);
+
+ for (auto alloc : allocs) {
+ if (!promoteSingleAllocation(alloc, DomTreeLevels))
+ continue;
+
+ if (alloc->use_empty())
+ alloc->eraseFromParent();
+ }
+}
+
+
namespace {
class SILMem2Reg : public SILFunctionTransform {
diff --git a/lib/SILOptimizer/Utils/Generics.cpp b/lib/SILOptimizer/Utils/Generics.cpp
index 2bde614..7740d08 100644
--- a/lib/SILOptimizer/Utils/Generics.cpp
+++ b/lib/SILOptimizer/Utils/Generics.cpp
@@ -536,6 +536,20 @@
return false;
}
+ // SWIFT_ENABLE_TENSORFLOW
+ // Disable specialization for instructions that are operands of
+ // `differentiable_function` instructions. `differentiable_function`
+ // requires derivative function operand types to match expected derivative
+ // function types computed from the original function operand's type, so
+ // operands cannot be specialized individually without specializing the
+ // others.
+ if (Apply.getInstruction())
+ for (auto result : Apply.getInstruction()->getResults())
+ for (auto use : result->getUses())
+ if (isa<DifferentiableFunctionInst>(use->getUser()))
+ return false;
+ // SWIFT_ENABLE_TENSORFLOW END
+
return true;
}
diff --git a/lib/Sema/CMakeLists.txt b/lib/Sema/CMakeLists.txt
index d35c849..b9782c5 100644
--- a/lib/Sema/CMakeLists.txt
+++ b/lib/Sema/CMakeLists.txt
@@ -25,6 +25,16 @@
DerivedConformanceEquatableHashable.cpp
DerivedConformanceComparable.cpp
DerivedConformanceError.cpp
+ # SWIFT_ENABLE_TENSORFLOW
+ DerivedConformanceAdditiveArithmetic.cpp
+ DerivedConformancePointwiseMultiplicative.cpp
+ DerivedConformanceElementaryFunctions.cpp
+ DerivedConformanceVectorProtocol.cpp
+ DerivedConformanceDifferentiable.cpp
+ DerivedConformanceKeyPathIterable.cpp
+ DerivedConformanceTensorArrayProtocol.cpp
+ DerivedConformanceTensorGroup.cpp
+ # SWIFT_ENABLE_TENSORFLOW END
DerivedConformanceRawRepresentable.cpp
DerivedConformances.cpp
ImportResolution.cpp
@@ -41,6 +51,9 @@
TypeCheckCaptures.cpp
TypeCheckCircularity.cpp
TypeCheckCodeCompletion.cpp
+ # SWIFT_ENABLE_TENSORFLOW
+ TypeCheckCompilerEvaluable.cpp
+ # SWIFT_ENABLE_TENSORFLOW END
TypeCheckConcurrency.cpp
TypeCheckConstraints.cpp
TypeCheckDecl.cpp
diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp
index 97dd019..2c34866 100644
--- a/lib/Sema/CSApply.cpp
+++ b/lib/Sema/CSApply.cpp
@@ -3904,10 +3904,20 @@
Expr *visitCoerceExpr(CoerceExpr *expr, Optional<unsigned> choice) {
// Simplify and update the type we're coercing to.
- assert(expr->getCastTypeRepr());
- const auto toType = simplifyType(cs.getType(expr->getCastTypeRepr()));
- expr->setCastType(toType);
- cs.setType(expr->getCastTypeRepr(), toType);
+ // SWIFT_ENABLE_TENSORFLOW
+ // Handle implicit `CoerceExpr` with null `TypeRepr`.
+ // Created by `KeyPathIterable` derived conformances.
+ if (expr->getCastType()) {
+ const auto toType = simplifyType(expr->getCastType());
+ expr->setCastType(toType);
+ } else {
+ assert(expr->getCastTypeRepr());
+ const auto toType = simplifyType(cs.getType(expr->getCastTypeRepr()));
+ expr->setCastType(toType);
+ cs.setType(expr->getCastTypeRepr(), toType);
+ }
+ const auto toType = expr->getCastType();
+ // SWIFT_ENABLE_TENSORFLOW END
// If this is a literal that got converted into constructor call
// lets put proper source information in place.
diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp
index fe04b9a..1579371 100644
--- a/lib/Sema/CSDiagnostics.cpp
+++ b/lib/Sema/CSDiagnostics.cpp
@@ -340,6 +340,8 @@
bool RequirementFailure::diagnoseAsError() {
const auto *reqDC = getRequirementDC();
auto *genericCtx = getGenericContext();
+ if (!genericCtx)
+ return false;
auto lhs = getLHS();
auto rhs = getRHS();
diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp
index 342891f..cb8f547 100644
--- a/lib/Sema/CSGen.cpp
+++ b/lib/Sema/CSGen.cpp
@@ -2766,9 +2766,16 @@
Type visitCoerceExpr(CoerceExpr *expr) {
// Validate the resulting type.
auto *const repr = expr->getCastTypeRepr();
- const auto toType = resolveTypeReferenceInExpression(
+ // SWIFT_ENABLE_TENSORFLOW
+ // Handle implicit `CoerceExpr` with null `TypeRepr`.
+ // Created by `KeyPathIterable` derived conformances.
+ auto toType = expr->getCastType();
+ if (!toType) {
+ toType = resolveTypeReferenceInExpression(
repr, TypeResolverContext::ExplicitCastExpr,
CS.getConstraintLocator(expr));
+ }
+ // SWIFT_ENABLE_TENSORFLOW END
if (!toType)
return nullptr;
diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp
index 384cad9..20e07a5 100644
--- a/lib/Sema/CodeSynthesis.cpp
+++ b/lib/Sema/CodeSynthesis.cpp
@@ -1305,6 +1305,10 @@
continue;
assert(!memberwiseInitDecl && "Memberwise initializer already found");
memberwiseInitDecl = initDecl;
+ // Overwrite access level only for implicit initializers, not user-defined
+ // initializers.
+ if (memberwiseInitDecl->isImplicit())
+ memberwiseInitDecl->overwriteAccess(accessLevel);
}
// Otherwise, create a memberwise initializer, set its access level, and
diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp
index 74dbdd7..f7aae30 100644
--- a/lib/Sema/ConstraintSystem.cpp
+++ b/lib/Sema/ConstraintSystem.cpp
@@ -1460,7 +1460,10 @@
auto hasAppliedSelf = doesMemberRefApplyCurriedSelf(baseObjTy, value);
baseObjTy = baseObjTy->getMetatypeInstanceType();
- FunctionType::Param baseObjParam(baseObjTy);
+ // SWIFT_ENABLE_TENSORFLOW
+ FunctionType::Param baseObjParam(
+ baseObjTy->getInOutObjectType(), Identifier(),
+ ParameterTypeFlags().withInOut(baseObjTy->is<InOutType>()));
if (auto *typeDecl = dyn_cast<TypeDecl>(value)) {
assert(!isa<ModuleDecl>(typeDecl) && "Nested module?");
@@ -3012,6 +3015,21 @@
// If there are multiple solutions, try to diagnose an ambiguity.
if (viable.size() > 1) {
+ // SWIFT_ENABLE_TENSORFLOW
+ if (DC->getParentModule()->getNameStr().startswith("__lldb_expr")) {
+ // TODO(https://bugs.swift.org/browse/SR-9814):
+ // If in LLDB repl mode, patch up the solution if we have ambiguity.
+ //
+ // This is a *temporary* short-term hack that simply returns the last
+ // solution. It seems to work for now and returns the lastly added
+ // definition during the repl session. However, this is extremely brittle and
+ // is not expected to work correctly all the time.
+ viable[0] = std::move(viable.back());
+ viable.erase(viable.begin() + 1, viable.end());
+ return SolutionResult::forSolved(std::move(viable[0]));
+ }
+ // SWIFT_ENABLE_TENSORFLOW
+
if (isDebugMode()) {
auto &log = llvm::errs();
log << "---Ambiguity error: " << viable.size()
diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp
index 5ab7ffc..15a981f 100644
--- a/lib/Sema/DerivedConformanceDifferentiable.cpp
+++ b/lib/Sema/DerivedConformanceDifferentiable.cpp
@@ -123,6 +123,23 @@
return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType;
}
+// SWIFT_ENABLE_TENSORFLOW
+/// Returns the `Differentiable.TangentVector` associated type witness
+/// for the given property declaration and declaration context.
+static Type getTangentVectorType(VarDecl *varDecl, DeclContext *DC) {
+ auto &C = varDecl->getASTContext();
+ auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
+ assert(diffableProto && "`Differentiable` protocol not found");
+ auto contextualType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
+ auto conf =
+ TypeChecker::conformsToProtocol(contextualType, diffableProto, DC);
+ assert(conf && "Contextual type must conform to `Differentiable`");
+ if (!conf)
+ return nullptr;
+ return conf.getTypeWitnessByName(contextualType, C.Id_TangentVector);
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
/// Returns true iff the given nominal type declaration can derive
/// `TangentVector` as `Self` in the given conformance context.
static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal,
@@ -647,9 +664,25 @@
// Otherwise, synthesize a new struct.
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
- auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType());
+ auto diffableType =
+ TypeLoc::withoutLoc(diffableProto->getDeclaredInterfaceType());
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
- auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType());
+ auto addArithType =
+ TypeLoc::withoutLoc(addArithProto->getDeclaredInterfaceType());
+ // SWIFT_ENABLE_TENSORFLOW
+ auto *pointMulProto =
+ C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
+ auto pointMulType =
+ TypeLoc::withoutLoc(pointMulProto->getDeclaredInterfaceType());
+ auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
+ auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredInterfaceType());
+ auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
+ auto vectorType =
+ TypeLoc::withoutLoc(vectorProto->getDeclaredInterfaceType());
+ auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
+ auto kpIterableType =
+ TypeLoc::withoutLoc(kpIterableProto->getDeclaredInterfaceType());
+ // SWIFT_ENABLE_TENSORFLOW END
// By definition, `TangentVector` must conform to `Differentiable` and
// `AdditiveArithmetic`.
@@ -659,6 +692,74 @@
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
+ // SWIFT_ENABLE_TENSORFLOW
+ // Add ad-hoc implicit conformances for `TangentVector`.
+ // TODO(TF-632): Remove this implicit conformance logic when synthesized
+ // member types can be extended.
+
+ // `TangentVector` struct can derive `PointwiseMultiplicative` if the
+ // `TangentVector` types of all stored properties conform to
+ // `PointwiseMultiplicative`.
+ bool canDerivePointwiseMultiplicative =
+ llvm::all_of(diffProperties, [&](VarDecl *vd) {
+ return TypeChecker::conformsToProtocol(
+ getTangentVectorType(vd, parentDC), pointMulProto, parentDC);
+ });
+
+ // `TangentVector` struct can derive `ElementaryFunctions` if the
+ // `TangentVector` types of all stored properties conform to
+ // `ElementaryFunctions`.
+ bool canDeriveElementaryFunctions =
+ llvm::all_of(diffProperties, [&](VarDecl *vd) {
+ return TypeChecker::conformsToProtocol(
+ getTangentVectorType(vd, parentDC), mathProto, parentDC);
+ });
+
+ // `TangentVector` struct can derive `VectorProtocol` if the `TangentVector`
+ // types of all members conform to `VectorProtocol` and share the same
+ // `VectorSpaceScalar` type.
+ Type sameScalarType;
+ bool canDeriveVectorProtocol =
+ !diffProperties.empty() && llvm::all_of(diffProperties, [&](VarDecl *vd) {
+ auto tanType = getTangentVectorType(vd, parentDC);
+ auto conf = TypeChecker::conformsToProtocol(tanType, vectorProto,
+ nominal);
+ if (!conf)
+ return false;
+ auto scalarType =
+ conf.getTypeWitnessByName(tanType, C.Id_VectorSpaceScalar);
+ if (!sameScalarType) {
+ sameScalarType = scalarType;
+ return true;
+ }
+ return scalarType->isEqual(sameScalarType);
+ });
+
+ // `TangentVector` struct should derive `KeyPathIterable` if the parent struct
+ // conforms to `KeyPathIterable`.
+ bool shouldDeriveKeyPathIterable =
+ !TypeChecker::conformsToProtocol(nominal->getDeclaredInterfaceType(),
+ kpIterableProto, parentDC)
+ .isInvalid();
+
+ // If all members conform to `PointwiseMultiplicative`, make the
+ // `TangentVector` struct conform to `PointwiseMultiplicative`.
+ if (canDerivePointwiseMultiplicative)
+ inherited.push_back(pointMulType);
+ // If all members conform to `ElementaryFunctions`, make the `TangentVector`
+ // struct conform to `ElementaryFunctions`.
+ if (canDeriveElementaryFunctions)
+ inherited.push_back(mathType);
+ // If all members also conform to `VectorProtocol` with the same `Scalar`
+ // type, make the `TangentVector` struct conform to `VectorProtocol`.
+ if (canDeriveVectorProtocol)
+ inherited.push_back(vectorType);
+ // If parent type conforms to `KeyPathIterable`, make the `TangentVector`
+ // struct conform to `KeyPathIterable`.
+ if (shouldDeriveKeyPathIterable)
+ inherited.push_back(kpIterableType);
+ // SWIFT_ENABLE_TENSORFLOW END
+
auto *structDecl =
new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(),
/*Inherited*/ C.AllocateCopy(inherited),
@@ -938,3 +1039,142 @@
// Otherwise, return nullptr.
return std::make_pair(nullptr, nullptr);
}
+
+// SWIFT_ENABLE_TENSORFLOW
+bool DerivedConformance::canDeriveEuclideanDifferentiable(
+ NominalTypeDecl *nominal, DeclContext *DC) {
+ auto &C = nominal->getASTContext();
+ auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
+ auto tangentVectorRequirement = getProtocolRequirement(diffableProto, C.Id_TangentVector);
+ if (!canDeriveDifferentiable(nominal, DC, tangentVectorRequirement))
+ return false;
+ auto *eucDiffProto =
+ C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
+ // Return true if all differentiation stored properties conform to
+ // `EuclideanDifferentiable`.
+ SmallVector<VarDecl *, 16> diffProperties;
+ getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
+ return llvm::all_of(diffProperties, [&](VarDecl *member) {
+ if (member->getInterfaceType()->hasError())
+ return false;
+ auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
+ return (bool)TypeChecker::conformsToProtocol(varType, eucDiffProto, DC);
+ });
+}
+
+/// Synthesize the `differentiableVectorView` property declaration.
+static ValueDecl *deriveEuclideanDifferentiable_differentiableVectorView(
+ DerivedConformance &derived) {
+ auto &C = derived.Context;
+ auto *parentDC = derived.getConformanceContext();
+
+ auto tangentType =
+ getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
+ auto tangentContextualType = parentDC->mapTypeIntoContext(tangentType);
+ auto *tangentDecl = cast<StructDecl>(tangentType->getAnyNominal());
+
+ VarDecl *vectorViewDecl;
+ PatternBindingDecl *pbDecl;
+ std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
+ C.Id_differentiableVectorView, tangentType, tangentContextualType,
+ /*isStatic*/ false, /*isFinal*/ true);
+
+ struct GetterSynthesizerContext {
+ StructDecl *tangentDecl;
+ Type tangentContextualType;
+ };
+
+ auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl,
+ void *ctx) -> std::pair<BraceStmt *, bool> {
+ auto *context = reinterpret_cast<GetterSynthesizerContext *>(ctx);
+ assert(context && "Invalid context");
+ auto *parentDC = getterDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto *module = nominal->getModuleContext();
+ auto &C = nominal->getASTContext();
+ auto *eucDiffProto =
+ C.getProtocol(KnownProtocolKind::EuclideanDifferentiable);
+ auto *vectorViewReq =
+ eucDiffProto->lookupDirect(C.Id_differentiableVectorView).front();
+
+ SmallVector<VarDecl *, 8> diffProperties;
+ getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);
+
+ // Create a reference to the memberwise initializer: `TangentVector.init`.
+ auto *memberwiseInitDecl =
+ context->tangentDecl->getEffectiveMemberwiseInitializer();
+ assert(memberwiseInitDecl && "Memberwise initializer must exist");
+ assert(diffProperties.size() ==
+ memberwiseInitDecl->getParameters()->size());
+ // `TangentVector`
+ auto *tangentTypeExpr =
+ TypeExpr::createImplicit(context->tangentContextualType, C);
+ // `TangentVector.init`
+ auto *initDRE = new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(),
+ /*Implicit*/ true);
+ initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+ auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr);
+ initExpr->setThrows(false);
+ initExpr->setImplicit();
+
+ // Create a call:
+ // TangentVector.init(
+ // <property_name_1...>:
+ // self.<property_name_1>.differentiableVectorView,
+ // <property_name_2...>:
+ // self.<property_name_2>.differentiableVectorView,
+ // ...
+ // )
+ SmallVector<Identifier, 8> argLabels;
+ SmallVector<Expr *, 8> memberRefs;
+ for (auto *member : diffProperties) {
+ auto *selfDRE =
+ new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(), DeclNameLoc(),
+ /*Implicit*/ true);
+ auto *memberExpr = new (C) MemberRefExpr(
+ selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto confRef = module->lookupConformance(memberType, eucDiffProto);
+ assert(confRef &&
+ "Member missing conformance to `EuclideanDifferentiable`");
+ ConcreteDeclRef memberDeclRef = vectorViewReq;
+ if (confRef.isConcrete())
+ memberDeclRef = confRef.getConcrete()->getWitnessDecl(vectorViewReq);
+ argLabels.push_back(member->getName());
+ memberRefs.push_back(new (C) MemberRefExpr(memberExpr, SourceLoc(),
+ memberDeclRef, DeclNameLoc(),
+ /*Implicit*/ true));
+ }
+ assert(memberRefs.size() == argLabels.size());
+ CallExpr *callExpr =
+ CallExpr::createImplicit(C, initExpr, memberRefs, argLabels);
+
+ // Create a return statement: `return TangentVector.init(...)`.
+ ASTNode retStmt =
+ new (C) ReturnStmt(SourceLoc(), callExpr, /*implicit*/ true);
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), retStmt, SourceLoc(),
+ /*implicit*/ true);
+ return std::make_pair(braceStmt, false);
+ };
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
+ vectorViewDecl, tangentContextualType);
+ getterDecl->setBodySynthesizer(
+ getterSynthesizer, /*context*/ C.AllocateObjectCopy(
+ GetterSynthesizerContext{tangentDecl, tangentContextualType}));
+ derived.addMembersToConformanceContext({vectorViewDecl, pbDecl});
+ return vectorViewDecl;
+}
+
+ValueDecl *
+DerivedConformance::deriveEuclideanDifferentiable(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getName() == Context.Id_differentiableVectorView)
+ return deriveEuclideanDifferentiable_differentiableVectorView(*this);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_euclidean_differentiable_requirement);
+ return nullptr;
+}
+// SWIFT_ENABLE_TENSORFLOW END
diff --git a/lib/Sema/DerivedConformanceElementaryFunctions.cpp b/lib/Sema/DerivedConformanceElementaryFunctions.cpp
new file mode 100644
index 0000000..d64b681
--- /dev/null
+++ b/lib/Sema/DerivedConformanceElementaryFunctions.cpp
@@ -0,0 +1,297 @@
+//===--- DerivedConformanceElementaryFunctions.cpp ------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the ElementaryFunctions protocol
+// for struct types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+// Represents synthesizable `ElementaryFunction` protocol requirements.
+enum ElementaryFunction {
+#define ELEMENTARY_FUNCTION(ID, NAME) ID,
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION
+};
+
+static StringRef getElementaryFunctionName(ElementaryFunction op) {
+ switch (op) {
+#define ELEMENTARY_FUNCTION(ID, NAME) case ElementaryFunction::ID: return NAME;
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION
+ }
+}
+
+// Return the `ElementaryFunction` protocol requirement corresponding to the
+// given elementary function.
+static ValueDecl *getElementaryFunctionRequirement(
+ ASTContext &C, ElementaryFunction op) {
+ auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
+ auto operatorId = C.getIdentifier(getElementaryFunctionName(op));
+ switch (op) {
+#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
+ case ID: \
+ return getProtocolRequirement(mathProto, operatorId);
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION_UNARY
+ case Root:
+ return getProtocolRequirement(mathProto, operatorId);
+ case Pow:
+ case PowInt:
+ auto lookup = mathProto->lookupDirect(operatorId);
+ lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
+ [](ValueDecl *v) {
+ return !isa<ProtocolDecl>(
+ v->getDeclContext()) ||
+ !v->isProtocolRequirement();
+ }),
+ lookup.end());
+ assert(lookup.size() == 2 && "Expected two 'pow' functions");
+ auto *powFuncDecl = cast<FuncDecl>(lookup.front());
+ auto secondParamType =
+ powFuncDecl->getParameters()->get(1)->getInterfaceType();
+ if (secondParamType->getAnyNominal() == C.getIntDecl())
+ return op == PowInt ? lookup.front() : lookup[1];
+ else
+ return op == PowInt ? lookup[1] : lookup.front();
+ }
+}
+
+bool DerivedConformance::canDeriveElementaryFunctions(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Nominal type must be a struct. (Zero stored properties is okay.)
+ auto *structDecl = dyn_cast<StructDecl>(nominal);
+ if (!structDecl)
+ return false;
+ // Must not have any `let` stored properties with an initial value.
+ // - This restriction may be lifted later with support for "true" memberwise
+ // initializers that initialize all stored properties, including initial
+ // value information.
+ if (hasLetStoredPropertyWithInitialValue(nominal))
+ return false;
+ // All stored properties must conform to `ElementaryFunctions`.
+ auto &C = nominal->getASTContext();
+ auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
+ return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
+ if (v->getInterfaceType()->hasError())
+ return false;
+ auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
+ return (bool)TypeChecker::conformsToProtocol(varType, mathProto, DC);
+ });
+}
+
+// Synthesize body for the given `ElementaryFunction` protocol requirement.
+static std::pair<BraceStmt *, bool>
+deriveBodyElementaryFunction(AbstractFunctionDecl *funcDecl,
+ ElementaryFunction op) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Create memberwise initializer: `Nominal.init(...)`.
+ auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
+ assert(memberwiseInitDecl && "Memberwise initializer must exist");
+ auto *initDRE =
+ new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
+ initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+ auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
+ DeclNameLoc(), nominal, funcDecl,
+ funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
+ auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
+
+ // Get operator protocol requirement.
+ auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
+ auto *operatorReq = getElementaryFunctionRequirement(C, op);
+
+ // Create reference(s) to operator parameters: one for unary functions and two
+ // for binary functions.
+ auto params = funcDecl->getParameters();
+
+ // Create call expression combining lhs and rhs members using member operator.
+ auto createMemberOpCallExpr = [&](VarDecl *member) -> Expr * {
+ auto module = nominal->getModuleContext();
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto confRef = module->lookupConformance(memberType, mathProto);
+ assert(confRef && "Member does not conform to math protocol");
+
+ // Get member type's elementary function, e.g. `Member.cos`.
+ // Use protocol requirement declaration for the operator by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberOpDecl = operatorReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the operator.
+ if (confRef.isConcrete())
+ memberOpDecl = confRef.getConcrete()->getWitnessDecl(
+ operatorReq);
+ assert(memberOpDecl && "Member operator declaration must exist");
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto memberOpExpr =
+ new (C) MemberRefExpr(memberTypeExpr, SourceLoc(), memberOpDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+
+ // - For unary ops, create expression:
+ // `<op>(x.member)`.
+ // - For `pow(_ x: Self, _ y: Self)`, create expression:
+ // `<op>(x.member, y.member)`.
+ // - For `pow(_ x: Self, _ n: Int)` and `root(_ x: Self, n: Int)`, create:
+ // `<op>(x.member, n)`.
+ // NOTE(TF-1054): create new `DeclRefExpr`s per loop iteration to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ auto *firstParamDRE =
+ new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true);
+ Expr *secondParamDRE = nullptr;
+ if (params->size() == 2)
+ secondParamDRE =
+ new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true);
+ Expr *firstArg = new (C) MemberRefExpr(firstParamDRE, SourceLoc(), member,
+ DeclNameLoc(), /*Implicit*/ true);
+ Expr *secondArg = nullptr;
+ if (secondParamDRE) {
+ if (op == PowInt || op == Root)
+ secondArg = secondParamDRE;
+ else
+ secondArg = new (C) MemberRefExpr(secondParamDRE, SourceLoc(), member,
+ DeclNameLoc(), /*Implicit*/ true);
+ }
+ SmallVector<Expr *, 2> memberOpArgs{firstArg};
+ if (secondArg)
+ memberOpArgs.push_back(secondArg);
+ SmallVector<Identifier, 2> memberOpArgLabels(memberOpArgs.size());
+ auto *memberOpCallExpr = CallExpr::createImplicit(
+ C, memberOpExpr, memberOpArgs, memberOpArgLabels);
+ return memberOpCallExpr;
+ };
+
+ // Create array of member operator call expressions.
+ llvm::SmallVector<Expr *, 2> memberOpCallExprs;
+ llvm::SmallVector<Identifier, 2> memberNames;
+ for (auto member : nominal->getStoredProperties()) {
+ memberOpCallExprs.push_back(createMemberOpCallExpr(member));
+ memberNames.push_back(member->getName());
+ }
+ // Call memberwise initializer with member operator call expressions.
+ auto *callExpr =
+ CallExpr::createImplicit(C, initExpr, memberOpCallExprs, memberNames);
+ ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
+ auto* braceStmt = BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+#define ELEMENTARY_FUNCTION(ID, NAME) \
+static std::pair<BraceStmt *, bool> deriveBodyElementaryFunctions_##ID( \
+ AbstractFunctionDecl *funcDecl, void *) { \
+ return deriveBodyElementaryFunction(funcDecl, ID); \
+}
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION
+
+// Synthesize function declaration for the given math operator.
+static ValueDecl *deriveElementaryFunction(DerivedConformance &derived,
+ElementaryFunction op) {
+ auto nominal = derived.Nominal;
+ auto parentDC = derived.getConformanceContext();
+ auto &C = derived.Context;
+ auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
+
+ // Create parameter declaration with the given name and type.
+ auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * {
+ auto *param = new (C)
+ ParamDecl(SourceLoc(), SourceLoc(), Identifier(), SourceLoc(),
+ C.getIdentifier(name), parentDC);
+ param->setSpecifier(ParamDecl::Specifier::Default);
+ param->setInterfaceType(type);
+ return param;
+ };
+
+ ParameterList *params = nullptr;
+
+ switch (op) {
+#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
+ case ID: \
+ params = \
+ ParameterList::create(C, {createParamDecl("x", selfInterfaceType)}); \
+ break;
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION_UNARY
+ case Pow:
+ params =
+ ParameterList::create(C, {createParamDecl("x", selfInterfaceType),
+ createParamDecl("y", selfInterfaceType)});
+ break;
+ case PowInt:
+ case Root:
+ params = ParameterList::create(
+ C, {createParamDecl("x", selfInterfaceType),
+ createParamDecl("n", C.getIntDecl()->getDeclaredInterfaceType())});
+ break;
+ }
+
+ auto operatorId = C.getIdentifier(getElementaryFunctionName(op));
+ DeclName operatorDeclName(C, operatorId, params);
+ auto operatorDecl =
+ FuncDecl::createImplicit(C, StaticSpellingKind::KeywordStatic,
+ operatorDeclName, SourceLoc(), /*Async*/ false,
+ /*Throws*/ false, /*GenericParams*/ nullptr,
+ params, selfInterfaceType, parentDC);
+ switch (op) {
+#define ELEMENTARY_FUNCTION(ID, NAME) \
+ case ID: \
+ operatorDecl->setBodySynthesizer(deriveBodyElementaryFunctions_##ID, \
+ nullptr); \
+ break;
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION
+ }
+ operatorDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({operatorDecl});
+ return operatorDecl;
+}
+
+ValueDecl *
+DerivedConformance::deriveElementaryFunctions(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
+ if (requirement->getBaseName() == Context.getIdentifier(NAME)) \
+ return deriveElementaryFunction(*this, ID);
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION_UNARY
+ if (requirement->getBaseName() == Context.getIdentifier("root"))
+ return deriveElementaryFunction(*this, Root);
+ if (requirement->getBaseName() == Context.getIdentifier("pow")) {
+ auto *powFuncDecl = cast<FuncDecl>(requirement);
+ return powFuncDecl->getParameters()->get(1)->getName().str() == "n"
+ ? deriveElementaryFunction(*this, PowInt)
+ : deriveElementaryFunction(*this, Pow);
+ }
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_elementary_functions_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformanceElementaryFunctions.def b/lib/Sema/DerivedConformanceElementaryFunctions.def
new file mode 100644
index 0000000..692ad2e
--- /dev/null
+++ b/lib/Sema/DerivedConformanceElementaryFunctions.def
@@ -0,0 +1,63 @@
+//===--- DerivedConformanceElementaryFunctions.def ------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines macros used for macro-metaprogramming with
+// ElementaryFunction protocol requirements. Currently used only by derived
+// conformances.
+//
+//===----------------------------------------------------------------------===//
+
+/// ELEMENTARY_FUNCTION(Id, Name)
+/// - Id is an elementary function identifier, used for the enum case
+/// `ElementaryFunctions::Id`.
+/// - Name is the name of the elementary function.
+
+// One macro must be defined by the includer.
+#if !defined(ELEMENTARY_FUNCTION) && !defined(ELEMENTARY_FUNCTION_UNARY)
+#error "Macro must be defined by includer"
+#endif
+
+#ifndef ELEMENTARY_FUNCTION
+#define ELEMENTARY_FUNCTION(Id, Name)
+#endif
+
+#ifndef ELEMENTARY_FUNCTION_UNARY
+#define ELEMENTARY_FUNCTION_UNARY(Id, Name) ELEMENTARY_FUNCTION(Id,Name)
+#endif
+
+ELEMENTARY_FUNCTION_UNARY(Sqrt, "sqrt")
+ELEMENTARY_FUNCTION_UNARY(Cos, "cos")
+ELEMENTARY_FUNCTION_UNARY(Sin, "sin")
+ELEMENTARY_FUNCTION_UNARY(Tan, "tan")
+ELEMENTARY_FUNCTION_UNARY(Cosh, "cosh")
+ELEMENTARY_FUNCTION_UNARY(Sinh, "sinh")
+ELEMENTARY_FUNCTION_UNARY(Tanh, "tanh")
+ELEMENTARY_FUNCTION_UNARY(Acos, "acos")
+ELEMENTARY_FUNCTION_UNARY(Asin, "asin")
+ELEMENTARY_FUNCTION_UNARY(Atan, "atan")
+ELEMENTARY_FUNCTION_UNARY(Acosh, "acosh")
+ELEMENTARY_FUNCTION_UNARY(Asinh, "asinh")
+ELEMENTARY_FUNCTION_UNARY(Atanh, "atanh")
+ELEMENTARY_FUNCTION_UNARY(Exp, "exp")
+ELEMENTARY_FUNCTION_UNARY(Exp2, "exp2")
+ELEMENTARY_FUNCTION_UNARY(Exp10, "exp10")
+ELEMENTARY_FUNCTION_UNARY(Expm1, "expm1")
+ELEMENTARY_FUNCTION_UNARY(Log, "log")
+ELEMENTARY_FUNCTION_UNARY(Log2, "log2")
+ELEMENTARY_FUNCTION_UNARY(Log10, "log10")
+ELEMENTARY_FUNCTION_UNARY(Log1p, "log1p")
+ELEMENTARY_FUNCTION(Pow, "pow")
+ELEMENTARY_FUNCTION(PowInt, "pow")
+ELEMENTARY_FUNCTION(Root, "root")
+
+#undef ELEMENTARY_FUNCTION_UNARY
+#undef ELEMENTARY_FUNCTION
diff --git a/lib/Sema/DerivedConformanceKeyPathIterable.cpp b/lib/Sema/DerivedConformanceKeyPathIterable.cpp
new file mode 100644
index 0000000..efc9e6a
--- /dev/null
+++ b/lib/Sema/DerivedConformanceKeyPathIterable.cpp
@@ -0,0 +1,195 @@
+//===--- DerivedConformanceKeyPathIterable.cpp ----------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the KeyPathIterable protocol for
+// a nominal type.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+bool DerivedConformance::canDeriveKeyPathIterable(NominalTypeDecl *nominal) {
+ // Note: we could extend synthesis to support classes.
+ // Subclasses need to append `allKeyPaths` to `super.allKeyPaths`.
+ return isa<StructDecl>(nominal);
+}
+
+// Compute `PartialKeyPath<Nominal>`, bound to the given nominal
+// declaration's type.
+static Type computePartialKeyPathType(NominalTypeDecl *nominal) {
+ auto &C = nominal->getASTContext();
+ auto nominalType = nominal->getDeclaredInterfaceType();
+ if (!nominalType || nominalType->hasError())
+ return nullptr;
+ auto *partialKeyPathDecl = cast<ClassDecl>(C.getPartialKeyPathDecl());
+ return BoundGenericClassType::get(partialKeyPathDecl, /*parent*/ Type(),
+ {nominal->getDeclaredInterfaceType()});
+}
+
+// Compute `AllKeyPaths` associated type for the given nominal declaration.
+// It should be `[PartialKeyPath<Nominal>]`.
+static ArraySliceType *computeAllKeyPathsType(NominalTypeDecl *nominal) {
+ auto partialKeyPathType = computePartialKeyPathType(nominal);
+ return ArraySliceType::get(partialKeyPathType);
+}
+
+// Compute `KeyPath<Nominal, Member>`.
+static Type computeKeyPathType(NominalTypeDecl *nominal, Type memberType) {
+ auto &C = nominal->getASTContext();
+ auto nominalType = nominal->getDeclaredInterfaceType();
+ if (!nominalType || nominalType->hasError())
+ return nullptr;
+ auto *keyPathDecl = cast<ClassDecl>(C.getKeyPathDecl());
+ return BoundGenericClassType::get(
+ keyPathDecl, /*parent*/ Type(),
+ {nominal->getDeclaredInterfaceType(), memberType});
+}
+
+// Mark the given `ValueDecl` as `@inlinable`, if the conformance context's
+// module is not resilient and the `ValueDecl` is effectively public.
+// TODO: Dedupe with DerivedConformanceRawRepresentable.cpp.
+static void maybeMarkAsInlinable(DerivedConformance &derived, ValueDecl *decl) {
+ ASTContext &C = derived.Context;
+ auto parentDC = derived.getConformanceContext();
+ if (!parentDC->getParentModule()->isResilient()) {
+ auto access = decl->getFormalAccessScope(
+ nullptr, /*treatUsableFromInlineAsPublic*/ true);
+ if (access.isPublic()) {
+ decl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ false));
+ if (auto *attr = decl->getAttrs().getAttribute<UsableFromInlineAttr>())
+ attr->setInvalid();
+ }
+ }
+}
+
+// Synthesize body for the `allKeyPaths` computed property getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyKeyPathIterable_allKeyPaths(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getDeclContext();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
+ DeclNameLoc(), nominal, funcDecl,
+ funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
+
+ // Create array of key path expressions to stored properties.
+ llvm::SmallVector<Expr *, 2> keyPathExprs;
+ for (auto member : nominal->getStoredProperties()) {
+ // FIXME(TF-123): Skip generating keypaths to `@differentiable` functions
+ // because of SILGen crash. Robust fix involves changing
+ // `createAutoDiffThunk`.
+ if (auto fnType = member->getType()->getAs<AnyFunctionType>())
+ if (fnType->getExtInfo().isDifferentiable())
+ continue;
+
+ auto *dotExpr = new (C)
+ UnresolvedDotExpr(nominalTypeExpr, SourceLoc(),
+ DeclNameRef(member->getName()), DeclNameLoc(),
+ /*Implicit*/ true);
+ Expr *keyPathExpr =
+ new (C) KeyPathExpr(SourceLoc(), dotExpr, nullptr, /*Implicit*/ true);
+ // NOTE(TF-575): Adding an explicit coercion expression to
+ // `KeyPath<Nominal, Member>` here is necessary due to type-checker changes.
+ auto keyPathInterfaceType =
+ computeKeyPathType(nominal, member->getInterfaceType());
+ auto keyPathType = parentDC->mapTypeIntoContext(keyPathInterfaceType);
+ keyPathExpr = CoerceExpr::createImplicit(C, keyPathExpr, keyPathType);
+ keyPathExprs.push_back(keyPathExpr);
+ }
+ // Return array of all key path expressions.
+ Expr *keyPathsArrayExpr =
+ ArrayExpr::create(C, SourceLoc(), keyPathExprs, {}, SourceLoc());
+ keyPathsArrayExpr->setImplicit();
+ auto *returnStmt = new (C) ReturnStmt(SourceLoc(), keyPathsArrayExpr);
+ auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
+ /*Implicit*/ true);
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
+ /*Implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize the `allKeyPaths` computed property declaration.
+static ValueDecl *
+deriveKeyPathIterable_allKeyPaths(DerivedConformance &derived) {
+ auto nominal = derived.Nominal;
+ auto &C = derived.Context;
+
+ auto returnInterfaceTy = computeAllKeyPathsType(nominal);
+ auto returnTy =
+ derived.getConformanceContext()->mapTypeIntoContext(returnInterfaceTy);
+
+ // Create `allKeyPaths` property declaration.
+ VarDecl *allKeyPathsDecl;
+ PatternBindingDecl *pbDecl;
+ std::tie(allKeyPathsDecl, pbDecl) = derived.declareDerivedProperty(
+ C.Id_allKeyPaths, returnInterfaceTy, returnTy, /*isStatic*/ false,
+ /*isFinal*/ true);
+
+ // Maybe add `@inlinable` to the `allKeyPaths` declaration.
+ if (llvm::all_of(nominal->getStoredProperties(), [](VarDecl *vd) {
+ return vd->getFormalAccessScope(
+ nullptr, /*treatUsableFromInlineAsPublic*/ true).isPublic();
+ })) {
+ maybeMarkAsInlinable(derived, allKeyPathsDecl);
+ }
+
+ // Create `allKeyPaths` getter.
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
+ allKeyPathsDecl, returnTy);
+ getterDecl->setBodySynthesizer(
+ deriveBodyKeyPathIterable_allKeyPaths, nullptr);
+ derived.addMembersToConformanceContext({allKeyPathsDecl, pbDecl});
+
+ return allKeyPathsDecl;
+}
+
+static Type deriveKeyPathIterable_AllKeyPaths(DerivedConformance &derived) {
+ auto *rawInterfaceType = computeAllKeyPathsType(derived.Nominal);
+ return derived.getConformanceContext()->mapTypeIntoContext(rawInterfaceType);
+}
+
+ValueDecl *DerivedConformance::deriveKeyPathIterable(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.Id_allKeyPaths)
+ return deriveKeyPathIterable_allKeyPaths(*this);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_key_path_iterable_requirement);
+ return nullptr;
+}
+
+Type DerivedConformance::deriveKeyPathIterable(
+ AssociatedTypeDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.Id_AllKeyPaths)
+ return deriveKeyPathIterable_AllKeyPaths(*this);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_key_path_iterable_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformancePointwiseMultiplicative.cpp b/lib/Sema/DerivedConformancePointwiseMultiplicative.cpp
new file mode 100644
index 0000000..416eb76
--- /dev/null
+++ b/lib/Sema/DerivedConformancePointwiseMultiplicative.cpp
@@ -0,0 +1,333 @@
+//===--- DerivedConformancePointwiseMultiplicative.cpp --------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the PointwiseMultiplicative
+// protocol for struct types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+bool
+DerivedConformance::canDerivePointwiseMultiplicative(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Nominal type must be a struct. (No stored properties is okay.)
+ auto *structDecl = dyn_cast<StructDecl>(nominal);
+ if (!structDecl)
+ return false;
+ // Must not have any `let` stored properties with an initial value.
+ // - This restriction may be lifted later with support for "true" memberwise
+ // initializers that initialize all stored properties, including initial
+ // value information.
+ if (hasLetStoredPropertyWithInitialValue(nominal))
+ return false;
+ // All stored properties must conform to `AdditiveArithmetic`.
+ auto &C = nominal->getASTContext();
+ auto *proto = C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
+ return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
+ if (v->getInterfaceType()->hasError())
+ return false;
+ auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
+ return (bool)TypeChecker::conformsToProtocol(varType, proto, DC);
+ });
+}
+
+// Synthesize body for math operator.
+static std::pair<BraceStmt *, bool>
+deriveBodyMathOperator(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Create memberwise initializer: `Nominal.init(...)`.
+ auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
+ assert(memberwiseInitDecl && "Memberwise initializer must exist");
+ auto *initDRE =
+ new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
+ initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+ auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
+ DeclNameLoc(), nominal, funcDecl,
+ funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
+ auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
+
+ // Get operator protocol requirement.
+ auto *proto = C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
+ auto operatorId = C.getIdentifier(".*");
+ auto *operatorReq = getProtocolRequirement(proto, operatorId);
+
+ // Create reference to operator parameters: lhs and rhs.
+ auto params = funcDecl->getParameters();
+
+ // Create expression combining lhs and rhs members using member operator.
+ auto createMemberOpExpr = [&](VarDecl *member) -> Expr * {
+ auto module = nominal->getModuleContext();
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto confRef = module->lookupConformance(memberType, proto);
+ assert(confRef && "Member does not conform to math protocol");
+
+ // Get member type's math operator, e.g. `Member.+`.
+ // Use protocol requirement declaration for the operator by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberOpDecl = operatorReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the operator.
+ if (confRef.isConcrete())
+ if (auto *concreteMemberMethodDecl =
+ confRef.getConcrete()->getWitnessDecl(operatorReq))
+ memberOpDecl = concreteMemberMethodDecl;
+ assert(memberOpDecl && "Member operator declaration must exist");
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto memberOpExpr =
+ new (C) MemberRefExpr(memberTypeExpr, SourceLoc(), memberOpDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+
+ // Create expression `lhs.member <op> rhs.member`.
+ // NOTE(TF-1054): create new `DeclRefExpr`s per loop iteration to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ auto *lhsDRE =
+ new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true);
+ auto *rhsDRE =
+ new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true);
+ Expr *lhsArg = new (C) MemberRefExpr(lhsDRE, SourceLoc(), member,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *rhsArg = new (C) MemberRefExpr(rhsDRE, SourceLoc(), member,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *memberOpArgs =
+ TupleExpr::create(C, SourceLoc(), {lhsArg, rhsArg}, {}, {}, SourceLoc(),
+ /*HasTrailingClosure*/ false,
+ /*Implicit*/ true);
+ auto *memberOpCallExpr =
+ new (C) BinaryExpr(memberOpExpr, memberOpArgs, /*Implicit*/ true);
+ return memberOpCallExpr;
+ };
+
+ // Create array of member operator call expressions.
+ llvm::SmallVector<Expr *, 2> memberOpExprs;
+ llvm::SmallVector<Identifier, 2> memberNames;
+ for (auto member : nominal->getStoredProperties()) {
+ memberOpExprs.push_back(createMemberOpExpr(member));
+ memberNames.push_back(member->getName());
+ }
+ // Call memberwise initializer with member operator call expressions.
+ auto *callExpr =
+ CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames);
+ ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
+ return std::pair<BraceStmt *, bool>(
+ BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true), false);
+}
+
+// Synthesize function declaration for the given math operator.
+static ValueDecl *
+derivePointwiseMultiplicative_multiply(DerivedConformance &derived) {
+ auto nominal = derived.Nominal;
+ auto parentDC = derived.getConformanceContext();
+ auto &C = derived.Context;
+ auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
+
+ // Create parameter declaration with the given name and type.
+ auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * {
+ auto *param = new (C)
+ ParamDecl(SourceLoc(), SourceLoc(), Identifier(), SourceLoc(),
+ C.getIdentifier(name), parentDC);
+ param->setSpecifier(ParamDecl::Specifier::Default);
+ param->setInterfaceType(type);
+ return param;
+ };
+
+ ParameterList *params =
+ ParameterList::create(C, {createParamDecl("lhs", selfInterfaceType),
+ createParamDecl("rhs", selfInterfaceType)});
+
+ auto operatorId = C.getIdentifier(".*");
+ DeclName operatorDeclName(C, operatorId, params);
+ auto operatorDecl = FuncDecl::createImplicit(
+ C, StaticSpellingKind::KeywordStatic, operatorDeclName, SourceLoc(),
+ /*Async*/ false,
+ /*Throws*/ false,
+ /*GenericParams=*/nullptr, params, selfInterfaceType, parentDC);
+ operatorDecl->setImplicit();
+ operatorDecl->setBodySynthesizer(&deriveBodyMathOperator);
+ operatorDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({operatorDecl});
+ return operatorDecl;
+}
+
+// Synthesize body for a computed property getter.
+static std::pair<BraceStmt *, bool>
+deriveComputedPropertyGetter(AbstractFunctionDecl *funcDecl,
+ ProtocolDecl *proto, ValueDecl *reqDecl) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
+ assert(memberwiseInitDecl && "Memberwise initializer must exist");
+ auto *initDRE =
+ new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
+ initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+
+ auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
+ DeclNameLoc(), nominal, funcDecl,
+ funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
+ auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
+
+ auto createMemberPropertyExpr = [&](VarDecl *member) -> Expr * {
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ Expr *memberExpr = nullptr;
+ // If the property is static, create a type expression: `Member`.
+ if (reqDecl->isStatic()) {
+ memberExpr = TypeExpr::createImplicit(memberType, C);
+ }
+ // If the property is not static, create a member ref expression:
+ // `self.member`.
+ else {
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *selfDRE =
+ new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+ memberExpr =
+ new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
+ /*Implicit*/ true);
+ }
+ auto module = nominal->getModuleContext();
+ auto confRef = module->lookupConformance(memberType, proto);
+ assert(confRef && "Member does not conform to `PointwiseMultiplicative`");
+ // If conformance reference is not concrete, then concrete witness
+ // declaration for property cannot be resolved. Return reference to protocol
+ // requirement: this will be dynamically dispatched.
+ if (!confRef.isConcrete()) {
+ return new (C) MemberRefExpr(memberExpr, SourceLoc(), reqDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ }
+ // Otherwise, return reference to concrete witness declaration.
+ auto conf = confRef.getConcrete();
+ auto witnessDecl = conf->getWitnessDecl(reqDecl);
+ return new (C) MemberRefExpr(memberExpr, SourceLoc(), witnessDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ };
+
+ // Create array of `member.<property>` expressions.
+ llvm::SmallVector<Expr *, 2> memberPropExprs;
+ llvm::SmallVector<Identifier, 2> memberNames;
+ for (auto member : nominal->getStoredProperties()) {
+ memberPropExprs.push_back(createMemberPropertyExpr(member));
+ memberNames.push_back(member->getName());
+ }
+ // Call memberwise initializer with member property expressions.
+ auto *callExpr =
+ CallExpr::createImplicit(C, initExpr, memberPropExprs, memberNames);
+ ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
+ auto *braceStmt =
+ BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize body for the `PointwiseMultiplicative.one` computed property
+// getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyPointwiseMultiplicative_one(AbstractFunctionDecl *funcDecl, void *) {
+ auto &C = funcDecl->getASTContext();
+ auto *pointMulProto =
+ C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
+ auto *oneReq = getProtocolRequirement(pointMulProto, C.Id_one);
+ return deriveComputedPropertyGetter(funcDecl, pointMulProto, oneReq);
+}
+
+// Synthesize body for the `PointwiseMultiplicative.reciprocal` computed
+// property getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyPointwiseMultiplicative_reciprocal(AbstractFunctionDecl *funcDecl,
+ void *) {
+ auto &C = funcDecl->getASTContext();
+ auto *pointMulProto =
+ C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
+ auto *reciprocalReq = getProtocolRequirement(pointMulProto, C.Id_reciprocal);
+ return deriveComputedPropertyGetter(funcDecl, pointMulProto, reciprocalReq);
+}
+
+// Synthesize a `PointwiseMultiplicative` property declaration.
+static ValueDecl *
+deriveProperty(DerivedConformance &derived, Identifier propertyName,
+ bool isStatic,
+ AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
+ auto *nominal = derived.Nominal;
+ auto *parentDC = derived.getConformanceContext();
+
+ auto returnInterfaceTy = nominal->getDeclaredInterfaceType();
+ auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy);
+
+ // Create property declaration.
+ VarDecl *propDecl;
+ PatternBindingDecl *pbDecl;
+ std::tie(propDecl, pbDecl) = derived.declareDerivedProperty(
+ propertyName, returnInterfaceTy, returnTy, /*isStatic*/ isStatic,
+ /*isFinal*/ true);
+
+ // Create property getter.
+ auto *getterDecl =
+ derived.addGetterToReadOnlyDerivedProperty(propDecl, returnTy);
+ getterDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
+ derived.addMembersToConformanceContext({propDecl, pbDecl});
+
+ return propDecl;
+}
+
+// Synthesize the static property declaration for
+// `PointwiseMultiplicative.one`.
+static ValueDecl *
+derivePointwiseMultiplicative_one(DerivedConformance &derived) {
+ auto &C = derived.Context;
+ return deriveProperty(derived, C.Id_one, /*isStatic*/ true,
+ {deriveBodyPointwiseMultiplicative_one, nullptr});
+}
+
+// Synthesize the instance property declaration for
+// `PointwiseMultiplicative.reciprocal`.
+static ValueDecl *
+derivePointwiseMultiplicative_reciprocal(DerivedConformance &derived) {
+ auto &C = derived.Context;
+ return deriveProperty(
+ derived, C.Id_reciprocal, /*isStatic*/ false,
+ {deriveBodyPointwiseMultiplicative_reciprocal, nullptr});
+}
+
+ValueDecl *
+DerivedConformance::derivePointwiseMultiplicative(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.getIdentifier(".*"))
+ return derivePointwiseMultiplicative_multiply(*this);
+ if (requirement->getBaseName() == Context.Id_one)
+ return derivePointwiseMultiplicative_one(*this);
+ if (requirement->getBaseName() == Context.Id_reciprocal)
+ return derivePointwiseMultiplicative_reciprocal(*this);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_pointwise_multiplicative_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp
new file mode 100644
index 0000000..9e26bd1
--- /dev/null
+++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp
@@ -0,0 +1,644 @@
+//===--- DerivedConformanceTensorArrayProtocol.cpp ------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the TensorArrayProtocol protocol
+// for a nominal type.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Nominal type must be a struct (zero stored properties is okay).
+ // Note: we could extend synthesis to support classes.
+ auto *structDecl = dyn_cast<StructDecl>(nominal);
+ if (!structDecl)
+ return false;
+ // All stored properties must conform to `TensorGroup`.
+ auto &C = nominal->getASTContext();
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
+ if (v->getInterfaceType()->hasError())
+ return false;
+ auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
+ return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC);
+ });
+}
+
+// Return the protocol requirement with the specified name.
+static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
+ auto lookup = proto->lookupDirect(name);
+ lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
+ [](ValueDecl *v) {
+ return !isa<ProtocolDecl>(
+ v->getDeclContext()) ||
+ !v->isProtocolRequirement();
+ }),
+ lookup.end());
+ assert(lookup.size() == 1 && "Ambiguous protocol requirement");
+ return lookup.front();
+}
+
+// Synthesize body for `_unpackTensorHandles(into:)`.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorArrayProtocol_unpackTensorHandles(
+ AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ Type baseAddressType = BoundGenericType::get(
+ C.getUnsafeMutablePointerDecl(), Type(), {cTensorHandleType});
+ Type addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+
+ // Get references to `self` and parameter declarations.
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *selfDRE = new (C)
+ DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+ auto *paramDecl = funcDecl->getParameters()->get(0);
+ auto *paramDRE = new (C)
+ DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
+
+ // Create an `if var` statement for the current address.
+ VarDecl *currAddressDecl = new (C) VarDecl(
+ /*IsStatic*/ false, VarDecl::Introducer::Var, /*IsCaptureList*/ false,
+ SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
+ currAddressDecl->setImplicit();
+ currAddressDecl->setHasNonPatternBindingInit(true);
+ currAddressDecl->setInterfaceType(baseAddressType);
+
+ Pattern *currAddressPat = NamedPattern::createImplicit(C, currAddressDecl);
+ currAddressPat =
+ BindingPattern::createImplicit(C, /*isLet*/ false, currAddressPat);
+ currAddressPat =
+ new (C) OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc());
+ currAddressPat->setImplicit();
+ StmtConditionElement cond[] = {
+ StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};
+
+ // Get method protocol requirement.
+ auto *tensorArrayProto = C.getProtocol(
+ KnownProtocolKind::TensorArrayProtocol);
+ auto *methodReq = getProtocolRequirement(
+ tensorArrayProto, C.Id_unpackTensorHandles);
+ auto *countReq = getProtocolRequirement(
+ tensorArrayProto, C.Id_tensorHandleCount);
+
+ Type intType = C.getIntDecl()->getDeclaredType();
+ TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
+
+ // Iterate through the `TensorArrayProtocol`-conforming members and call
+ // `member._unpackTensorHandles(into:)`.
+ llvm::SmallVector<ASTNode, 2> memberExprs;
+ for (auto member : nominal->getStoredProperties()) {
+ auto memberType = parentDC->mapTypeIntoContext(
+ member->getValueInterfaceType());
+ auto module = nominal->getModuleContext();
+ auto confRef = module->lookupConformance(memberType, tensorArrayProto);
+ assert(confRef && "Member does not conform to `TensorArrayProtocol`");
+
+ // Get member type's method, e.g. `Member._unpackTensorHandles(into:)`.
+ // Use protocol requirement declaration for the method by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberMethodDecl = methodReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the operator.
+ if (confRef.isConcrete())
+ memberMethodDecl = confRef.getConcrete()->
+ getWitnessDecl(methodReq);
+ assert(memberMethodDecl && "Member method declaration must exist");
+
+ // Create reference to member method: `Member._unpackTensorHandles(into:)`.
+ auto *memberDRE = new (C) MemberRefExpr(
+ selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
+ auto memberMethodExpr =
+ new (C) MemberRefExpr(memberDRE, SourceLoc(), memberMethodDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+
+ // Obtain the method call argument.
+ auto *addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *callExpr = CallExpr::createImplicit(C, memberMethodExpr, {addressDRE},
+ {C.getIdentifier("into")});
+
+ // Advance the current address.
+ DeclName advancedName(C, C.getIdentifier("advanced"),
+ {C.getIdentifier("by")});
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *advancedMethodExpr =
+ new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
+ DeclNameRef(advancedName), DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Obtain `Member._tensorHandleCount`.
+ auto *memberCountMRE = new (C) MemberRefExpr(
+ memberDRE, SourceLoc(), countReq, DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Cast the tensor handle count to Int.
+ auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
+ {Identifier()});
+ auto *intInitExpr = new (C)
+ UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *intInitCallExpr = CallExpr::createImplicit(
+ C, intInitExpr, {memberCountMRE}, {Identifier()});
+
+ // Assign the new address.
+ auto *assignCallExpr = CallExpr::createImplicit(
+ C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *assignExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
+ assignCallExpr, /*Implicit*/ true);
+
+ memberExprs.push_back(callExpr);
+ memberExprs.push_back(assignExpr);
+ }
+
+ auto *thenBody = BraceStmt::create(C, SourceLoc(),
+ C.AllocateCopy(memberExprs),
+ SourceLoc(), /*implicit*/ true);
+
+ auto *ifStmt = new (C)
+ IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
+ /*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
+ /*ElseLoc*/ SourceLoc(), /*Else*/ nullptr, /*implicit*/ true);
+
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
+ /*implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize function declaration for a `TensorArrayProtocol`
+// method requirement.
+static ValueDecl *deriveTensorArrayProtocol_method(
+ DerivedConformance &derived, Identifier methodName, Identifier argumentName,
+ Identifier parameterName, Type parameterType, Type returnType,
+ AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
+ auto nominal = derived.Nominal;
+ auto &C = derived.Context;
+ auto parentDC = derived.getConformanceContext();
+
+ auto *param =
+ new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentName, SourceLoc(),
+ parameterName, parentDC);
+ param->setSpecifier(ParamDecl::Specifier::Default);
+ param->setInterfaceType(parameterType);
+ ParameterList *params = ParameterList::create(C, {param});
+
+ DeclName declName(C, methodName, params);
+ auto funcDecl = FuncDecl::createImplicit(
+ C, StaticSpellingKind::None, declName, SourceLoc(), /*Async*/ false,
+ /*Throws*/ false,
+ /*GenericParams*/ nullptr, params, returnType, parentDC);
+ funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
+
+ funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({funcDecl});
+ return funcDecl;
+}
+
+// Synthesize the `_unpackTensorHandles(into:)` function declaration.
+static ValueDecl
+*deriveTensorArrayProtocol_unpackTensorHandles(DerivedConformance &derived) {
+ auto &C = derived.Context;
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ Type baseAddressType = BoundGenericType::get(
+ C.getUnsafeMutablePointerDecl(), Type(), {cTensorHandleType});
+ Type addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+ Type voidType = C.getVoidDecl()->getDeclaredInterfaceType();
+
+ return deriveTensorArrayProtocol_method(
+ derived, C.Id_unpackTensorHandles, C.getIdentifier("into"),
+ C.getIdentifier("address"), addressType, voidType,
+ {deriveBodyTensorArrayProtocol_unpackTensorHandles, nullptr});
+}
+
+/// Derive the body for the '_tensorHandleCount' getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorArrayProtocol_tensorHandleCount(AbstractFunctionDecl *funcDecl,
+ void *) {
+ auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Get references to `self`.
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *selfDRE = new (C)
+ DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+
+ // Get protocol requirement.
+ auto *tensorArrayProto = C.getProtocol(
+ KnownProtocolKind::TensorArrayProtocol);
+ auto *countReq = getProtocolRequirement(
+ tensorArrayProto, C.Id_tensorHandleCount);
+
+ // Concatenate all member `_tensorHandleCount`s.
+ Type intType = C.getInt32Decl()->getDeclaredType();
+ TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
+ auto plusOpLookup = C.getInt32Decl()->lookupDirect(C.getIdentifier("+"));
+ assert(plusOpLookup.size() == 1 && "Ambiguous 'Int32.+' operator.");
+ ValueDecl *plusOpDecl = plusOpLookup.front();
+ Expr *tensorHandleCountExpr = new (C)
+ IntegerLiteralExpr("0", SourceLoc(), /*implicit*/ true);
+ for (auto member : nominal->getStoredProperties()) {
+ auto plusOpExpr = new (C) MemberRefExpr(
+ intTypeExpr, SourceLoc(), plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
+ auto *memberDRE = new (C) MemberRefExpr(
+ selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
+ auto *memberTensorHandleCountExpr = new (C)
+ MemberRefExpr(memberDRE, SourceLoc(), countReq,
+ DeclNameLoc(), /*Implicit*/ true);
+ // Create expression `lhsArg + rhsArg`.
+ auto *plusOpArgs =
+ TupleExpr::create(C, SourceLoc(),
+ {tensorHandleCountExpr, memberTensorHandleCountExpr},
+ {}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
+ /*Implicit*/ true);
+ tensorHandleCountExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
+ /*Implicit*/ true);
+ }
+
+ // Return the resulting data types array.
+ auto *returnStmt = new (C) ReturnStmt(SourceLoc(), tensorHandleCountExpr);
+ auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
+ /*Implicit*/ true);
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
+ /*Implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+/// Derive a `_tensorHandleCount` implementation.
+static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
+ DerivedConformance &derived) {
+ auto nominal = derived.Nominal;
+ ASTContext &C = derived.Context;
+
+ auto parentDC = derived.getConformanceContext();
+ Type intType = C.getInt32Decl()->getDeclaredType();
+ auto returnType = parentDC->mapTypeIntoContext(intType);
+
+ // Create `_tensorHandleCount` property declaration.
+ VarDecl *tensorHandleCountDecl;
+ PatternBindingDecl *patDecl;
+ std::tie(tensorHandleCountDecl, patDecl) = derived.declareDerivedProperty(
+ C.Id_tensorHandleCount, returnType, returnType, /*isStatic*/ false,
+ /*isFinal*/ false);
+
+ // Add `@inlinable` to the `_tensorHandleCount` declaration.
+ if (nominal->getEffectiveAccess() > AccessLevel::Internal)
+ tensorHandleCountDecl->getAttrs().add(
+ new (C) InlinableAttr(/*implicit*/ true));
+
+ // Create `_tensorHandleCount` getter.
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
+ tensorHandleCountDecl, returnType);
+ getterDecl->setBodySynthesizer(
+ deriveBodyTensorArrayProtocol_tensorHandleCount, nullptr);
+ derived.addMembersToConformanceContext({tensorHandleCountDecl, patDecl});
+
+ return tensorHandleCountDecl;
+}
+
+/// Derive the body for the '_typeList' getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);
+
+ // Concatenate all member `_typeList` arrays.
+ Type arrayType = BoundGenericType::get(
+ C.getArrayDecl(), Type(),
+ {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
+ auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
+ auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
+ assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
+ ValueDecl *plusOpDecl = plusOpLookup.front();
+ Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
+ for (auto member : nominal->getStoredProperties()) {
+ auto *plusOpExpr =
+ new (C) MemberRefExpr(arrayTypeExpr, SourceLoc(), plusOpDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto *memberTypeListExpr = new (C)
+ MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
+ DeclNameLoc(), /*Implicit*/ true);
+ // Create expression `lhsArg + rhsArg`.
+ auto *plusOpArgs =
+ TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
+ {}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
+ /*Implicit*/ true);
+ typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
+ /*Implicit*/ true);
+ }
+
+ // Return the resulting data types array.
+ auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
+ auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
+ /*Implicit*/ true);
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
+ /*Implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+/// Derive a `_typeList` implementation.
+static ValueDecl *deriveTensorArrayProtocol_typeList(
+ DerivedConformance &derived) {
+ auto nominal = derived.Nominal;
+ ASTContext &C = derived.Context;
+
+ auto parentDC = derived.getConformanceContext();
+ Type dataTypeArrayType = BoundGenericType::get(
+ C.getArrayDecl(), Type(),
+ {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
+ auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);
+
+ // Create `_typeList` property declaration.
+ VarDecl *typeListDecl;
+ PatternBindingDecl *patDecl;
+ std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
+ C.Id_typeList, returnType, returnType, /*isStatic*/ false,
+ /*isFinal*/ false);
+
+ // Add `@inlinable` to the `_typeList` declaration.
+ if (nominal->getEffectiveAccess() > AccessLevel::Internal)
+ typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));
+
+ // Create `_typeList` getter.
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
+ typeListDecl, returnType);
+ getterDecl->setBodySynthesizer(
+ deriveBodyTensorArrayProtocol_typeList, nullptr);
+ derived.addMembersToConformanceContext({typeListDecl, patDecl});
+
+ return typeListDecl;
+}
+
+// Synthesize body for `init(_owning:count:)`.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ auto baseAddressType = BoundGenericType::get(
+ C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
+ auto addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+ auto *addressTypeExpr = TypeExpr::createImplicit(addressType, C);
+
+ // Get references to `self` and parameter declarations.
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *selfDRE = new (C)
+ DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+ auto *paramDecl = funcDecl->getParameters()->get(0);
+ auto *paramDRE = new (C)
+ DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
+
+ // Create an `if var` statement for the current address.
+ VarDecl *currAddressDecl = new (C) VarDecl(
+ /*IsStatic*/ false, VarDecl::Introducer::Var, /*IsCaptureList*/ false,
+ SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
+ currAddressDecl->setImplicit();
+ currAddressDecl->setHasNonPatternBindingInit(true);
+ currAddressDecl->setInterfaceType(baseAddressType);
+
+ Pattern *currAddressPat = NamedPattern::createImplicit(C, currAddressDecl);
+ currAddressPat =
+ BindingPattern::createImplicit(C, /*isLet*/ false, currAddressPat);
+ currAddressPat =
+ new (C) OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc());
+ currAddressPat->setImplicit();
+ StmtConditionElement cond[] = {
+ StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};
+
+ // Get the necessary protocol requirements.
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ auto *tensorArrayProto = C.getProtocol(
+ KnownProtocolKind::TensorArrayProtocol);
+ auto initName = DeclName(
+ C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
+ auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
+ auto *tensorHandleCountReq = getProtocolRequirement(
+ tensorArrayProto, C.Id_tensorHandleCount);
+
+ Type intType = C.getIntDecl()->getDeclaredType();
+ TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
+
+ // Iterate over members and call `self.member = MemberType(_owning:)`.
+ llvm::SmallVector<ASTNode, 2> thenMemberExprs;
+ llvm::SmallVector<ASTNode, 2> elseMemberExprs;
+ for (auto member : nominal->getStoredProperties()) {
+ auto memberType = parentDC->mapTypeIntoContext(
+ member->getValueInterfaceType());
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto module = nominal->getModuleContext();
+ auto confRef = module->lookupConformance(
+ memberType, tensorGroupProto);
+ assert(confRef && "Member does not conform to `TensorGroup`");
+
+ // Get member type's constructor, e.g. `MemberType.init(_owning:)`.
+ // Use protocol requirement declaration for the method by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberInitDecl = initReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the constructor.
+ if (confRef.isConcrete())
+ memberInitDecl = confRef.getConcrete()->getWitnessDecl(initReq);
+ assert(memberInitDecl && "Member constructor declaration must exist");
+ auto memberInitDRE = new (C) DeclRefExpr(
+ memberInitDecl, DeclNameLoc(), /*implicit*/ true);
+ memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+
+ // Create reference to member constructor: `MemberType.init(_owning:)`.
+ auto *memberInitExpr = new (C) ConstructorRefCallExpr(
+ memberInitDRE, memberTypeExpr);
+
+ auto *addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *thenInitCallExpr = CallExpr::createImplicit(
+ C, memberInitExpr, {addressDRE}, {C.getIdentifier("_owning")});
+
+ // Create a nil expression with type `UnsafePointer<CTensorHandle>?` for the
+ // `else` branch.
+ auto *nilDecl = C.getOptionalNoneDecl();
+ auto *elseInitExpr =
+ new (C) MemberRefExpr(addressTypeExpr, SourceLoc(), nilDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *elseInitCallExpr = CallExpr::createImplicit(
+ C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
+
+ // Assign the current member to the result of the initializer call.
+ auto *memberDRE = new (C) MemberRefExpr(
+ selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
+
+ auto *thenAssignMemberExpr = new (C) AssignExpr(
+ memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
+ auto *elseAssignMemberExpr = new (C) AssignExpr(
+ memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);
+
+ thenMemberExprs.push_back(thenAssignMemberExpr);
+ elseMemberExprs.push_back(elseAssignMemberExpr);
+
+ // Advance the current address.
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ DeclName advancedName(C, C.getIdentifier("advanced"),
+ {C.getIdentifier("by")});
+ auto *advancedMethodExpr =
+ new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
+ DeclNameRef(advancedName), DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Obtain `MemberType._tensorHandleCount`.
+ auto *memberCountMRE = new (C) MemberRefExpr(
+ memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Cast the tensor handle count to Int.
+ auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
+ {Identifier()});
+ auto *intInitExpr = new (C)
+ UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *intInitCallExpr = CallExpr::createImplicit(
+ C, intInitExpr, {memberCountMRE}, {Identifier()});
+
+ // Assign the new address.
+ auto *assignAddrCallExpr = CallExpr::createImplicit(
+ C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
+ assignAddrCallExpr,
+ /*Implicit*/ true);
+
+ thenMemberExprs.push_back(assignAddrExpr);
+ }
+
+ auto *thenBody = BraceStmt::create(
+ C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
+ /*implicit*/ true);
+
+ auto *elseBody = BraceStmt::create(
+ C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
+ /*implicit*/ true);
+
+ auto *ifStmt = new (C)
+ IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
+ /*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
+ /*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);
+
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
+ /*implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize the `init(_owning:count:)` function declaration.
+static ValueDecl
+*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
+ auto &C = derived.Context;
+ auto nominal = derived.Nominal;
+ auto parentDC = derived.getConformanceContext();
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ Type baseAddressType = BoundGenericType::get(
+ C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
+ Type addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+ Type intType = C.getIntDecl()->getDeclaredType();
+
+ auto *param1 = new (C) ParamDecl(SourceLoc(), SourceLoc(),
+ C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"),
+ parentDC);
+ param1->setSpecifier(ParamDecl::Specifier::Default);
+ param1->setInterfaceType(addressType);
+ auto *param2 = new (C) ParamDecl(SourceLoc(), SourceLoc(),
+ C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"),
+ parentDC);
+ param2->setSpecifier(ParamDecl::Specifier::Default);
+ param2->setInterfaceType(intType);
+ ParameterList *params = ParameterList::create(C, {param1, param2});
+
+ DeclName name(C, DeclBaseName::createConstructor(), params);
+ auto *initDecl =
+ new (C) ConstructorDecl(name, SourceLoc(), /*Failable*/ false,
+ SourceLoc(), /*Throws*/ false, SourceLoc(),
+ params, /*GenericParams*/ nullptr, parentDC);
+ initDecl->setImplicit();
+ initDecl->setSynthesized();
+ initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init, nullptr);
+
+ initDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({initDecl});
+ return initDecl;
+}
+
+ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
+ ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.Id_unpackTensorHandles)
+ return deriveTensorArrayProtocol_unpackTensorHandles(*this);
+ if (requirement->getBaseName() == Context.Id_tensorHandleCount)
+ return deriveTensorArrayProtocol_tensorHandleCount(*this);
+ if (requirement->getBaseName() == Context.Id_typeList)
+ return deriveTensorArrayProtocol_typeList(*this);
+ if (requirement->getBaseName() == DeclBaseName::createConstructor())
+ return deriveTensorArrayProtocol_init(*this);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_tensor_array_protocol_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformanceTensorGroup.cpp b/lib/Sema/DerivedConformanceTensorGroup.cpp
new file mode 100644
index 0000000..3aa33db
--- /dev/null
+++ b/lib/Sema/DerivedConformanceTensorGroup.cpp
@@ -0,0 +1,370 @@
+//===--- DerivedConformanceTensorGroup.cpp --------------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the TensorGroup protocol for
+// a nominal type.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+bool DerivedConformance::canDeriveTensorGroup(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Nominal type must be a struct (zero stored properties is okay).
+ // Note: we could extend synthesis to support classes.
+ auto *structDecl = dyn_cast<StructDecl>(nominal);
+ if (!structDecl)
+ return false;
+ // All stored properties must conform to `TensorGroup`.
+ auto &C = nominal->getASTContext();
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
+ if (v->getInterfaceType()->hasError())
+ return false;
+ auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
+ return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC);
+ });
+}
+
+// Return the protocol requirement with the specified name.
+static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
+ auto lookup = proto->lookupDirect(name);
+ lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
+ [](ValueDecl *v) {
+ return !isa<ProtocolDecl>(
+ v->getDeclContext()) ||
+ !v->isProtocolRequirement();
+ }),
+ lookup.end());
+ assert(lookup.size() == 1 && "Ambiguous protocol requirement");
+ return lookup.front();
+}
+
+/// Derive the body for the '_typeList' getter.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorGroup_typeList(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);
+
+ // Concatenate all member `_typeList` arrays.
+ Type arrayType = BoundGenericType::get(
+ C.getArrayDecl(), Type(),
+ {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
+ auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
+ auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
+ assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
+ ValueDecl *plusOpDecl = plusOpLookup.front();
+ Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
+ for (auto member : nominal->getStoredProperties()) {
+ auto plusOpExpr =
+ new (C) MemberRefExpr(arrayTypeExpr, SourceLoc(), plusOpDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto *memberTypeListExpr = new (C)
+ MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
+ DeclNameLoc(), /*Implicit*/ true);
+ // Create expression `lhsArg + rhsArg`.
+ auto *plusOpArgs =
+ TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
+ {}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
+ /*Implicit*/ true);
+ typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
+ /*Implicit*/ true);
+ }
+
+ // Return the resulting data types array.
+ auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
+ auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
+ /*Implicit*/ true);
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
+ /*Implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+/// Derive a '_typeList' implementation.
+static ValueDecl *deriveTensorGroup_typeList(DerivedConformance &derived) {
+ auto nominal = derived.Nominal;
+ ASTContext &C = derived.Context;
+
+ auto parentDC = derived.getConformanceContext();
+ Type dataTypeArrayType = BoundGenericType::get(
+ C.getArrayDecl(), Type(),
+ {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
+ auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);
+
+ // Create `_typeList` property declaration.
+ VarDecl *typeListDecl;
+ PatternBindingDecl *patDecl;
+ std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
+ C.Id_typeList, returnType, returnType, /*isStatic*/ true,
+ /*isFinal*/ false);
+
+ // Add `@inlinable` to the `_typeList` declaration.
+ if (nominal->getEffectiveAccess() > AccessLevel::Internal)
+ typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));
+
+ // Create `_typeList` getter.
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
+ typeListDecl, returnType);
+ getterDecl->setBodySynthesizer(deriveBodyTensorGroup_typeList, nullptr);
+ derived.addMembersToConformanceContext({typeListDecl, patDecl});
+
+ return typeListDecl;
+}
+
+// Synthesize body for `init(_owning:)`.
+static std::pair<BraceStmt *, bool>
+deriveBodyTensorGroup_init(AbstractFunctionDecl *funcDecl, void *) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ auto baseAddressType = BoundGenericType::get(
+ C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
+ auto addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+ auto *addressTypeExpr = TypeExpr::createImplicit(addressType, C);
+
+ // Get references to `self` and parameter declarations.
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *selfDRE = new (C)
+ DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+ auto *paramDecl = funcDecl->getParameters()->get(0);
+ auto *paramDRE = new (C)
+ DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
+
+ // Create an `if var` statement for the current address.
+ VarDecl *currAddressDecl = new (C) VarDecl(
+ /*IsStatic*/ false, VarDecl::Introducer::Var, /*IsCaptureList*/ false,
+ SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
+ currAddressDecl->setImplicit();
+ currAddressDecl->setHasNonPatternBindingInit(true);
+ currAddressDecl->setInterfaceType(baseAddressType);
+
+ Pattern *currAddressPat = NamedPattern::createImplicit(C, currAddressDecl);
+ currAddressPat =
+ BindingPattern::createImplicit(C, /*isLet*/ false, currAddressPat);
+ currAddressPat =
+ new (C) OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc());
+ currAddressPat->setImplicit();
+ StmtConditionElement cond[] = {
+ StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};
+
+ // Get the necessary protocol requirements.
+ auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
+ auto *tensorArrayProto = C.getProtocol(
+ KnownProtocolKind::TensorArrayProtocol);
+ auto initName = DeclName(
+ C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
+ auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
+ auto *tensorHandleCountReq = getProtocolRequirement(
+ tensorArrayProto, C.Id_tensorHandleCount);
+
+ Type intType = C.getIntDecl()->getDeclaredType();
+ TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
+
+ // Iterate through the `TensorGroup`-conforming members and call
+ // `self.member = MemberType(_owning:)`.
+ llvm::SmallVector<ASTNode, 2> thenMemberExprs;
+ llvm::SmallVector<ASTNode, 2> elseMemberExprs;
+ for (auto member : nominal->getStoredProperties()) {
+ auto memberType = parentDC->mapTypeIntoContext(
+ member->getValueInterfaceType());
+ auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
+ auto module = nominal->getModuleContext();
+ auto confRef = module->lookupConformance(
+ memberType, tensorGroupProto);
+ assert(confRef && "Member does not conform to `TensorGroup`");
+
+ // Get member type's constructor, e.g. `MemberType.init(_owning:)`.
+ // Use protocol requirement declaration for the method by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberInitDecl = initReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the constructor.
+ if (confRef.isConcrete())
+ memberInitDecl = confRef.getConcrete()->getWitnessDecl(initReq);
+ assert(memberInitDecl && "Member constructor declaration must exist");
+ auto memberInitDRE = new (C) DeclRefExpr(
+ memberInitDecl, DeclNameLoc(), /*implicit*/ true);
+ memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+
+ // Create reference to member constructor: `MemberType.init(_owning:)`.
+ auto *memberInitExpr = new (C) ConstructorRefCallExpr(
+ memberInitDRE, memberTypeExpr);
+
+ auto *addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *thenInitCallExpr = CallExpr::createImplicit(
+ C, memberInitExpr, {addressDRE}, {C.getIdentifier("_owning")});
+
+ // Create a nil expression with type `UnsafePointer<CTensorHandle>?` for the
+ // `else` branch.
+ auto *nilDecl = C.getOptionalNoneDecl();
+ auto *elseInitExpr =
+ new (C) MemberRefExpr(addressTypeExpr, SourceLoc(), nilDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *elseInitCallExpr = CallExpr::createImplicit(
+ C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
+
+ // Assign the current member to the result of the initializer call.
+ auto *memberDRE = new (C) MemberRefExpr(
+ selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
+
+ auto *thenAssignMemberExpr = new (C) AssignExpr(
+ memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
+ auto *elseAssignMemberExpr = new (C) AssignExpr(
+ memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);
+
+ thenMemberExprs.push_back(thenAssignMemberExpr);
+ elseMemberExprs.push_back(elseAssignMemberExpr);
+
+ // Advance the current address.
+ DeclName advancedName(C, C.getIdentifier("advanced"),
+ {C.getIdentifier("by")});
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *advancedMethodExpr =
+ new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
+ DeclNameRef(advancedName), DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Obtain `MemberType._tensorHandleCount`.
+ auto *memberCountMRE = new (C) MemberRefExpr(
+ memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
+ /*Implicit*/ true);
+
+ // Cast the tensor handle count to Int.
+ auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
+ {Identifier()});
+ auto *intInitExpr = new (C)
+ UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *intInitCallExpr = CallExpr::createImplicit(
+ C, intInitExpr, {memberCountMRE}, {Identifier()});
+
+ // Assign the new address.
+ auto *assignAddrCallExpr = CallExpr::createImplicit(
+ C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
+ // NOTE(TF-1054): create new `DeclRefExpr` to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ addressDRE = new (C) DeclRefExpr(
+ currAddressDecl, DeclNameLoc(), /*implicit*/ true);
+ auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
+ assignAddrCallExpr,
+ /*Implicit*/ true);
+
+ thenMemberExprs.push_back(assignAddrExpr);
+ }
+
+ auto *thenBody = BraceStmt::create(
+ C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
+ /*implicit*/ true);
+
+ auto *elseBody = BraceStmt::create(
+ C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
+ /*implicit*/ true);
+
+ auto *ifStmt = new (C)
+ IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
+ /*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
+ /*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);
+
+ auto *braceStmt = BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
+ /*implicit*/ true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize a constructor declaration for a `TensorGroup` method requirement.
+static ValueDecl *deriveTensorGroup_constructor(
+ DerivedConformance &derived, Identifier argumentName,
+ Identifier parameterName, Type parameterType, Type returnType,
+ AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
+ auto nominal = derived.Nominal;
+ auto &C = derived.Context;
+ auto parentDC = derived.getConformanceContext();
+
+ auto *param =
+ new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentName, SourceLoc(),
+ parameterName, parentDC);
+ param->setSpecifier(ParamDecl::Specifier::Default);
+ param->setInterfaceType(parameterType);
+ ParameterList *params = ParameterList::create(C, {param});
+
+ DeclName name(C, DeclBaseName::createConstructor(), params);
+ auto *initDecl =
+ new (C) ConstructorDecl(name, SourceLoc(), /*Failable*/ false,
+ SourceLoc(), /*Throws*/ false, SourceLoc(),
+ params, /*GenericParams*/ nullptr, parentDC);
+ initDecl->setImplicit();
+ initDecl->setSynthesized();
+ initDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
+
+ initDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({initDecl});
+ return initDecl;
+}
+
+// Synthesize the `init(_owning:)` function declaration.
+static ValueDecl *deriveTensorGroup_init(DerivedConformance &derived) {
+ auto &C = derived.Context;
+
+ // Obtain the address type.
+ auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
+ Type baseAddressType = BoundGenericType::get(
+ C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
+ Type addressType = BoundGenericType::get(
+ C.getOptionalDecl(), Type(), {baseAddressType});
+ Type voidType = C.getVoidDecl()->getDeclaredInterfaceType();
+
+ return deriveTensorGroup_constructor(
+ derived, C.getIdentifier("_owning"),
+ C.getIdentifier("tensorHandles"), addressType, voidType,
+ {deriveBodyTensorGroup_init, nullptr});
+}
+
+ValueDecl *DerivedConformance::deriveTensorGroup(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.Id_typeList)
+ return deriveTensorGroup_typeList(*this);
+ if (requirement->getBaseName() == DeclBaseName::createConstructor())
+ return deriveTensorGroup_init(*this);
+ Context.Diags.diagnose(requirement->getLoc(), diag::broken_tensor_group_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformanceVectorProtocol.cpp b/lib/Sema/DerivedConformanceVectorProtocol.cpp
new file mode 100644
index 0000000..5c9de06
--- /dev/null
+++ b/lib/Sema/DerivedConformanceVectorProtocol.cpp
@@ -0,0 +1,267 @@
+//===--- DerivedConformanceVectorProtocol.cpp -----------------------------===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements explicit derivation of the VectorProtocol protocol for
+// struct types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeSynthesis.h"
+#include "TypeChecker.h"
+#include "swift/AST/Decl.h"
+#include "swift/AST/Expr.h"
+#include "swift/AST/GenericSignature.h"
+#include "swift/AST/Module.h"
+#include "swift/AST/ParameterList.h"
+#include "swift/AST/Pattern.h"
+#include "swift/AST/ProtocolConformance.h"
+#include "swift/AST/Stmt.h"
+#include "swift/AST/Types.h"
+#include "DerivedConformances.h"
+
+using namespace swift;
+
+// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
+// it conforms to `VectorProtocol` in the given context. Otherwise, return
+// `nullptr`.
+static Type getVectorProtocolVectorSpaceScalarAssocType(
+ VarDecl *varDecl, DeclContext *DC) {
+ auto &C = varDecl->getASTContext();
+ auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
+ if (varDecl->getInterfaceType()->hasError())
+ return nullptr;
+ auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
+ auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC);
+ if (!conf)
+ return nullptr;
+ return conf.getTypeWitnessByName(varType, C.Id_VectorSpaceScalar);
+}
+
+// Return the `VectorSpaceScalar` associated type for the given nominal type in
+// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived.
+static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Nominal type must be a struct. (Zero stored properties is okay.)
+ if (!isa<StructDecl>(nominal))
+ return nullptr;
+ // If all stored properties conform to `VectorProtocol` and have the same
+ // `VectorSpaceScalar` associated type, return that `VectorSpaceScalar`
+ // associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived.
+ Type sameScalarType;
+ for (auto member : nominal->getStoredProperties()) {
+ if (member->getInterfaceType()->hasError())
+ return nullptr;
+ auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC);
+ // If stored property does not conform to `VectorProtocol`, return nullptr.
+ if (!scalarType)
+ return nullptr;
+ // If same `VectorSpaceScalar` type has not been set, set it for the first
+ // time.
+ if (!sameScalarType) {
+ sameScalarType = scalarType;
+ continue;
+ }
+ // If stored property `VectorSpaceScalar` types do not match, return
+ // nullptr.
+ if (!scalarType->isEqual(sameScalarType))
+ return nullptr;
+ }
+ return sameScalarType;
+}
+
+bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
+ DeclContext *DC) {
+ // Must not have any `let` stored properties with an initial value.
+ // - This restriction may be lifted later with support for "true" memberwise
+ // initializers that initialize all stored properties, including initial
+ // value information.
+ if (hasLetStoredPropertyWithInitialValue(nominal))
+ return false;
+ // Must be able to derive `VectorSpaceScalar` associated type.
+ return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC));
+}
+
+// Synthesize body for a `VectorProtocol` method requirement.
+static std::pair<BraceStmt *, bool>
+deriveBodyVectorProtocol_method(AbstractFunctionDecl *funcDecl,
+ Identifier methodName,
+ Identifier methodParamLabel) {
+ auto *parentDC = funcDecl->getParent();
+ auto *nominal = parentDC->getSelfNominalTypeDecl();
+ auto &C = nominal->getASTContext();
+
+ // Create memberwise initializer: `Nominal.init(...)`.
+ auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
+ assert(memberwiseInitDecl && "Memberwise initializer must exist");
+ auto *initDRE =
+ new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
+ initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
+ auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
+ DeclNameLoc(), nominal, funcDecl,
+ funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
+ auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
+
+ // Get method protocol requirement.
+ auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
+ auto *methodReq = getProtocolRequirement(vectorProto, methodName);
+
+ // Get references to `self` and parameter declarations.
+ auto *selfDecl = funcDecl->getImplicitSelfDecl();
+ auto *paramDecl = funcDecl->getParameters()->get(0);
+
+ // Create call expression applying a member method to the parameter.
+ // Format: `<member>.method(<parameter>)`.
+ // Example: `x.scaled(by: scalar)`.
+ auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
+ auto *module = nominal->getModuleContext();
+ auto memberType =
+ parentDC->mapTypeIntoContext(member->getValueInterfaceType());
+ auto confRef = module->lookupConformance(memberType, vectorProto);
+ assert(confRef && "Member does not conform to `VectorNumeric`");
+
+ // Get member type's method, e.g. `Member.scaled(by:)`.
+ // Use protocol requirement declaration for the method by default: this
+ // will be dynamically dispatched.
+ ValueDecl *memberMethodDecl = methodReq;
+ // If conformance reference is concrete, then use concrete witness
+ // declaration for the operator.
+ if (confRef.isConcrete()) {
+ if (auto *concreteMemberMethodDecl =
+ confRef.getConcrete()->getWitnessDecl(methodReq))
+ memberMethodDecl = concreteMemberMethodDecl;
+ assert(memberMethodDecl);
+ }
+ assert(memberMethodDecl && "Member method declaration must exist");
+
+ // Create reference to member method: `x.scaled(by:)`.
+ // NOTE(TF-1054): create new `DeclRefExpr`s per loop iteration to avoid
+ // `ConstraintSystem::resolveOverload` error.
+ auto *selfDRE =
+ new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
+ auto memberExpr =
+ new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
+ /*Implicit*/ true);
+ auto memberMethodExpr =
+ new (C) MemberRefExpr(memberExpr, SourceLoc(), memberMethodDecl,
+ DeclNameLoc(), /*Implicit*/ true);
+ auto *paramDRE =
+ new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
+
+ // Create expression: `x.scaled(by: scalar)`.
+ return CallExpr::createImplicit(C, memberMethodExpr, {paramDRE},
+ {methodParamLabel});
+ };
+
+ // Create array of member method call expressions.
+ llvm::SmallVector<Expr *, 2> memberMethodCallExprs;
+ llvm::SmallVector<Identifier, 2> memberNames;
+ for (auto *member : nominal->getStoredProperties()) {
+ memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
+ memberNames.push_back(member->getName());
+ }
+ // Call memberwise initializer with member method call expressions.
+ auto *callExpr =
+ CallExpr::createImplicit(C, initExpr, memberMethodCallExprs, memberNames);
+ ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
+ auto *braceStmt =
+ BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
+ return std::pair<BraceStmt *, bool>(braceStmt, false);
+}
+
+// Synthesize function declaration for a `VectorProtocol` method requirement.
+static ValueDecl *deriveVectorProtocol_method(
+ DerivedConformance &derived, Identifier methodBaseName,
+ Identifier argumentLabel, Identifier parameterName, Type parameterType,
+ Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
+ auto nominal = derived.Nominal;
+ auto &C = derived.Context;
+ auto parentDC = derived.getConformanceContext();
+
+ auto *param =
+ new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentLabel, SourceLoc(),
+ parameterName, parentDC);
+ param->setSpecifier(ParamDecl::Specifier::Default);
+ param->setInterfaceType(parameterType);
+ ParameterList *params = ParameterList::create(C, {param});
+
+ DeclName declName(C, methodBaseName, params);
+ auto funcDecl = FuncDecl::createImplicit(
+ C, StaticSpellingKind::None, declName, SourceLoc(), /*Async*/ false,
+ /*Throws*/ false,
+ /*GenericParams*/ nullptr, params, returnType, parentDC);
+ funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
+
+ funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
+
+ derived.addMembersToConformanceContext({funcDecl});
+ return funcDecl;
+}
+
+/// Synthesize a method declaration that has the following signture:
+/// func {methodBaseName}(
+/// {argumentLabel} {parameterName}: VectorSpaceScalar
+/// ) -> Self
+static ValueDecl *deriveVectorProtocol_unaryMethodOnScalar(
+ DerivedConformance &derived, Identifier methodBaseName,
+ Identifier argumentLabel, Identifier parameterName) {
+ auto &C = derived.Context;
+ auto *nominal = derived.Nominal;
+ auto *parentDC = derived.getConformanceContext();
+
+ auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
+ auto scalarType = deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
+ ->mapTypeOutOfContext();
+
+ auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl,
+ void *ctx) -> std::pair<BraceStmt *, bool> {
+ auto methodNameAndLabel = reinterpret_cast<Identifier *>(ctx);
+ return deriveBodyVectorProtocol_method(
+ funcDecl, methodNameAndLabel[0], methodNameAndLabel[1]);
+ };
+ Identifier baseNameAndLabel[2] = {methodBaseName, argumentLabel};
+ return deriveVectorProtocol_method(
+ derived, methodBaseName, argumentLabel, parameterName, scalarType,
+ selfInterfaceType,
+ {bodySynthesizer, C.AllocateCopy(baseNameAndLabel).data()});
+}
+
+ValueDecl *DerivedConformance::deriveVectorProtocol(ValueDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ auto &C = requirement->getASTContext();
+ if (requirement->getBaseName() == Context.Id_scaled)
+ return deriveVectorProtocol_unaryMethodOnScalar(
+ *this, C.Id_scaled, C.Id_by, C.Id_scale);
+ if (requirement->getBaseName() == Context.Id_adding)
+ return deriveVectorProtocol_unaryMethodOnScalar(
+ *this, C.Id_adding, Identifier(), C.Id_x);
+ if (requirement->getBaseName() == Context.Id_subtracting)
+ return deriveVectorProtocol_unaryMethodOnScalar(
+ *this, C.Id_subtracting, Identifier(), C.Id_x);
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_vector_protocol_requirement);
+ return nullptr;
+}
+
+Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) {
+ // Diagnose conformances in disallowed contexts.
+ if (checkAndDiagnoseDisallowedContext(requirement))
+ return nullptr;
+ if (requirement->getBaseName() == Context.Id_VectorSpaceScalar)
+ return deriveVectorProtocol_VectorSpaceScalar(
+ Nominal, getConformanceContext());
+ Context.Diags.diagnose(requirement->getLoc(),
+ diag::broken_vector_protocol_requirement);
+ return nullptr;
+}
diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp
index b3eb907..4708dec 100644
--- a/lib/Sema/DerivedConformances.cpp
+++ b/lib/Sema/DerivedConformances.cpp
@@ -83,6 +83,29 @@
if (*derivableKind == KnownDerivableProtocolKind::Differentiable)
return true;
+ // SWIFT_ENABLE_TENSORFLOW
+ if (*derivableKind == KnownDerivableProtocolKind::PointwiseMultiplicative)
+ return canDerivePointwiseMultiplicative(Nominal, DC);
+
+ if (*derivableKind == KnownDerivableProtocolKind::ElementaryFunctions)
+ return canDeriveElementaryFunctions(Nominal, DC);
+
+ if (*derivableKind == KnownDerivableProtocolKind::KeyPathIterable)
+ return canDeriveKeyPathIterable(Nominal);
+
+ if (*derivableKind == KnownDerivableProtocolKind::TensorArrayProtocol)
+ return canDeriveTensorArrayProtocol(Nominal, DC);
+
+ if (*derivableKind == KnownDerivableProtocolKind::TensorGroup)
+ return canDeriveTensorGroup(Nominal, DC);
+
+ if (*derivableKind == KnownDerivableProtocolKind::VectorProtocol)
+ return canDeriveVectorProtocol(Nominal, DC);
+
+ if (*derivableKind == KnownDerivableProtocolKind::EuclideanDifferentiable)
+ return canDeriveEuclideanDifferentiable(Nominal, DC);
+ // SWIFT_ENABLE_TENSORFLOW END
+
if (auto *enumDecl = dyn_cast<EnumDecl>(Nominal)) {
switch (*derivableKind) {
// The presence of a raw type is an explicit declaration that
@@ -268,7 +291,10 @@
// Local function that retrieves the requirement with the same name as
// the provided requirement, but within the given known protocol.
- auto getRequirement = [&](KnownProtocolKind kind) -> ValueDecl * {
+ // SWIFT_ENABLE_TENSORFLOW
+ auto getRequirement = [&](KnownProtocolKind kind,
+ llvm::function_ref<bool(ValueDecl *)> filter =
+ nullptr) -> ValueDecl * {
// Dig out the protocol.
auto proto = ctx.getProtocol(kind);
if (!proto) return nullptr;
@@ -283,6 +309,17 @@
}
// Retrieve the requirement.
+ // SWIFT_ENABLE_TENSORFLOW
+ // Filter requirements, if `filter` function is specified.
+ if (filter) {
+ auto results = proto->lookupDirect(name);
+ llvm::erase_if(results, [&](ValueDecl *v) {
+ return !isa<ProtocolDecl>(v->getDeclContext()) ||
+ !v->isProtocolRequirement() || !filter(v);
+ });
+ return results.empty() ? nullptr : results.front();
+ }
+ // SWIFT_ENABLE_TENSORFLOW END
return proto->getSingleRequirement(name);
};
@@ -320,6 +357,36 @@
if (name.isSimpleName(ctx.Id_zero))
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
+ // SWIFT_ENABLE_TENSORFLOW
+ // EuclideanDifferentiable.differentiableVectorView
+ if (name.isSimpleName(ctx.Id_differentiableVectorView))
+ return getRequirement(KnownProtocolKind::EuclideanDifferentiable);
+
+ // PointwiseMultiplicative.one
+ if (name.isSimpleName(ctx.Id_one))
+ return getRequirement(KnownProtocolKind::PointwiseMultiplicative);
+
+ // PointwiseMultiplicative.reciprocal
+ if (name.isSimpleName(ctx.Id_reciprocal))
+ return getRequirement(KnownProtocolKind::PointwiseMultiplicative);
+
+ // KeyPathIterable.allKeyPaths
+ if (name.isSimpleName(ctx.Id_allKeyPaths))
+ return getRequirement(KnownProtocolKind::KeyPathIterable);
+
+ // TensorArrayProtocol._tensorHandleCount
+ if (name.isSimpleName(ctx.Id_tensorHandleCount))
+ return getRequirement(KnownProtocolKind::TensorArrayProtocol);
+
+ // TensorArrayProtocol._typeList
+ if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic())
+ return getRequirement(KnownProtocolKind::TensorArrayProtocol);
+
+ // TensorGroup._typeList
+ if (name.isSimpleName(ctx.Id_typeList))
+ return getRequirement(KnownProtocolKind::TensorGroup);
+ // SWIFT_ENABLE_TENSORFLOW END
+
return nullptr;
}
@@ -359,6 +426,94 @@
return getRequirement(KnownProtocolKind::Hashable);
}
+ // SWIFT_ENABLE_TENSORFLOW
+ // AdditiveArithmetic.+
+ // AdditiveArithmetic.-
+ if (func->isOperator() && (name.getBaseName() == "+" ||
+ name.getBaseName() == "-")) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 2)
+ return getRequirement(KnownProtocolKind::AdditiveArithmetic);
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // PointwiseMultiplicative.(.*)
+ if (func->isOperator() && name.getBaseName() == ".*") {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 2)
+ return getRequirement(KnownProtocolKind::PointwiseMultiplicative);
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // ElementaryFunctions requirements
+ if (name.isCompoundName()) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 1 && (false
+#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) || name.getBaseName() == NAME
+#include "DerivedConformanceElementaryFunctions.def"
+#undef ELEMENTARY_FUNCTION_UNARY
+ )) {
+ return getRequirement(KnownProtocolKind::ElementaryFunctions);
+ }
+ if (argumentNames.size() == 2) {
+ if (name.getBaseName() == "root")
+ return getRequirement(KnownProtocolKind::ElementaryFunctions);
+ if (name.getBaseName() == "pow") {
+ return getRequirement(
+ KnownProtocolKind::ElementaryFunctions,
+ [&](ValueDecl *v) {
+ auto *funcDecl = dyn_cast<FuncDecl>(v);
+ if (!funcDecl)
+ return false;
+ return funcDecl->getParameters()->get(1)->getName() ==
+ func->getParameters()->get(1)->getName();
+ });
+ }
+ }
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // VectorProtocol.scaled(by:)
+ if (name.isCompoundName() && name.getBaseName() == ctx.Id_scaled) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 1 &&
+ argumentNames[0] == ctx.getIdentifier("by"))
+ return getRequirement(KnownProtocolKind::VectorProtocol);
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // VectorProtocol.adding(_:)
+ // VectorProtocol.subtracting(_:)
+ if (name.isCompoundName() &&
+ (name.getBaseName() == ctx.Id_adding ||
+ name.getBaseName() == ctx.Id_subtracting)) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 1 && argumentNames[0].empty())
+ return getRequirement(KnownProtocolKind::VectorProtocol);
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // TensorArrayProtocol._unpackTensorHandles(into:)
+ if (name.isCompoundName() &&
+ name.getBaseName() == ctx.Id_unpackTensorHandles) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 1 &&
+ argumentNames[0] == ctx.getIdentifier("into")) {
+ return getRequirement(KnownProtocolKind::TensorArrayProtocol);
+ }
+ }
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // Differentiable.move(along:)
+ if (name.isCompoundName() &&
+ name.getBaseName() == ctx.Id_move) {
+ auto argumentNames = name.getArgumentNames();
+ if (argumentNames.size() == 1 &&
+ argumentNames[0] == ctx.getIdentifier("along")) {
+ return getRequirement(KnownProtocolKind::Differentiable);
+ }
+ }
+
return nullptr;
}
@@ -379,6 +534,19 @@
// Decodable.init(from: Decoder)
if (argumentNames[0] == ctx.Id_from)
return getRequirement(KnownProtocolKind::Decodable);
+
+ // SWIFT_ENABLE_TENSORFLOW
+ // TensorGroup.init(_owning:)
+ if (argumentNames[0] == ctx.getIdentifier("_owning")) {
+ return getRequirement(KnownProtocolKind::TensorGroup);
+ }
+ } else if (argumentNames.size() == 2) {
+ // SWIFT_ENABLE_TENSORFLOW
+ // TensorArrayProtocol.init(_owning:count)
+ if (argumentNames[0] == ctx.getIdentifier("_owning") &&
+ argumentNames[1] == ctx.getIdentifier("count")) {
+ return getRequirement(KnownProtocolKind::TensorArrayProtocol);
+ }
}
return nullptr;
@@ -398,6 +566,16 @@
if (name.isSimpleName(ctx.Id_TangentVector))
return getRequirement(KnownProtocolKind::Differentiable);
+ // SWIFT_ENABLE_TENSORFLOW
+ // KeyPathIterable.AllKeyPaths
+ if (name.isSimpleName(ctx.Id_AllKeyPaths))
+ return getRequirement(KnownProtocolKind::KeyPathIterable);
+
+ // VectorProtocol.VectorSpaceScalar
+ if (name.isSimpleName(ctx.Id_VectorSpaceScalar))
+ return getRequirement(KnownProtocolKind::VectorProtocol);
+ // SWIFT_ENABLE_TENSORFLOW END
+
return nullptr;
}
@@ -447,6 +625,60 @@
return getterDecl;
}
+// SWIFT_ENABLE_TENSORFLOW
+std::pair<AccessorDecl *, AccessorDecl *>
+DerivedConformance::addGetterAndSetterToMutableDerivedProperty(
+ VarDecl *property, Type propertyContextType) {
+ auto *getter = declareDerivedPropertyGetter(property, propertyContextType);
+ auto *setter = declareDerivedPropertySetter(property, propertyContextType);
+ property->setImplInfo(StorageImplInfo::getMutableComputed());
+ property->setAccessors(SourceLoc(), {getter, setter}, SourceLoc());
+ return std::make_pair(getter, setter);
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
+// SWIFT_ENABLE_TENSORFLOW
+AccessorDecl *
+DerivedConformance::declareDerivedPropertySetter(VarDecl *property,
+ Type propertyContextType) {
+ bool isStatic = property->isStatic();
+ bool isFinal = property->isFinal();
+
+ auto &C = property->getASTContext();
+ auto parentDC = property->getDeclContext();
+
+ auto propertyInterfaceType = property->getInterfaceType();
+ auto propertyParam = new (C) ParamDecl(SourceLoc(), SourceLoc(), Identifier(),
+ property->getLoc(), C.getIdentifier("newValue"), parentDC);
+ propertyParam->setSpecifier(ParamDecl::Specifier::Default);
+ propertyParam->setInterfaceType(propertyInterfaceType);
+
+ ParameterList *params = ParameterList::create(C, propertyParam);
+
+ auto setterDecl = AccessorDecl::create(C,
+ /*FuncLoc*/ SourceLoc(), /*AccessorKeywordLoc*/ SourceLoc(),
+ AccessorKind::Set, property, /*StaticLoc*/ SourceLoc(),
+ StaticSpellingKind::None, /*Throws*/ false, /*ThrowsLoc*/ SourceLoc(),
+ /*GenericParams*/ nullptr, params, propertyInterfaceType, parentDC);
+ setterDecl->setImplicit();
+ setterDecl->setStatic(isStatic);
+ // Set mutating if parent is not a class.
+ if (!parentDC->getSelfClassDecl())
+ setterDecl->setSelfAccessKind(SelfAccessKind::Mutating);
+
+ // If this is supposed to be a final method, mark it as such.
+ assert(isFinal || !parentDC->getSelfClassDecl());
+ if (isFinal && parentDC->getSelfClassDecl() &&
+ !setterDecl->isFinal())
+ setterDecl->getAttrs().add(new (C) FinalAttr(/*Implicit*/ true));
+
+ // Compute the interface type of the setter.
+ setterDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
+ setterDecl->copyFormalAccessFrom(property);
+
+ return setterDecl;
+}
+
std::pair<VarDecl *, PatternBindingDecl *>
DerivedConformance::declareDerivedProperty(Identifier name,
Type propertyInterfaceType,
@@ -457,6 +689,10 @@
VarDecl *propDecl = new (Context)
VarDecl(/*IsStatic*/ isStatic, VarDecl::Introducer::Var,
/*IsCaptureList*/ false, SourceLoc(), name, parentDC);
+ // SWIFT_ENABLE_TENSORFLOW
+ // TODO: Upstream this change to master.
+ if (isFinal && parentDC->getSelfClassDecl())
+ propDecl->getAttrs().add(new (Context) FinalAttr(/*Implicit*/ true));
propDecl->setImplicit();
propDecl->copyFormalAccessFrom(Nominal, /*sourceIsParentContext*/ true);
propDecl->setInterfaceType(propertyInterfaceType);
diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h
index c5dba4b..2321b59 100644
--- a/lib/Sema/DerivedConformances.h
+++ b/lib/Sema/DerivedConformances.h
@@ -277,6 +277,94 @@
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveDecodable(ValueDecl *requirement);
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Determine if a KeyPathIterable requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveKeyPathIterable(NominalTypeDecl *type);
+
+ /// Derive a KeyPathIterable requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveKeyPathIterable(ValueDecl *requirement);
+
+ /// Derive a KeyPathIterable type witness for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ Type deriveKeyPathIterable(AssociatedTypeDecl *assocType);
+
+ /// Determine if a TensorArrayProtocol requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveTensorArrayProtocol(NominalTypeDecl *type,
+ DeclContext *DC);
+
+ /// Derive a TensorArrayProtocol requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveTensorArrayProtocol(ValueDecl *requirement);
+
+ /// Determine if a TensorGroup requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveTensorGroup(NominalTypeDecl *type, DeclContext *DC);
+
+ /// Derive a TensorGroup requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveTensorGroup(ValueDecl *requirement);
+
+ /// Determine if a PointwiseMultiplicative requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDerivePointwiseMultiplicative(NominalTypeDecl *type,
+ DeclContext *DC);
+
+ /// Derive an PointwiseMultiplicative requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *derivePointwiseMultiplicative(ValueDecl *requirement);
+
+ /// Determine if an ElementaryFunctions requirement can be derived for a
+ /// type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveElementaryFunctions(NominalTypeDecl *type,
+ DeclContext *DC);
+
+ /// Derive an ElementaryFunctions requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveElementaryFunctions(ValueDecl *requirement);
+
+ /// Determine if a VectorProtocol requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveVectorProtocol(NominalTypeDecl *type,
+ DeclContext *DC);
+
+ /// Derive a VectorProtocol requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveVectorProtocol(ValueDecl *requirement);
+
+ /// Derive a VectorProtocol type witness for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ Type deriveVectorProtocol(AssociatedTypeDecl *assocType);
+
+ /// Determine if a Differentiable requirement can be derived for a type.
+ ///
+ /// \returns True if the requirement can be derived.
+ static bool canDeriveEuclideanDifferentiable(NominalTypeDecl *type,
+ DeclContext *DC);
+
+ /// Derive a EuclideanDifferentiable requirement for a nominal type.
+ ///
+ /// \returns the derived member, which will also be added to the type.
+ ValueDecl *deriveEuclideanDifferentiable(ValueDecl *requirement);
+ // SWIFT_ENABLE_TENSORFLOW END
+
/// Declare a read-only property.
std::pair<VarDecl *, PatternBindingDecl *>
declareDerivedProperty(Identifier name, Type propertyInterfaceType,
@@ -292,6 +380,19 @@
static AccessorDecl *declareDerivedPropertyGetter(VarDecl *property,
Type propertyContextType);
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Add a getter and setter to a derived property. The property becomes
+ /// mutable.
+ static std::pair<AccessorDecl *, AccessorDecl *>
+ addGetterAndSetterToMutableDerivedProperty(VarDecl *property,
+ Type propertyContextType);
+
+ /// Declare a setter for a derived property.
+ /// The setter will not be added to the property yet.
+ static AccessorDecl *declareDerivedPropertySetter(VarDecl *property,
+ Type propertyContextType);
+ // SWIFT_ENABLE_TENSORFLOW END
+
/// Build a reference to the 'self' decl of a derived function.
static DeclRefExpr *createSelfDeclRef(AbstractFunctionDecl *fn);
diff --git a/lib/Sema/LookupVisibleDecls.cpp b/lib/Sema/LookupVisibleDecls.cpp
index 2989603..bcd65d8 100644
--- a/lib/Sema/LookupVisibleDecls.cpp
+++ b/lib/Sema/LookupVisibleDecls.cpp
@@ -850,6 +850,12 @@
void foundDecl(ValueDecl *VD, DeclVisibilityKind Reason,
DynamicLookupInfo dynamicLookupInfo) override {
+ // SWIFT_ENABLE_TENSORFLOW
+ // Suppress "sequenced" as a result, because it crashes completions.
+ // TODO(TF-315): Fix properly and then remove this.
+ if (isa<FuncDecl>(VD) &&
+ cast<FuncDecl>(VD)->getBaseIdentifier().str() == "sequenced")
+ return;
if (!Results.insert({VD, Reason, dynamicLookupInfo}))
return;
diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp
index a0fd2ae..4aa6289 100644
--- a/lib/Sema/MiscDiagnostics.cpp
+++ b/lib/Sema/MiscDiagnostics.cpp
@@ -229,7 +229,7 @@
});
}
}
-
+
// If we have an assignment expression, scout ahead for acceptable _'s.
if (auto *AE = dyn_cast<AssignExpr>(E)) {
auto destExpr = AE->getDest();
diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp
index c090319..443cfdc 100644
--- a/lib/Sema/TypeCheckAttr.cpp
+++ b/lib/Sema/TypeCheckAttr.cpp
@@ -257,6 +257,9 @@
void visitDifferentiableAttr(DifferentiableAttr *attr);
void visitDerivativeAttr(DerivativeAttr *attr);
void visitTransposeAttr(TransposeAttr *attr);
+ // SWIFT_ENABLE_TENSORFLOW
+ void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
+ // SWIFT_ENABLE_TENSORFLOW END
void visitAsyncHandlerAttr(AsyncHandlerAttr *attr) {
if (!Ctx.LangOpts.EnableExperimentalConcurrency) {
@@ -4063,11 +4066,10 @@
}
}
// Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
- // TODO(TF-129): Enable `set` when differentiation supports inout parameters.
// TODO(TF-1080): Enable `read` and `modify` when differentiation supports
// coroutines.
if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
- if (!accessor->isGetter())
+ if (!accessor->isGetter() && !accessor->isSetter())
original = nullptr;
// Diagnose if original `AbstractFunctionDecl` could not be resolved.
if (!original) {
@@ -5178,3 +5180,71 @@
// Set the resolved linearity parameter indices in the attribute.
attr->setParameterIndices(linearParamIndices);
}
+
+// SWIFT_ENABLE_TENSORFLOW
+static bool
+compilerEvaluableAllowedInExtensionDecl(ExtensionDecl *extensionDecl) {
+ auto extendedTypeKind = extensionDecl->getExtendedType()->getKind();
+ return extendedTypeKind == TypeKind::Enum ||
+ extendedTypeKind == TypeKind::Protocol ||
+ extendedTypeKind == TypeKind::Struct ||
+ extendedTypeKind == TypeKind::BoundGenericEnum ||
+ extendedTypeKind == TypeKind::BoundGenericStruct;
+}
+
+void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
+ // Check that the function is defined in an allowed context.
+ // TODO(marcrasi): In many cases, we can probably generate a more informative
+ // error message than just saying that it's "not allowed here". (Like "not
+ // allowed in a class [point at the class decl], put it at the top level or in
+ // a struct instead").
+ auto declContext = D->getDeclContext();
+ switch (declContext->getContextKind()) {
+ case DeclContextKind::AbstractFunctionDecl:
+ // Nested functions are okay.
+ break;
+ case DeclContextKind::ExtensionDecl:
+ // Enum, Protocol, and Struct extensions are okay. For Enums and Structs
+ // extensions, the extended type must be compiler-representable.
+ // TODO(marcrasi): Check that the extended type is compiler-representable.
+ if (!compilerEvaluableAllowedInExtensionDecl(
+ cast<ExtensionDecl>(declContext))) {
+ diagnose(D, diag::compiler_evaluable_bad_context);
+ attr->setInvalid();
+ return;
+ }
+ break;
+ case DeclContextKind::FileUnit:
+ // Top level functions are okay.
+ break;
+ case DeclContextKind::GenericTypeDecl:
+ switch (cast<GenericTypeDecl>(declContext)->getKind()) {
+ case DeclKind::Enum:
+ // Enums are okay, if they are compiler-representable.
+ // TODO(marcrasi): Check that it's compiler-representable.
+ break;
+ case DeclKind::Struct:
+ // Structs are okay, if they are compiler-representable.
+ // TODO(marcrasi): Check that it's compiler-representable.
+ break;
+ default:
+ diagnose(D, diag::compiler_evaluable_bad_context);
+ attr->setInvalid();
+ return;
+ }
+ break;
+ default:
+ diagnose(D, diag::compiler_evaluable_bad_context);
+ attr->setInvalid();
+ return;
+ }
+
+ // Check that the signature only has allowed types.
+ // TODO(marcrasi): Do this.
+
+ // For @compilerEvaluable to be truly valid, the function body must also
+ // follow certain rules. We can only check these rules after the body is type
+ // checked, and it's not type checked yet, so we check these rules later in
+ // TypeChecker::checkFunctionBodyCompilerEvaluable().
+}
+// SWIFT_ENABLE_TENSORFLOW END
diff --git a/lib/Sema/TypeCheckCompilerEvaluable.cpp b/lib/Sema/TypeCheckCompilerEvaluable.cpp
new file mode 100644
index 0000000..9fab534
--- /dev/null
+++ b/lib/Sema/TypeCheckCompilerEvaluable.cpp
@@ -0,0 +1,249 @@
+//===--- TypeCheckCompilerEvaluable.cpp - Check compiler evaluability -----===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// SWIFT_ENABLE_TENSORFLOW
+// Checks that function bodies follow rules for compiler evaluable functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TypeChecker.h"
+#include "swift/AST/ASTWalker.h"
+#include "swift/AST/Attr.h"
+#include "swift/AST/Decl.h"
+
+#include "llvm/Support/Debug.h"
+
+using namespace swift;
+
+namespace {
+
+/// Checks that a type is compiler representable.
+/// Currently a skeleton implementation that only rejects types named Float,
+/// Double and String.
+/// TODO(marcrasi): Fill in a real implementation.
+static bool checkCompilerRepresentable(const Type &type) {
+ return type.getString() != "Double" && type.getString() != "Float" &&
+ type.getString() != "String";
+}
+
+/// Checks that the body of a function is compiler evaluable.
+class CheckCompilerEvaluableBody : public ASTWalker {
+ ASTContext &Ctx;
+
+ // The function whose body we are checking.
+ const AbstractFunctionDecl *CheckingFunc;
+
+ // Whether the body has passed the check.
+ bool CompilerEvaluable = true;
+
+public:
+ CheckCompilerEvaluableBody(ASTContext &Ctx,
+ const AbstractFunctionDecl *CheckingFunc)
+ : Ctx(Ctx), CheckingFunc(CheckingFunc) {}
+
+ std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
+ // If this is the ignored part of a DotSyntaxBaseIgnored, then we can accept
+ // it without walking it.
+ if (auto *parentDotSyntaxBaseIgnored =
+ dyn_cast_or_null<DotSyntaxBaseIgnoredExpr>(Parent.getAsExpr()))
+ if (parentDotSyntaxBaseIgnored->getLHS() == E)
+ return {false, E};
+
+ if (!checkCompilerRepresentable(E->getType())) {
+ Ctx.Diags.diagnose(E->getLoc(), diag::compiler_evaluable_forbidden_type,
+ E->getType())
+ .highlight(E->getSourceRange());
+ CompilerEvaluable = false;
+ return {false, E};
+ }
+
+ switch (E->getKind()) {
+ #define ALWAYS_ALLOWED(ID) \
+ case ExprKind::ID: \
+ return {true, E};
+ #define SOMETIMES_ALLOWED(ID) \
+ case ExprKind::ID: \
+ return checkExpr##ID(cast<ID##Expr>(E));
+
+ ALWAYS_ALLOWED(NilLiteral)
+ ALWAYS_ALLOWED(IntegerLiteral)
+ ALWAYS_ALLOWED(BooleanLiteral)
+ ALWAYS_ALLOWED(MagicIdentifierLiteral)
+ ALWAYS_ALLOWED(DiscardAssignment)
+ SOMETIMES_ALLOWED(DeclRef)
+ ALWAYS_ALLOWED(Type)
+ SOMETIMES_ALLOWED(OtherConstructorDeclRef)
+ ALWAYS_ALLOWED(DotSyntaxBaseIgnored)
+ ALWAYS_ALLOWED(MemberRef)
+ ALWAYS_ALLOWED(Paren)
+ ALWAYS_ALLOWED(DotSelf)
+ ALWAYS_ALLOWED(Try)
+ ALWAYS_ALLOWED(ForceTry)
+ ALWAYS_ALLOWED(OptionalTry)
+ ALWAYS_ALLOWED(Tuple)
+ ALWAYS_ALLOWED(Subscript)
+ ALWAYS_ALLOWED(TupleElement)
+ ALWAYS_ALLOWED(CaptureList)
+ ALWAYS_ALLOWED(Closure)
+ ALWAYS_ALLOWED(AutoClosure)
+ ALWAYS_ALLOWED(InOut)
+ ALWAYS_ALLOWED(DynamicType)
+ ALWAYS_ALLOWED(RebindSelfInConstructor)
+ ALWAYS_ALLOWED(BindOptional)
+ ALWAYS_ALLOWED(OptionalEvaluation)
+ ALWAYS_ALLOWED(ForceValue)
+ SOMETIMES_ALLOWED(Call)
+ ALWAYS_ALLOWED(PrefixUnary)
+ ALWAYS_ALLOWED(PostfixUnary)
+ ALWAYS_ALLOWED(Binary)
+ ALWAYS_ALLOWED(DotSyntaxCall)
+ ALWAYS_ALLOWED(ConstructorRefCall)
+ ALWAYS_ALLOWED(Load)
+ ALWAYS_ALLOWED(InjectIntoOptional)
+ ALWAYS_ALLOWED(Coerce)
+ ALWAYS_ALLOWED(If)
+ ALWAYS_ALLOWED(Assign)
+ ALWAYS_ALLOWED(CodeCompletion)
+ ALWAYS_ALLOWED(EditorPlaceholder)
+
+ // Allow all errors and unchecked expressions so that we don't put errors
+ // on top of expressions that alrady have errors.
+ ALWAYS_ALLOWED(Error)
+ ALWAYS_ALLOWED(UnresolvedTypeConversion)
+ #define UNCHECKED_EXPR(ID, PARENT) ALWAYS_ALLOWED(ID)
+ #include "swift/AST/ExprNodes.def"
+
+ default:
+ Ctx.Diags.diagnose(E->getStartLoc(),
+ diag::compiler_evaluable_forbidden_expression)
+ .highlight(E->getSourceRange());
+ CompilerEvaluable = false;
+ return {false, E};
+
+ #undef ALWAYS_ALLOWED
+ #undef SOMETIMES_ALLOWED
+ }
+ }
+
+ std::pair<bool, Expr *> checkExprCall(CallExpr *call) {
+ // TODO(SR-8035): Eliminate this special case.
+ // Allow calls to some stdlib assertion functions without walking them
+ // further, because the calls do currently-forbidden things. (They use
+ // Strings and they call functions imported from C).
+ if (auto *calleeRef = dyn_cast<DeclRefExpr>(call->getDirectCallee()))
+ if (auto *callee = dyn_cast<AbstractFunctionDecl>(calleeRef->getDecl()))
+ if (callee->isChildContextOf(Ctx.TheStdlibModule) &&
+ (callee->getNameStr() == "_precondition" ||
+ callee->getNameStr() == "_preconditionFailure" ||
+ callee->getNameStr() == "_sanityCheck" ||
+ callee->getNameStr() == "fatalError"))
+ return {false, call};
+
+ // Otherwise, walk everything in the expression.
+ return {true, call};
+ }
+
+ std::pair<bool, Expr *> checkExprDeclRef(DeclRefExpr *declRef) {
+ auto *decl = declRef->getDeclRef().getDecl();
+ if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
+ // DeclRefs to immutable variables are always allowed.
+ if (varDecl->isLet())
+ return {true, declRef};
+
+ // DeclRefs to mutable variables are only allowed if they are declared
+ // within the @compilerEvaluable function.
+ if (varDecl->getDeclContext() == CheckingFunc ||
+ varDecl->getDeclContext()->isChildContextOf(CheckingFunc))
+ return {true, declRef};
+
+ Ctx.Diags.diagnose(declRef->getLoc(),
+ diag::compiler_evaluable_non_local_mutable);
+ CompilerEvaluable = false;
+ return {false, declRef};
+ } else if (auto *functionDecl = dyn_cast<AbstractFunctionDecl>(decl)) {
+ return checkAbstractFunctionDeclRef(declRef, functionDecl);
+ } else if (isa<EnumElementDecl>(decl)) {
+ return {true, declRef};
+ } else {
+ Ctx.Diags.diagnose(declRef->getLoc(),
+ diag::compiler_evaluable_forbidden_expression)
+ .highlight(declRef->getSourceRange());
+ CompilerEvaluable = false;
+ return {false, declRef};
+ }
+ }
+
+ std::pair<bool, Expr *>
+ checkExprOtherConstructorDeclRef(OtherConstructorDeclRefExpr *declRef) {
+ return checkAbstractFunctionDeclRef(declRef, declRef->getDecl());
+ }
+
+ std::pair<bool, Expr *>
+ checkAbstractFunctionDeclRef(Expr *declRef, AbstractFunctionDecl *decl) {
+ // If the function is @compilerEvaluable, allow it.
+ if (decl->getAttrs().hasAttribute<CompilerEvaluableAttr>(
+ /*AllowInvalid=*/true))
+ return {true, declRef};
+
+ // If the function is nested within the function that we are checking, allow
+ // it.
+ if (decl->isChildContextOf(CheckingFunc))
+ return {true, declRef};
+
+ // For now, allow all builtins.
+ // TODO: Mark which builtins are actually compiler evaluable.
+ if (decl->isChildContextOf(Ctx.TheBuiltinModule))
+ return {true, declRef};
+
+ // Allow all protocol methods. Later, the interpreter looks up the actual
+ // function and emits an error when it is not @compilerEvaluable.
+ if (isa<ProtocolDecl>(decl->getDeclContext()))
+ return {true, declRef};
+
+ Ctx.Diags.diagnose(declRef->getLoc(),
+ diag::compiler_evaluable_ref_non_compiler_evaluable);
+ CompilerEvaluable = false;
+ return {false, declRef};
+ }
+
+ std::pair<bool, Stmt *> walkToStmtPre(Stmt *S) override {
+ if (S->getKind() == StmtKind::While) {
+ Ctx.Diags.diagnose(S->getStartLoc(), diag::compiler_evaluable_loop);
+ CompilerEvaluable = false;
+ return {false, S};
+ }
+ return {true, S};
+ }
+
+ bool getCompilerEvaluable() const { return CompilerEvaluable; }
+};
+
+} // namespace
+
+/// If the function has a valid @compilerEvaluable attribute, checks that the
+/// function body follows all the rules for compiler evaluable functions.
+///
+/// The function body must already be type checked.
+void TypeChecker::checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D) {
+ auto compilerEvaluableAttr =
+ D->getAttrs().getAttribute<CompilerEvaluableAttr>();
+ if (!compilerEvaluableAttr || !compilerEvaluableAttr->isValid()) return;
+
+ assert(D->getBodyKind() == AbstractFunctionDecl::BodyKind::TypeChecked &&
+ "cannot check @compilerEvaluable body that is not type checked");
+
+ CheckCompilerEvaluableBody Checker(D->getASTContext(), D);
+ D->getBody()->walk(Checker);
+ if (!Checker.getCompilerEvaluable()) {
+ compilerEvaluableAttr->setInvalid();
+ }
+}
diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp
index 153c37e..61962fa 100644
--- a/lib/Sema/TypeCheckDeclOverride.cpp
+++ b/lib/Sema/TypeCheckDeclOverride.cpp
@@ -1482,6 +1482,10 @@
UNINTERESTING_ATTR(Transpose)
UNINTERESTING_ATTR(NoDerivative)
+ // SWIFT_ENABLE_TENSORFLOW
+ UNINTERESTING_ATTR(CompilerEvaluable)
+ // SWIFT_ENABLE_TENSORFLOW END
+
// These can't appear on overridable declarations.
UNINTERESTING_ATTR(Prefix)
UNINTERESTING_ATTR(Postfix)
diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp
index 60b2543..fc5c525 100644
--- a/lib/Sema/TypeCheckProtocol.cpp
+++ b/lib/Sema/TypeCheckProtocol.cpp
@@ -5760,6 +5760,32 @@
case KnownDerivableProtocolKind::OptionSet:
llvm_unreachable(
"When possible, OptionSet is derived via memberwise init synthesis");
+
+ // SWIFT_ENABLE_TENSORFLOW
+ case KnownDerivableProtocolKind::KeyPathIterable:
+ return derived.deriveKeyPathIterable(Requirement);
+
+ case KnownDerivableProtocolKind::TensorArrayProtocol:
+ return derived.deriveTensorArrayProtocol(Requirement);
+
+ case KnownDerivableProtocolKind::TensorGroup:
+ return derived.deriveTensorGroup(Requirement);
+
+ case KnownDerivableProtocolKind::PointwiseMultiplicative:
+ return derived.derivePointwiseMultiplicative(Requirement);
+
+ case KnownDerivableProtocolKind::ElementaryFunctions:
+ return derived.deriveElementaryFunctions(Requirement);
+
+ case KnownDerivableProtocolKind::VectorProtocol:
+ return derived.deriveVectorProtocol(Requirement);
+
+ case KnownDerivableProtocolKind::EuclideanDifferentiable:
+ return derived.deriveEuclideanDifferentiable(Requirement);
+
+ default:
+ return nullptr;
+ // SWIFT_ENABLE_TENSORFLOW END
}
llvm_unreachable("unknown derivable protocol kind");
}
@@ -5786,6 +5812,12 @@
return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr);
case KnownProtocolKind::Differentiable:
return derived.deriveDifferentiable(AssocType);
+ // SWIFT_ENABLE_TENSORFLOW
+ case KnownProtocolKind::KeyPathIterable:
+ return std::make_pair(derived.deriveKeyPathIterable(AssocType), nullptr);
+ case KnownProtocolKind::VectorProtocol:
+ return std::make_pair(derived.deriveVectorProtocol(AssocType), nullptr);
+ // SWIFT_ENABLE_TENSORFLOW END
default:
return std::make_pair(nullptr, nullptr);
}
diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp
index c465e86..e758492 100644
--- a/lib/Sema/TypeCheckStmt.cpp
+++ b/lib/Sema/TypeCheckStmt.cpp
@@ -1605,6 +1605,13 @@
TypeCheckFunctionBodyRequest{AFD}, true);
TypeChecker::checkFunctionEffects(AFD);
TypeChecker::computeCaptures(AFD);
+ // SWIFT_ENABLE_TENSORFLOW
+ // Check `@compilerEvaluable` function body correctness.
+ // Do this here, rather than in
+ // `AttributeChecker::visitCompilerEvaluableAttr()` because we need the
+ // function bodies to be type checked.
+ TypeChecker::checkFunctionBodyCompilerEvaluable(AFD);
+ // SWIFT_ENABLE_TENSORFLOW END
return res;
}
diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp
index 6e3dc8a..2ddef66 100644
--- a/lib/Sema/TypeCheckType.cpp
+++ b/lib/Sema/TypeCheckType.cpp
@@ -3110,7 +3110,6 @@
attrs.clearAttribute(TAK_noDerivative);
differentiability = SILParameterDifferentiability::NotDifferentiable;
}
-
type = resolveAttributedType(attrs, attrRepr->getTypeRepr(), options);
} else {
type = resolveType(repr, options);
diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h
index 356c6c6..38ffdd0 100644
--- a/lib/Sema/TypeChecker.h
+++ b/lib/Sema/TypeChecker.h
@@ -1137,6 +1137,7 @@
void checkEnumElementEffects(EnumElementDecl *D, Expr *expr);
void checkPropertyWrapperEffects(PatternBindingDecl *binding,
Expr *expr);
+void checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D);
/// If an expression references 'self.init' or 'super.init' in an
/// initializer context, returns the implicit 'self' decl of the constructor.
diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp
index 6c5d3bf..a008770 100644
--- a/lib/Serialization/Serialization.cpp
+++ b/lib/Serialization/Serialization.cpp
@@ -820,6 +820,13 @@
BLOCK_RECORD(sil_block, SIL_SPECIALIZE_ATTR);
BLOCK_RECORD(sil_block, SIL_ONE_OPERAND_EXTRA_ATTR);
BLOCK_RECORD(sil_block, SIL_TWO_OPERANDS_EXTRA_ATTR);
+ BLOCK_RECORD(sil_block, SIL_INST_DIFFERENTIABLE_FUNCTION);
+ BLOCK_RECORD(sil_block, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT);
+ // SWIFT_ENABLE_TENSORFLOW
+ BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION);
+ BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION_EXTRACT);
+ // SWIFT_ENABLE_TENSORFLOW END
+ BLOCK_RECORD(sil_block, SIL_DIFFERENTIABILITY_WITNESS);
// These layouts can exist in both decl blocks and sil blocks.
#define BLOCK_RECORD_WITH_NAMESPACE(K, X) emitRecordID(X, #X, nameBuffer)
@@ -5290,6 +5297,32 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+void swift::serializeToMemory(
+ ModuleOrSourceFile DC, const SerializationOptions &options,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleBuffer,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleDocBuffer, const SILModule *M) {
+ if (moduleBuffer) {
+ auto name = "Serialization, swiftmodule, to memory";
+ llvm::NamedRegionTimer timer(name, name, "Swift", "Swift compilation");
+ llvm::SmallString<1024> buf;
+ llvm::raw_svector_ostream stream(buf);
+ Serializer::writeToStream(stream, DC, M, options);
+ *moduleBuffer =
+ std::make_unique<llvm::SmallVectorMemoryBuffer>(std::move(buf));
+ }
+
+ if (moduleDocBuffer) {
+ auto name = "Serialization, swiftdoc, to memory";
+ llvm::NamedRegionTimer timer(name, name, "Swift", "Swift compilation");
+ llvm::SmallString<1024> buf;
+ llvm::raw_svector_ostream stream(buf);
+ writeDocToStream(stream, DC, options.GroupInfoPath);
+ *moduleDocBuffer =
+ std::make_unique<llvm::SmallVectorMemoryBuffer>(std::move(buf));
+ }
+}
+
void swift::serialize(ModuleOrSourceFile DC,
const SerializationOptions &options,
const SILModule *M) {
diff --git a/stdlib/cmake/modules/AddSwiftStdlib.cmake b/stdlib/cmake/modules/AddSwiftStdlib.cmake
index 32a4214..1cff522 100644
--- a/stdlib/cmake/modules/AddSwiftStdlib.cmake
+++ b/stdlib/cmake/modules/AddSwiftStdlib.cmake
@@ -1755,7 +1755,6 @@
endif()
endif()
-
# Collect architecture agnostic SDK linker flags
set(swiftlib_link_flags_all ${SWIFTLIB_LINK_FLAGS})
if(${sdk} STREQUAL IOS_SIMULATOR AND ${name} STREQUAL swiftMediaPlayer)
diff --git a/stdlib/cmake/modules/SwiftSource.cmake b/stdlib/cmake/modules/SwiftSource.cmake
index 170d891..0f2cf1e 100644
--- a/stdlib/cmake/modules/SwiftSource.cmake
+++ b/stdlib/cmake/modules/SwiftSource.cmake
@@ -401,7 +401,16 @@
# The standard library and overlays are always built resiliently.
if(SWIFTFILE_IS_STDLIB)
- list(APPEND swift_flags "-enable-library-evolution")
+ # SWIFT_ENABLE_TENSORFLOW
+ # FIXME(TF-328): Resilience is currently disabled for the TensorFlow module
+ # because it causes compilation to crash during IRGen.
+ # Also, disable resilience for DifferentiationUnittest because resilience
+ # changes generated AD code, leading to additional leaks.
+ if(NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "TensorFlow" AND
+ NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "DifferentiationUnittest")
+ list(APPEND swift_flags "-enable-library-evolution")
+ endif()
+ # SWIFT_ENABLE_TENSORFLOW END
endif()
if(SWIFT_STDLIB_SINGLE_THREADED_RUNTIME)
@@ -523,8 +532,8 @@
endif()
if (NOT SWIFTFILE_IS_STDLIB_CORE)
- list(APPEND swift_module_flags
- "-Xfrontend" "-experimental-skip-non-inlinable-function-bodies")
+ # list(APPEND swift_module_flags
+ # "-Xfrontend" "-experimental-skip-non-inlinable-function-bodies")
endif()
set(module_outputs "${module_file}" "${module_doc_file}")
diff --git a/stdlib/linker-support/magic-symbols-for-install-name.c b/stdlib/linker-support/magic-symbols-for-install-name.c
index dc15419..6946ff7 100644
--- a/stdlib/linker-support/magic-symbols-for-install-name.c
+++ b/stdlib/linker-support/magic-symbols-for-install-name.c
@@ -83,6 +83,10 @@
// treat macOS 10.14 as an "older OS."
#if __MAC_OS_X_VERSION_MIN_REQUIRED < __MAC_10_14
RPATH_INSTALL_NAME_DIRECTIVE(10, 14)
+ // SWIFT_ENABLE_TENSORFLOW
+ // For TensorFlow, keep using @rpath instead of system paths
+ RPATH_INSTALL_NAME_DIRECTIVE(10, 15)
+ // SWIFT_ENABLE_TENSORFLOW END
#endif
#else
diff --git a/stdlib/public/Darwin/CoreGraphics/CGFloat.swift.gyb b/stdlib/public/Darwin/CoreGraphics/CGFloat.swift.gyb
index e682a99..f76665a 100644
--- a/stdlib/public/Darwin/CoreGraphics/CGFloat.swift.gyb
+++ b/stdlib/public/Darwin/CoreGraphics/CGFloat.swift.gyb
@@ -513,6 +513,66 @@
}
//===----------------------------------------------------------------------===//
+// Real conformance
+//===----------------------------------------------------------------------===//
+
+%from SwiftMathFunctions import *
+
+extension CGFloat: ElementaryFunctions {
+% for func in ElementaryFunctions + RealFunctions:
+
+ @_alwaysEmitIntoClient
+ public static func ${func.decl('CGFloat')} {
+ return CGFloat(NativeType.${func.swiftName}(${func.params("", ".native")}))
+ }
+% end
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: CGFloat, _ y: CGFloat) -> CGFloat {
+ guard x >= 0 else { return .nan }
+ return CGFloat(NativeType.pow(x.native, y.native))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: CGFloat, _ n: Int) -> CGFloat {
+ // TODO: this implementation isn't quite right for n so large that
+ // the conversion to `CGFloat` rounds. We could also consider using
+ // a multiply-chain implementation for small `n`; this would be faster
+ // for static `n`, but less accurate on platforms with a good `pow`
+ // implementation.
+ return CGFloat(NativeType.pow(x.native, n))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func root(_ x: CGFloat, _ n: Int) -> CGFloat {
+ guard x >= 0 || n % 2 != 0 else { return .nan }
+ // TODO: this implementation isn't quite right for n so large that
+ // the conversion to `CGFloat` rounds.
+ return CGFloat(NativeType.root(x.native, n))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func atan2(_ y: CGFloat, _ x: CGFloat) -> CGFloat {
+ return CGFloat(NativeType.atan2(y.native, x.native))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func logGamma(_ x: CGFloat) -> CGFloat {
+ return CGFloat(NativeType.logGamma(x.native))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func signGamma(_ x: CGFloat) -> FloatingPointSign {
+ if x >= 0 { return .plus }
+ let trunc = x.rounded(.towardZero)
+ if x == trunc { return .plus }
+ let halfTrunc = trunc/2
+ if halfTrunc == halfTrunc.rounded(.towardZero) { return .minus }
+ return .plus
+ }
+}
+
+//===----------------------------------------------------------------------===//
// tgmath
//===----------------------------------------------------------------------===//
@@ -532,10 +592,18 @@
}%
%for ufunc in UnaryFunctions:
+% if ufunc in ['rint','nearbyint']:
+@available(swift, deprecated: 5.1, message: "Swift does not model dynamic rounding modes, use x.rounded(.toNearestOrEven) instead.")
+@_transparent
+public func ${ufunc}(_ x: CGFloat) -> CGFloat {
+ return x.rounded(.toNearestOrEven)
+}
+% else:
@_transparent
public func ${ufunc}(_ x: CGFloat) -> CGFloat {
return CGFloat(${ufunc}(x.native))
}
+% end
%end
diff --git a/stdlib/public/Differentiation/AnyDifferentiable.swift b/stdlib/public/Differentiation/AnyDifferentiable.swift
index 3a11610..090a0f0 100644
--- a/stdlib/public/Differentiation/AnyDifferentiable.swift
+++ b/stdlib/public/Differentiation/AnyDifferentiable.swift
@@ -15,8 +15,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
//===----------------------------------------------------------------------===//
// `AnyDifferentiable`
//===----------------------------------------------------------------------===//
diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift
index 046af6f..26eda34 100644
--- a/stdlib/public/Differentiation/ArrayDifferentiation.swift
+++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift
@@ -10,8 +10,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//
@@ -90,6 +88,16 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+extension Array.DifferentiableView: EuclideanDifferentiable
+where Element: EuclideanDifferentiable {
+ public var differentiableVectorView: Array.DifferentiableView.TangentVector {
+ Array.DifferentiableView.TangentVector(
+ base.map { $0.differentiableVectorView })
+ }
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
extension Array.DifferentiableView: Equatable
where Element: Differentiable & Equatable {
public static func == (
@@ -193,6 +201,15 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+extension Array: EuclideanDifferentiable
+where Element: EuclideanDifferentiable {
+ public var differentiableVectorView: TangentVector {
+ TangentVector(map { $0.differentiableVectorView })
+ }
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//
diff --git a/stdlib/public/Differentiation/CMakeLists.txt b/stdlib/public/Differentiation/CMakeLists.txt
index 6e0257a..0b44efa 100644
--- a/stdlib/public/Differentiation/CMakeLists.txt
+++ b/stdlib/public/Differentiation/CMakeLists.txt
@@ -10,18 +10,18 @@
#
#===----------------------------------------------------------------------===#
-add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
- Differentiable.swift
- DifferentialOperators.swift
- DifferentiationUtilities.swift
- AnyDifferentiable.swift
- ArrayDifferentiation.swift
- OptionalDifferentiation.swift
+# SWIFT_ENABLE_TENSORFLOW
+# NOTE: A non-empty `_Differentiation` module is currently created only on
+# master branch, not on tensorflow branch.
+#
+# Instead, on tensorflow branch, the differentiation-related Swift source files
+# in this directory are directly built as part of swiftCore: see
+# stdlib/public/core/CMakeLists.txt. The `_Differentiation` module is created
+# empty to enable `#if canImport(_Differentiation)` guards in tests.
+# SWIFT_ENABLE_TENSORFLOW END
- GYB_SOURCES
- FloatingPointDifferentiation.swift.gyb
- TgmathDerivatives.swift.gyb
- SIMDDifferentiation.swift.gyb
+add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
+ Empty.swift
SWIFT_MODULE_DEPENDS_OSX Darwin
SWIFT_MODULE_DEPENDS_IOS Darwin
diff --git a/stdlib/public/Differentiation/Differentiable.swift b/stdlib/public/Differentiation/Differentiable.swift
index 1341e03..cabe927 100644
--- a/stdlib/public/Differentiation/Differentiable.swift
+++ b/stdlib/public/Differentiation/Differentiable.swift
@@ -20,8 +20,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol Differentiable {
diff --git a/stdlib/public/Differentiation/DifferentialOperators.swift b/stdlib/public/Differentiation/DifferentialOperators.swift
index 2435e55..0e6e2d8 100644
--- a/stdlib/public/Differentiation/DifferentialOperators.swift
+++ b/stdlib/public/Differentiation/DifferentialOperators.swift
@@ -14,8 +14,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
// Transpose
@inlinable
diff --git a/stdlib/public/Differentiation/DifferentiationSupport.swift b/stdlib/public/Differentiation/DifferentiationSupport.swift
new file mode 100644
index 0000000..95edc67
--- /dev/null
+++ b/stdlib/public/Differentiation/DifferentiationSupport.swift
@@ -0,0 +1,220 @@
+//===--- DifferentiationSupport.swift -------------------------*- swift -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// SWIFT_ENABLE_TENSORFLOW
+//
+// This file defines not-yet-upstreamed support for differentiable programming
+// and deep learning APIs.
+//
+//===----------------------------------------------------------------------===//
+
+infix operator .* : MultiplicationPrecedence
+infix operator .*= : AssignmentPrecedence
+
+//===----------------------------------------------------------------------===//
+// Compiler Protocols
+//===----------------------------------------------------------------------===//
+
+/// A type with values that support pointwise multiplication.
+// TODO: Add API documentation.
+public protocol PointwiseMultiplicative : AdditiveArithmetic {
+ /// The one value.
+ ///
+ /// One is the identity element for multiplication. For any value,
+ /// `x .* .one == x` and `.one .* x == x`.
+ static var one: Self { get }
+
+ /// The multiplicative inverse of self.
+ ///
+ /// For any value, `x .* x.reciprocal == .one` and
+ /// `x.reciprocal .* x == .one`.
+ var reciprocal: Self { get }
+
+ /// Multiplies two values and produces their product.
+ ///
+ /// - Parameters:
+ /// - lhs: The first value to multiply.
+ /// - rhs: The second value to multiply.
+ static func .*(lhs: Self, rhs: Self) -> Self
+
+ /// Multiplies two values and produces their product.
+ ///
+ /// - Parameters:
+ /// - lhs: The first value to multiply.
+ /// - rhs: The second value to multiply.
+ static func .*=(lhs: inout Self, rhs: Self)
+}
+
+public extension PointwiseMultiplicative {
+ static func .*=(lhs: inout Self, rhs: Self) {
+ lhs = lhs .* rhs
+ }
+}
+
+public extension PointwiseMultiplicative
+ where Self : ExpressibleByIntegerLiteral {
+ static var one: Self {
+ return 1
+ }
+}
+
+/// A type that represents an unranked vector space. Values of this type are
+/// elements in this vector space and have either no shape or a static shape.
+public protocol VectorProtocol : AdditiveArithmetic {
+ /// The type of scalars in the vector space.
+ associatedtype VectorSpaceScalar : AdditiveArithmetic
+
+ func adding(_ x: VectorSpaceScalar) -> Self
+
+ mutating func add(_ x: VectorSpaceScalar)
+
+ func subtracting(_ x: VectorSpaceScalar) -> Self
+
+ mutating func subtract(_ x: VectorSpaceScalar)
+
+ /// Returns `self` multiplied by the given scalar.
+ func scaled(by scalar: VectorSpaceScalar) -> Self
+
+ /// Multiplies `self` by the given scalar.
+ mutating func scale(by scalar: VectorSpaceScalar)
+}
+
+public extension VectorProtocol {
+ mutating func add(_ x: VectorSpaceScalar) {
+ self = adding(x)
+ }
+
+ mutating func subtract(_ x: VectorSpaceScalar) {
+ self = subtracting(x)
+ }
+
+ mutating func scale(by scalar: VectorSpaceScalar) {
+ self = scaled(by: scalar)
+ }
+}
+
+/*
+// Note: These default-implemented operators will slow down type-checking
+// performance and break existing code.
+
+public extension VectorProtocol {
+ static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
+ lhs.adding(rhs)
+ }
+
+ static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self {
+ rhs.adding(lhs)
+ }
+
+ static func += (lhs: inout Self, rhs: VectorSpaceScalar) {
+ lhs.add(rhs)
+ }
+
+ static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self {
+ lhs.subtracting(rhs)
+ }
+
+ static func -= (lhs: inout Self, rhs: VectorSpaceScalar) {
+ lhs.subtract(rhs)
+ }
+
+ static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
+ lhs.scaled(by: rhs)
+ }
+
+ static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
+ rhs.scaled(by: lhs)
+ }
+
+ static func *= (lhs: inout Self, rhs: VectorSpaceScalar) {
+ lhs.scale(by: rhs)
+ }
+}
+
+public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
+ static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self {
+ -rhs.adding(lhs)
+ }
+
+ static prefix func - (x: Self) -> Self {
+ .zero - x
+ }
+}
+*/
+
+/// A type that is differentiable in the Euclidean space.
+/// The type may represent a vector space, or consist of a vector space and some
+/// other non-differentiable component.
+///
+/// Mathematically, this represents a product manifold that consists of
+/// a differentiable vector space and some arbitrary manifold, where the tangent
+/// bundle of the entire product manifold is equal to the vector space
+/// component.
+///
+/// This abstraction is useful for representing common differentiable data
+/// structures that contain both differentiable vector properties and other
+/// stored properties that do not have a derivative, e.g.
+///
+/// ```swift
+/// struct Perceptron: @memberwise EuclideanDifferentiable {
+/// var weight: SIMD16<Float>
+/// var bias: Float
+/// @noDerivative var useBias: Bool
+/// }
+/// ```
+///
+/// - Note: Conform a type to `EuclideanDifferentiable` if it is differentiable
+/// only with respect to its vector space component and when its
+/// `TangentVector` is equal to its vector space component.
+public protocol EuclideanDifferentiable: Differentiable {
+ /// The differentiable vector component of `self`.
+ var differentiableVectorView: TangentVector { get }
+}
+
+public extension EuclideanDifferentiable where TangentVector == Self {
+ var differentiableVectorView: TangentVector { _read { yield self } }
+}
+
+//===----------------------------------------------------------------------===//
+// Functional utilities
+//===----------------------------------------------------------------------===//
+
+/// Make a function be recomputed in its pullback, known as "checkpointing" in
+/// traditional automatic differentiation.
+@inlinable
+public func withRecomputationInPullbacks<T, U>(
+ _ body: @escaping @differentiable (T) -> U
+) -> @differentiable (T) -> U where T : Differentiable, U : Differentiable {
+ return differentiableFunction { x in
+ (value: body(x), pullback: { v in pullback(at: x, in: body)(v) })
+ }
+}
+
+public extension Differentiable {
+ @inlinable
+ @differentiable(wrt: self)
+ func withRecomputationInPullbacks<Result : Differentiable>(
+ _ body: @escaping @differentiable (Self) -> Result
+ ) -> Result {
+ return body(self)
+ }
+
+ @inlinable
+ @derivative(of: withRecomputationInPullbacks)
+ internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
+ _ body: @escaping @differentiable (Self) -> Result
+ ) -> (value: Result, pullback: (Result.TangentVector) -> TangentVector) {
+ return Swift.valueWithPullback(
+ at: self, in: Swift.withRecomputationInPullbacks(body)
+ )
+ }
+}
diff --git a/stdlib/public/Differentiation/DifferentiationUtilities.swift b/stdlib/public/Differentiation/DifferentiationUtilities.swift
index d97af0c..4bbab96 100644
--- a/stdlib/public/Differentiation/DifferentiationUtilities.swift
+++ b/stdlib/public/Differentiation/DifferentiationUtilities.swift
@@ -15,8 +15,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
//===----------------------------------------------------------------------===//
// Differentiable function creation
//===----------------------------------------------------------------------===//
diff --git a/stdlib/public/Differentiation/Empty.swift b/stdlib/public/Differentiation/Empty.swift
new file mode 100644
index 0000000..8160c57
--- /dev/null
+++ b/stdlib/public/Differentiation/Empty.swift
@@ -0,0 +1,3 @@
+// SWIFT_ENABLE_TENSORFLOW
+// Empty Swift file, only for tensorflow branch.
+// See explanation in stdlib/public/Differentiation/CMakeLists.txt.
diff --git a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb
index f2ff2fc..26bad25 100644
--- a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb
+++ b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb
@@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
import SwiftShims
% from SwiftFloatingPointTypes import all_floating_point_types
@@ -52,6 +51,50 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+${Availability(bits)}
+extension ${Self} : EuclideanDifferentiable {}
+// SWIFT_ENABLE_TENSORFLOW END
+
+// SWIFT_ENABLE_TENSORFLOW
+
+${Availability(bits)}
+extension ${Self} : VectorProtocol {
+ ${Availability(bits)}
+ public typealias VectorSpaceScalar = ${Self}
+
+ ${Availability(bits)}
+ public func adding(_ x: ${Self}) -> ${Self} {
+ self + x
+ }
+
+ ${Availability(bits)}
+ public mutating func add(_ x: ${Self}) {
+ self += x
+ }
+
+ ${Availability(bits)}
+ public func subtracting(_ x: ${Self}) -> ${Self} {
+ self - x
+ }
+
+ ${Availability(bits)}
+ public mutating func subtract(_ x: ${Self}) {
+ self -= x
+ }
+
+ ${Availability(bits)}
+ public func scaled(by scalar: ${Self}) -> ${Self} {
+ self * scalar
+ }
+
+ ${Availability(bits)}
+ public mutating func scale(by scalar: ${Self}) {
+ self *= scalar
+ }
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//
diff --git a/stdlib/public/Differentiation/OptionalDifferentiation.swift b/stdlib/public/Differentiation/OptionalDifferentiation.swift
index c8aa82f..5e6afce 100644
--- a/stdlib/public/Differentiation/OptionalDifferentiation.swift
+++ b/stdlib/public/Differentiation/OptionalDifferentiation.swift
@@ -10,8 +10,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
extension Optional: Differentiable where Wrapped: Differentiable {
public struct TangentVector: Differentiable, AdditiveArithmetic {
public typealias TangentVector = Self
diff --git a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb
index d134867..f875fe2 100644
--- a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb
+++ b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb
@@ -10,8 +10,6 @@
//
//===----------------------------------------------------------------------===//
-import Swift
-
%{
storagescalarCounts = [2,4,8,16,32,64]
vectorscalarCounts = storagescalarCounts + [3]
@@ -38,6 +36,15 @@
}
}
+// SWIFT_ENABLE_TENSORFLOW
+extension SIMD${n}: EuclideanDifferentiable
+where
+ Scalar: EuclideanDifferentiable & BinaryFloatingPoint,
+ Scalar.TangentVector: BinaryFloatingPoint
+{
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//
@@ -412,9 +419,6 @@
}
}
-// FIXME(TF-1103): Derivative registration does not yet support
-// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
-/*
extension SIMD
where
Self: Differentiable,
@@ -439,7 +443,6 @@
return (sum(), { v in Scalar.TangentVector(v.sum()) })
}
}
-*/
extension SIMD
where
diff --git a/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb b/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb
index 1a77090..f7f7df0 100644
--- a/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb
+++ b/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb
@@ -12,6 +12,13 @@
// This file defines derivatives for tgmath functions.
//===----------------------------------------------------------------------===//
+// SWIFT_ENABLE_TENSORFLOW
+// `_Differentiation` module sources are currently compiled as part of the core
+// standard library on tensorflow branch, so import declarations are not
+// necessary.
+#if false
+// SWIFT_ENABLE_TENSORFLOW END
+
import Swift
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
@@ -24,6 +31,10 @@
#error("Unsupported platform")
#endif
+// SWIFT_ENABLE_TENSORFLOW
+#endif
+// SWIFT_ENABLE_TENSORFLOW END
+
@usableFromInline
@derivative(of: fma)
func _jvpFma<T: FloatingPoint & Differentiable> (
diff --git a/stdlib/public/Platform/CMakeLists.txt b/stdlib/public/Platform/CMakeLists.txt
index ae972bc..540e750 100644
--- a/stdlib/public/Platform/CMakeLists.txt
+++ b/stdlib/public/Platform/CMakeLists.txt
@@ -2,6 +2,9 @@
Platform.swift
TiocConstants.swift)
set(swift_platform_gyb_sources
+ # SWIFT_ENABLE_TENSORFLOW_START
+ ../Differentiation/TgmathDerivatives.swift.gyb
+ # SWIFT_ENABLE_TENSORFLOW_END
tgmath.swift.gyb)
set(darwin_depends)
diff --git a/stdlib/public/SwiftShims/LibcShims.h b/stdlib/public/SwiftShims/LibcShims.h
index 6abe688..ef8be44 100644
--- a/stdlib/public/SwiftShims/LibcShims.h
+++ b/stdlib/public/SwiftShims/LibcShims.h
@@ -163,6 +163,319 @@
long double lgammal_r(long double x, int *psigngam);
#endif // defined(__APPLE__)
+// SWIFT_ENABLE_TENSORFLOW
+// These changes were part of `ElementaryFunctions`, which was removed from
+// apple/swift master branch and moved to apple/swift-numerics.
+// TF-1203 tracks eliminating these ad-hoc tensorflow branch changes.
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_tanf(float x) {
+ return __builtin_tanf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_acosf(float x) {
+ return __builtin_acosf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_asinf(float x) {
+ return __builtin_asinf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_atanf(float x) {
+ return __builtin_atanf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_atan2f(float y, float x) {
+ return __builtin_atan2f(y, x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_coshf(float x) {
+ return __builtin_coshf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_sinhf(float x) {
+ return __builtin_sinhf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_tanhf(float x) {
+ return __builtin_tanhf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_acoshf(float x) {
+ return __builtin_acoshf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_asinhf(float x) {
+ return __builtin_asinhf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_atanhf(float x) {
+ return __builtin_atanhf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_exp10f(float x) {
+#if defined __APPLE__
+ extern float __exp10f(float);
+ return __exp10f(x);
+#else
+ return __builtin_powf(10, x);
+#endif
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_expm1f(float x) {
+ return __builtin_expm1f(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_log1pf(float x) {
+ return __builtin_log1pf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_hypotf(float x, float y) {
+#if defined(_WIN32)
+ extern float _hypotf(float, float);
+ return _hypotf(x, y);
+#else
+ return __builtin_hypotf(x, y);
+#endif
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_erff(float x) {
+ return __builtin_erff(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_erfcf(float x) {
+ return __builtin_erfcf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_tgammaf(float x) {
+ return __builtin_tgammaf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+float _stdlib_lgammaf(float x) {
+ extern float lgammaf_r(float x, int *psigngam);
+ int dontCare;
+ return lgammaf_r(x, &dontCare);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_tan(double x) {
+ return __builtin_tan(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_acos(double x) {
+ return __builtin_acos(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_asin(double x) {
+ return __builtin_asin(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_atan(double x) {
+ return __builtin_atan(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_atan2(double y, double x) {
+ return __builtin_atan2(y, x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_cosh(double x) {
+ return __builtin_cosh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_sinh(double x) {
+ return __builtin_sinh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_tanh(double x) {
+ return __builtin_tanh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_acosh(double x) {
+ return __builtin_acosh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_asinh(double x) {
+ return __builtin_asinh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_atanh(double x) {
+ return __builtin_atanh(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_exp10(double x) {
+#if defined __APPLE__
+ extern double __exp10(double);
+ return __exp10(x);
+#else
+ return __builtin_pow(10, x);
+#endif
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_expm1(double x) {
+ return __builtin_expm1(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_log1p(double x) {
+ return __builtin_log1p(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_hypot(double x, double y) {
+ return __builtin_hypot(x, y);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_erf(double x) {
+ return __builtin_erf(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_erfc(double x) {
+ return __builtin_erfc(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_tgamma(double x) {
+ return __builtin_tgamma(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+double _stdlib_lgamma(double x) {
+ extern double lgamma_r(double x, int *psigngam);
+ int dontCare;
+ return lgamma_r(x, &dontCare);
+}
+
+#if !(defined(_WIN32) || defined(ANDROID)) && (defined(__i386__) || defined(__x86_64__))
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_tanl(long double x) {
+ return __builtin_tanl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_acosl(long double x) {
+ return __builtin_acosl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_asinl(long double x) {
+ return __builtin_asinl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_atanl(long double x) {
+ return __builtin_atanl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_atan2l(long double y, long double x) {
+ return __builtin_atan2l(y, x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_coshl(long double x) {
+ return __builtin_coshl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_sinhl(long double x) {
+ return __builtin_sinhl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_tanhl(long double x) {
+ return __builtin_tanhl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_acoshl(long double x) {
+ return __builtin_acoshl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_asinhl(long double x) {
+ return __builtin_asinhl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_atanhl(long double x) {
+ return __builtin_atanhl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_exp10l(long double x) {
+ return __builtin_powl(10, x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_expm1l(long double x) {
+ return __builtin_expm1l(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_log1pl(long double x) {
+ return __builtin_log1pl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_hypotl(long double x, long double y) {
+ return __builtin_hypotl(x, y);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_erfl(long double x) {
+ return __builtin_erfl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_erfcl(long double x) {
+ return __builtin_erfcl(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_tgammal(long double x) {
+ return __builtin_tgammal(x);
+}
+
+static inline SWIFT_ALWAYS_INLINE
+long double _stdlib_lgammal(long double x) {
+ extern long double lgammal_r(long double x, int *psigngam);
+ int dontCare;
+ return lgammal_r(x, &dontCare);
+}
+#endif // !(defined(_WIN32) || defined(ANDROID)) && (defined(__i386__) || defined(__x86_64__))
+// SWIFT_ENABLE_TENSORFLOW END
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/stdlib/public/core/CMakeLists.txt b/stdlib/public/core/CMakeLists.txt
index ad74e0e..7733351 100644
--- a/stdlib/public/core/CMakeLists.txt
+++ b/stdlib/public/core/CMakeLists.txt
@@ -86,6 +86,9 @@
Integers.swift
Join.swift
KeyPath.swift
+ # SWIFT_ENABLE_TENSORFLOW
+ KeyPathIterable.swift
+ # SWIFT_ENABLE_TENSORFLOW END
KeyValuePairs.swift
LazyCollection.swift
LazySequence.swift
@@ -201,6 +204,26 @@
UnsafeRawBufferPointer.swift.gyb
)
+# SWIFT_ENABLE_TENSORFLOW
+# Compile differentiable programming sources only if enabled.
+set(SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES)
+set(SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_GYB_SOURCES)
+if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
+ list(APPEND SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES
+ ../Differentiation/Differentiable.swift
+ ../Differentiation/DifferentialOperators.swift
+ ../Differentiation/DifferentiationUtilities.swift
+ ../Differentiation/DifferentiationSupport.swift
+ ../Differentiation/AnyDifferentiable.swift
+ ../Differentiation/ArrayDifferentiation.swift
+ ../Differentiation/OptionalDifferentiation.swift)
+ list(APPEND SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_GYB_SOURCES
+ ../Differentiation/FloatingPointDifferentiation.swift.gyb
+ ../Differentiation/SIMDDifferentiation.swift.gyb)
+ message(STATUS "Differentiable programming standard library additions enabled.")
+endif()
+# SWIFT_ENABLE_TENSORFLOW END
+
# The complete list of sources in the core standard library. Includes
# all the essential sources listed above.
set(SWIFTLIB_SOURCES
@@ -223,10 +246,17 @@
VarArgs.swift
Zip.swift
"${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c"
+ # SWIFT_ENABLE_TENSORFLOW
+ ${SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES}
+ # SWIFT_ENABLE_TENSORFLOW END
)
set(SWIFTLIB_GYB_SOURCES
${SWIFTLIB_ESSENTIAL_GYB_SOURCES}
+ # SWIFT_ENABLE_TENSORFLOW
+ MathFunctions.swift.gyb
+ ${SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_GYB_SOURCES}
+ # SWIFT_ENABLE_TENSORFLOW END
SIMDVectorTypes.swift.gyb
Tuple.swift.gyb
)
diff --git a/stdlib/public/core/GroupInfo.json b/stdlib/public/core/GroupInfo.json
index 8b4fbca..e28c328 100644
--- a/stdlib/public/core/GroupInfo.json
+++ b/stdlib/public/core/GroupInfo.json
@@ -149,7 +149,8 @@
"MemoryLayout.swift",
],
"KeyPaths": [
- "KeyPath.swift"
+ "KeyPath.swift",
+ "KeyPathIterable.swift"
],
"Reflection": [
"Dump.swift",
@@ -161,6 +162,7 @@
"Math": [
"SetAlgebra.swift",
"BuiltinMath.swift",
+ "MathFunctions.swift",
{
"Integers": [
"Integers.swift",
@@ -234,5 +236,17 @@
],
"Result": [
"Result.swift"
+ ],
+ "DifferentiableProgramming": [
+ "Differentiable.swift",
+ "DifferentialOperators.swift",
+ "DifferentiationUtilities.swift",
+ "DifferentiationSupport.swift",
+ "AnyDifferentiable.swift",
+ "ArrayDifferentiation.swift",
+ "FloatingPointDifferentiation.swift",
+ "OptionalDifferentiation.swift",
+ "SIMDDifferentiation.swift",
+ "TgmathDerivatives.swift"
]
}
diff --git a/stdlib/public/core/KeyPathIterable.swift b/stdlib/public/core/KeyPathIterable.swift
new file mode 100644
index 0000000..97aeb25
--- /dev/null
+++ b/stdlib/public/core/KeyPathIterable.swift
@@ -0,0 +1,151 @@
+//===-- KeyPathIterable.swift ---------------------------------*- swift -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the KeyPathIterable protocol.
+//
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// KeyPathIterable
+//===----------------------------------------------------------------------===//
+
+/// An implementation detail of `KeyPathIterable`; do not use this protocol
+/// directly.
+public protocol _KeyPathIterableBase {
+ var _allKeyPathsTypeErased: [AnyKeyPath] { get }
+ var _recursivelyAllKeyPathsTypeErased: [AnyKeyPath] { get }
+}
+
+/// A type whose values provides custom key paths to properties or elements.
+public protocol KeyPathIterable: _KeyPathIterableBase {
+ /// A type that can represent a collection of all key paths of this type.
+ associatedtype AllKeyPaths: Collection
+ where AllKeyPaths.Element == PartialKeyPath<Self>
+
+ /// A collection of all custom key paths of this value.
+ var allKeyPaths: AllKeyPaths { get }
+}
+
+public extension KeyPathIterable {
+ /// An array of all custom key paths of this value and any custom key paths
+ /// nested within each of what this value's key paths refers to.
+ var recursivelyAllKeyPaths: [PartialKeyPath<Self>] {
+ var result: [PartialKeyPath<Self>] = []
+ for kp in allKeyPaths {
+ result.append(kp)
+ if let nested = self[keyPath: kp] as? _KeyPathIterableBase {
+ for nkp in nested._recursivelyAllKeyPathsTypeErased {
+ result.append(kp.appending(path: nkp)!)
+ }
+ }
+ }
+ return result
+ }
+}
+
+public extension KeyPathIterable {
+ var _allKeyPathsTypeErased: [AnyKeyPath] {
+ return allKeyPaths.map { $0 as AnyKeyPath }
+ }
+ var _recursivelyAllKeyPathsTypeErased: [AnyKeyPath] {
+ return recursivelyAllKeyPaths.map { $0 as AnyKeyPath }
+ }
+}
+
+public extension KeyPathIterable {
+ /// Returns an array of all custom key paths of this value, to the specified
+ /// type.
+ func allKeyPaths<T>(to _: T.Type) -> [KeyPath<Self, T>] {
+ return allKeyPaths.compactMap { $0 as? KeyPath<Self, T> }
+ }
+
+ /// Returns an array of all custom key paths of this value and any custom key
+ /// paths nested within each of what this value's key paths refers to, to
+ /// the specified type.
+ func recursivelyAllKeyPaths<T>(to _: T.Type) -> [KeyPath<Self, T>] {
+ return recursivelyAllKeyPaths.compactMap { $0 as? KeyPath<Self, T> }
+ }
+
+ /// Returns an array of all custom writable key paths of this value, to the
+ /// specified type.
+ func allWritableKeyPaths<T>(to _: T.Type) -> [WritableKeyPath<Self, T>] {
+ return allKeyPaths(to: T.self)
+ .compactMap { $0 as? WritableKeyPath<Self, T> }
+ }
+
+ /// Returns an array of all custom writable key paths of this value and any
+ /// custom writable key paths nested within each of what this value's key
+ /// paths refers to, to the specified type.
+ func recursivelyAllWritableKeyPaths<T>(
+ to _: T.Type
+ ) -> [WritableKeyPath<Self, T>] {
+ return recursivelyAllKeyPaths(to: T.self)
+ .compactMap { $0 as? WritableKeyPath<Self, T> }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Collection conformances
+//===----------------------------------------------------------------------===//
+
+/// Returns `true` if all of the given key paths are instances of
+/// `WritableKeyPath<Root, Value>`.
+private func areWritable<Root, Value>(
+ _ keyPaths: [PartialKeyPath<Root>], valueType: Value.Type
+) -> Bool {
+ return !keyPaths.contains(
+ where: { kp in !(kp is WritableKeyPath<Root, Value>) }
+ )
+}
+
+extension Array: KeyPathIterable {
+ public typealias AllKeyPaths = [PartialKeyPath<Array>]
+ public var allKeyPaths: [PartialKeyPath<Array>] {
+ let result = indices.map { \Array[$0] }
+ _internalInvariant(areWritable(result, valueType: Element.self))
+ return result
+ }
+}
+
+// TODO(TF-938): Remove this conformance after removing
+// `Element: Differentiable` requirement.
+//
+// Currently necessary to avoid error:
+//
+// error: conditional conformance of type 'Array<Element>.DifferentiableView'
+// to protocol 'KeyPathIterable' does not imply conformance to inherited
+// protocol '_KeyPathIterableBase'.
+extension Array.DifferentiableView: _KeyPathIterableBase
+where Element: Differentiable {}
+
+// TODO(TF-938): Remove `Element: Differentiable` requirement.
+extension Array.DifferentiableView: KeyPathIterable
+where Element: Differentiable {
+ public typealias AllKeyPaths = [PartialKeyPath<Array.DifferentiableView>]
+ public var allKeyPaths: [PartialKeyPath<Array.DifferentiableView>] {
+ let result = [\Array.DifferentiableView.base]
+ _internalInvariant(areWritable(result, valueType: Array.self))
+ return result
+ }
+}
+
+extension Dictionary: KeyPathIterable {
+ public typealias AllKeyPaths = [PartialKeyPath<Dictionary>]
+ public var allKeyPaths: [PartialKeyPath<Dictionary>] {
+ // Note: `Dictionary.subscript(_: Key)` returns `Value?` and can be used to
+ // form `WritableKeyPath<Self, Value>` key paths.
+ // Force-unwrapping the result is necessary.
+ let result = keys.map { \Dictionary[$0]! }
+ _internalInvariant(areWritable(result, valueType: Value.self))
+ return result
+ }
+}
diff --git a/stdlib/public/core/MathFunctions.swift.gyb b/stdlib/public/core/MathFunctions.swift.gyb
new file mode 100644
index 0000000..67b9735
--- /dev/null
+++ b/stdlib/public/core/MathFunctions.swift.gyb
@@ -0,0 +1,180 @@
+//===--- MathFunctions.swift ----------------------------------*- swift -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+
+import SwiftShims
+
+%from SwiftMathFunctions import *
+%from SwiftFloatingPointTypes import all_floating_point_types
+
+%# Skip `Float16` for now until it's clear how to conform it to
+%# `ElementaryFunctions`, i.e. after apple/swift-numerics adds the conformance.
+%floating_point_types = [type for type in all_floating_point_types() if type.bits != 16]
+
+/// A type that has elementary functions available.
+///
+/// An ["elementary function"][elfn] is a function built up from powers, roots,
+/// exponentials, logarithms, trigonometric functions (sin, cos, tan) and
+/// their inverses, and the hyperbolic functions (sinh, cosh, tanh) and their
+/// inverses.
+///
+/// Conformance to this protocol means that all of these building blocks are
+/// available as static functions on the type.
+///
+/// ```swift
+/// let x: Float = 1
+/// let y = Float.sin(x) // 0.84147096
+/// ```
+///
+/// [elfn]: http://en.wikipedia.org/wiki/Elementary_function
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE(TF-796): Make `ElementaryFunctions` available on macOS.
+// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)
+public protocol ElementaryFunctions {
+
+%for func in ElementaryFunctions:
+
+ ${func.comment}
+ static func ${func.decl("Self")}
+%end
+
+ /// `exp(y log(x))` computed without loss of intermediate precision.
+ ///
+ /// For real types, if `x` is negative the result is NaN, even if `y` has
+ /// an integral value. For complex types, there is a branch cut on the
+ /// negative real axis.
+ static func pow(_ x: Self, _ y: Self) -> Self
+
+ /// `x` raised to the `n`th power.
+ static func pow(_ x: Self, _ n: Int) -> Self
+
+ /// The `n`th root of `x`.
+ ///
+ /// For real types, if `x` is negative and `n` is even, the result is NaN.
+ /// For complex types, there is a branch cut along the negative real axis.
+ static func root(_ x: Self, _ n: Int) -> Self
+}
+
+%for type in floating_point_types:
+% if type.bits == 80:
+#if (arch(i386) || arch(x86_64)) && !(os(Windows) || os(Android))
+% end
+% Self = type.stdlib_name
+extension ${Self}: ElementaryFunctions {
+% for func in ElementaryFunctions + RealFunctions:
+
+ @_alwaysEmitIntoClient
+ public static func ${func.decl(Self)} {
+ return ${func.impl(type)}
+ }
+% end
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: ${Self}, _ y: ${Self}) -> ${Self} {
+ guard x >= 0 else { return .nan }
+ return ${Self}(Builtin.int_pow_FPIEEE${type.bits}(x._value, y._value))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: ${Self}, _ n: Int) -> ${Self} {
+ // TODO: this implementation isn't quite right for n so large that
+ // the conversion to `${Self}` rounds. We could also consider using
+ // a multiply-chain implementation for small `n`; this would be faster
+ // for static `n`, but less accurate on platforms with a good `pow`
+ // implementation.
+ return ${Self}(Builtin.int_pow_FPIEEE${type.bits}(x._value, ${Self}(n)._value))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func root(_ x: ${Self}, _ n: Int) -> ${Self} {
+ guard x >= 0 || n % 2 != 0 else { return .nan }
+ // TODO: this implementation isn't quite right for n so large that
+ // the conversion to `${Self}` rounds.
+ return ${Self}(signOf: x, magnitudeOf: pow(x.magnitude, 1/${Self}(n)))
+ }
+
+ @_alwaysEmitIntoClient
+ public static func atan2(_ y: ${Self}, _ x: ${Self}) -> ${Self} {
+ return _stdlib_atan2${type.cFuncSuffix}(y, x)
+ }
+
+#if !os(Windows)
+ @_alwaysEmitIntoClient
+ public static func logGamma(_ x: ${Self}) -> ${Self} {
+ return _stdlib_lgamma${type.cFuncSuffix}(x)
+ }
+
+ @_alwaysEmitIntoClient
+ public static func signGamma(_ x: ${Self}) -> FloatingPointSign {
+ if x >= 0 { return .plus }
+ let trunc = x.rounded(.towardZero)
+ if x == trunc { return .plus }
+ let halfTrunc = trunc/2
+ if halfTrunc == halfTrunc.rounded(.towardZero) { return .minus }
+ return .plus
+ }
+#endif
+}
+% if type.bits == 80:
+#endif
+% end
+%end
+
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE(TF-796): Make `ElementaryFunctions` available on macOS.
+// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)
+extension SIMD where Scalar: ElementaryFunctions {
+% for func in ElementaryFunctions:
+
+ @_alwaysEmitIntoClient
+ public static func ${func.decl("Self")} {
+ var r = Self()
+ for i in r.indices {
+ r[i] = Scalar.${func.swiftName}(${func.params(suffix="[i]")})
+ }
+ return r
+ }
+% end
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: Self, _ y: Self) -> Self {
+ var r = Self()
+ for i in r.indices {
+ r[i] = Scalar.pow(x[i], y[i])
+ }
+ return r
+ }
+
+ @_alwaysEmitIntoClient
+ public static func pow(_ x: Self, _ n: Int) -> Self {
+ var r = Self()
+ for i in r.indices {
+ r[i] = Scalar.pow(x[i], n)
+ }
+ return r
+ }
+
+ @_alwaysEmitIntoClient
+ public static func root(_ x: Self, _ n: Int) -> Self {
+ var r = Self()
+ for i in r.indices {
+ r[i] = Scalar.root(x[i], n)
+ }
+ return r
+ }
+}
+
+%for n in [2,3,4,8,16,32,64]:
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE(TF-796): Make `ElementaryFunctions` available on macOS.
+// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)
+extension SIMD${n}: ElementaryFunctions where Scalar: ElementaryFunctions { }
+%end
diff --git a/stdlib/public/core/SIMDVector.swift b/stdlib/public/core/SIMDVector.swift
index 7c84c91..d9e6633 100644
--- a/stdlib/public/core/SIMDVector.swift
+++ b/stdlib/public/core/SIMDVector.swift
@@ -842,7 +842,10 @@
}
/// Returns the sum of the scalars in the vector.
- @_alwaysEmitIntoClient
+ // SWIFT_ENABLE_TENSORFLOW: We have changed `@_alwaysEmitIntoClient` to `@inlinable`, to work
+ // around TF-1103.
+ // TODO(TF-1103): Change `@inlinable` back to `@_alwaysEmitIntoClient`.
+ @inlinable
public func sum() -> Scalar {
// Implementation note: this eventually be defined to lower to either
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
diff --git a/test/AutoDiff/SILGen/vtable.swift b/test/AutoDiff/SILGen/vtable.swift
index 782a08b..42f41eb 100644
--- a/test/AutoDiff/SILGen/vtable.swift
+++ b/test/AutoDiff/SILGen/vtable.swift
@@ -113,8 +113,8 @@
// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk
// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk
// CHECK: #Super.genericMethod: <T> (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF
-// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk
-// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk
+// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk
+// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk
// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable5SuperC8propertySfvg
// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable5SuperC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk
// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable5SuperC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk
@@ -128,8 +128,8 @@
// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [override]
// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [override]
// CHECK: #Super.genericMethod: <T> (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF [inherited]
-// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited]
-// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited]
+// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk [inherited]
+// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk [inherited]
// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable3SubC8propertySfvg [override]
// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [override]
// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [override]
@@ -145,8 +145,8 @@
// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [inherited]
// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [inherited]
// CHECK: #Super.genericMethod: <T> (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF [inherited]
-// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited]
-// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited]
+// CHECK: #Super.genericMethod!jvp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk [inherited]
+// CHECK: #Super.genericMethod!vjp.SUU.<T where T : Differentiable>: <T> (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_{{s|16_Differentiation}}14DifferentiableRzl_vtable_entry_thunk [inherited]
// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable3SubC8propertySfvg [inherited]
// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [inherited]
// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [inherited]
diff --git a/test/AutoDiff/SILOptimizer/differentiation_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_sil.swift
index 37b7348..5e11bc0 100644
--- a/test/AutoDiff/SILOptimizer/differentiation_sil.swift
+++ b/test/AutoDiff/SILOptimizer/differentiation_sil.swift
@@ -1,6 +1,12 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN
// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s | %FileCheck %s --check-prefix=CHECK-SIL
+// SWIFT_ENABLE_TENSORFLOW
+// Note: this test is currently for master branch. It can be enabled on
+// tensorflow branch after more AutoDiff upstreaming.
+// UNSUPPORTED: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// Simple differentiation transform test: check SIL before and after the transform.
import _Differentiation
diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift
index e933855..9221e31 100644
--- a/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift
+++ b/test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift
@@ -70,7 +70,9 @@
// CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable {
// CHECK-AST: internal init()
-// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic {
+// SWIFT_ENABLE_TENSORFLOW
+// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// SWIFT_ENABLE_TENSORFLOW END
@usableFromInline
struct UsableFromInlineStruct: Differentiable {}
@@ -79,7 +81,9 @@
// CHECK-AST: struct UsableFromInlineStruct : Differentiable {
// CHECK-AST: internal init()
// CHECK-AST: @usableFromInline
-// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
+// SWIFT_ENABLE_TENSORFLOW
+// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// SWIFT_ENABLE_TENSORFLOW END
// Test property wrappers.
@@ -96,7 +100,7 @@
}
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
-// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
@@ -111,7 +115,7 @@
}
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
-// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift
index ef536b5..98d6428 100644
--- a/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift
+++ b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift
@@ -115,10 +115,10 @@
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0Vycvg : $@convention(method) (MemberwiseTangentVectorStruct) -> @owned @callee_guaranteed () -> MemberwiseTangentVectorStruct.TangentVector {
// CHECK: bb0([[SELF:%.*]] : $MemberwiseTangentVectorStruct):
// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.x
-// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
+// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]])
// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.y
-// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
+// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]])
// CHECK: // function_ref closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0VycvgAFycfU_
@@ -132,10 +132,10 @@
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvg : $@convention(method) (SelfTangentVectorStruct) -> @owned @callee_guaranteed () -> SelfTangentVectorStruct {
// CHECK: bb0([[SELF:%.*]] : $SelfTangentVectorStruct):
// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.x
-// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
+// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]])
// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.y
-// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
+// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]])
// CHECK: // function_ref closure #1 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU_
@@ -157,11 +157,11 @@
// CHECK: bb0([[SELF:%.*]] : @guaranteed $MemberwiseTangentVectorClass):
// CHECK: [[X_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.x!getter
// CHECK: [[X_PROP:%.*]] = apply [[X_PROP_METHOD]]([[SELF]])
-// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
+// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float
// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]])
// CHECK: [[Y_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.y!getter
// CHECK: [[Y_PROP:%.*]] = apply [[Y_PROP_METHOD]]([[SELF]])
-// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
+// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]])
// CHECK: // function_ref closure #1 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU_
diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift
index 451bf4d..4071637 100644
--- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift
+++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift
@@ -241,14 +241,24 @@
}
extension InstanceMethod {
+ // expected-note @+2 {{'foo' previously declared here}}
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self { x }
+ // expected-note @+2 {{'generic' previously declared here}}
// expected-note @+1 {{'generic' defined here}}
func generic<T: Differentiable>(_ x: T) -> Self { self }
}
extension InstanceMethod {
+ // expected-error @+1 {{invalid redeclaration of 'foo'}}
+ func foo(_ x: Self) -> Self { self }
+
+ // expected-error @+1 {{invalid redeclaration of 'generic'}}
+ func generic<T: Differentiable>(_ x: T) -> Self { self }
+}
+
+extension InstanceMethod {
@derivative(of: foo)
func jvpFoo(x: Self) -> (
value: Self, differential: (TangentVector, TangentVector) -> (TangentVector)
diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift
index ac54b0b..8021c34 100644
--- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift
+++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift
@@ -171,16 +171,14 @@
subscript(explicit x: Float) -> Float {
@differentiable // ok
get { return x }
- // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
- @differentiable
+ @differentiable // ok
set {}
}
subscript(x: Float, y: Float) -> Float {
@differentiable // ok
get { return x + y }
- // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
- @differentiable
+ @differentiable // ok
set {}
}
}
@@ -700,7 +698,7 @@
var stored: Float
var computed: Float {
- // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
+ // `set` has an `inout` parameter: `(inout Self) -> (Float) -> ()`.
@differentiable
set { stored = newValue }
diff --git a/test/AutoDiff/Sema/differentiable_features_disabled.swift b/test/AutoDiff/Sema/differentiable_features_disabled.swift
index 073b893..7fb7e8d 100644
--- a/test/AutoDiff/Sema/differentiable_features_disabled.swift
+++ b/test/AutoDiff/Sema/differentiable_features_disabled.swift
@@ -1,4 +1,6 @@
// RUN: %target-swift-frontend -typecheck -verify %s
+// SWIFT_ENABLE_TENSORFLOW
+// XFAIL: *
// expected-error @+1 {{'@differentiable' attribute used without importing module '_Differentiation'}}
let _: @differentiable (Float) -> Float
diff --git a/test/AutoDiff/Sema/missing_differentiable_protocol.swift b/test/AutoDiff/Sema/missing_differentiable_protocol.swift
index 11e6d2f..6cdba5b 100644
--- a/test/AutoDiff/Sema/missing_differentiable_protocol.swift
+++ b/test/AutoDiff/Sema/missing_differentiable_protocol.swift
@@ -1,5 +1,11 @@
// RUN: %target-swift-frontend -typecheck -verify %s
+// SWIFT_ENABLE_TENSORFLOW
+// Expected to fail on `tensorflow` branch because the `Differentiable` protocol
+// currently exists in the core stdlib, not in the `_Differentiation` module.
+// XFAIL: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// Tests that Sema fails gracefully when the `_Differentiation` module is not imported.
// expected-error @+1 {{'@differentiable' attribute used without importing module '_Differentiation'}}
diff --git a/test/AutoDiff/compiler_crashers_fixed/sr12493-differentiable-function-extract-subst-function-type.swift b/test/AutoDiff/compiler_crashers_fixed/sr12493-differentiable-function-extract-subst-function-type.swift
index 359c1d2..e81c54a 100644
--- a/test/AutoDiff/compiler_crashers_fixed/sr12493-differentiable-function-extract-subst-function-type.swift
+++ b/test/AutoDiff/compiler_crashers_fixed/sr12493-differentiable-function-extract-subst-function-type.swift
@@ -3,6 +3,11 @@
// SR-12493: SIL verification error regarding substituted function types and
// `differentiable_function_extract` instruction. Occurs only with `-O`.
+// SWIFT_ENABLE_TENSORFLOW
+// Note: SR-12493 occurs only on master branch, not on tensorflow branch.
+// UNSUPPORTED: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// FIXME(SR-13021): Disabled due to flakiness on Linux.
// REQUIRES: SR13021
diff --git a/test/AutoDiff/compiler_crashers_fixed/tf1167-differentiable-attr-override-match.swift b/test/AutoDiff/compiler_crashers_fixed/tf1167-differentiable-attr-override-match.swift
index 18792bc..a45d11d 100644
--- a/test/AutoDiff/compiler_crashers_fixed/tf1167-differentiable-attr-override-match.swift
+++ b/test/AutoDiff/compiler_crashers_fixed/tf1167-differentiable-attr-override-match.swift
@@ -1,6 +1,12 @@
// RUN: %target-swift-frontend -typecheck -verify %s
// REQUIRES: asserts
+// SWIFT_ENABLE_TENSORFLOW
+// This test isn't reproducible on `tensorflow` branch because
+// `_Differentiation` module does not currently match `master` branch.
+// REQUIRES: no_tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// TF-1167: `OverrideMatcher::match` crash due to meaningless assertion:
// `assert(false)`. The assertion was triggered when parameter indices
// could not be resolved for neither base nor derived declaration
diff --git a/test/AutoDiff/compiler_crashers_fixed/tf1232-autodiff-generated-declaration-mangling.swift b/test/AutoDiff/compiler_crashers_fixed/tf1232-autodiff-generated-declaration-mangling.swift
index 61bc8ac..e27ada8 100644
--- a/test/AutoDiff/compiler_crashers_fixed/tf1232-autodiff-generated-declaration-mangling.swift
+++ b/test/AutoDiff/compiler_crashers_fixed/tf1232-autodiff-generated-declaration-mangling.swift
@@ -1,6 +1,11 @@
// RUN: %target-build-swift -g %s
// REQUIRES: asserts
+// SWIFT_ENABLE_TENSORFLOW
+// Note: TF-1232 is reproducible only on master branch, not on tensorflow branch.
+// UNSUPPORTED: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// TF-1232: IRGenDebugInfo crash due to lack of proper mangling for
// AutoDiff-generated declarations: linear map structs and branching trace
// enums.
diff --git a/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file1.swift b/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file1.swift
new file mode 100644
index 0000000..c808c0e
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file1.swift
@@ -0,0 +1,4 @@
+@_alwaysEmitIntoClient
+public func f(_ x: Float) -> Float {
+ x
+}
diff --git a/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file2.swift b/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file2.swift
new file mode 100644
index 0000000..d6888f7
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/always_emit_into_client/MultiFileModule/file2.swift
@@ -0,0 +1,5 @@
+@derivative(of: f)
+@_alwaysEmitIntoClient
+public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { 10 * $0 })
+}
diff --git a/test/AutoDiff/downstream/Inputs/always_emit_into_client/SingleFileModule/file.swift b/test/AutoDiff/downstream/Inputs/always_emit_into_client/SingleFileModule/file.swift
new file mode 100644
index 0000000..c006087
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/always_emit_into_client/SingleFileModule/file.swift
@@ -0,0 +1,10 @@
+@_alwaysEmitIntoClient
+public func f(_ x: Float) -> Float {
+ x
+}
+
+@derivative(of: f)
+@_alwaysEmitIntoClient
+public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { 10 * $0 })
+}
diff --git a/test/AutoDiff/downstream/Inputs/class_method_thunk_other_module.swift b/test/AutoDiff/downstream/Inputs/class_method_thunk_other_module.swift
new file mode 100644
index 0000000..5eb6a8a
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/class_method_thunk_other_module.swift
@@ -0,0 +1,40 @@
+class OtherModuleSuper {
+ @differentiable
+ func f(_ x: Float) -> Float {
+ return 2 * x
+ }
+
+ @derivative(of: f)
+ final func jvpf(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (f(x), { v in 2 * v })
+ }
+
+ @derivative(of: f)
+ final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (f(x), { v in 2 * v })
+ }
+}
+
+class OtherModuleSubOverride : OtherModuleSuper {
+ @differentiable
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+}
+
+class OtherModuleSubOverrideCustomDerivatives : OtherModuleSuper {
+ @differentiable
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+
+ @derivative(of: f)
+ final func jvpf2(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+
+ @derivative(of: f)
+ final func vjpf2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+}
diff --git a/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/main/main.swift b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/main/main.swift
new file mode 100644
index 0000000..5b47648
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/main/main.swift
@@ -0,0 +1,12 @@
+import StdlibUnittest
+
+import module1
+
+var Tests = TestSuite("CrossModuleDerivativeAttr")
+
+Tests.test("CrossFile") {
+ let grad = gradient(at: 0, in: fCrossFile)
+ expectEqual(10, grad)
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift
new file mode 100644
index 0000000..c76929b
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift
@@ -0,0 +1 @@
+public func fCrossFile(_ x: Float) -> Float { x }
diff --git a/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift
new file mode 100644
index 0000000..c02963a
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift
@@ -0,0 +1,4 @@
+@derivative(of: fCrossFile)
+public func vjpCrossFile(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { 10 * $0 })
+}
diff --git a/test/AutoDiff/downstream/Inputs/differentiable_attr_other_module.swift b/test/AutoDiff/downstream/Inputs/differentiable_attr_other_module.swift
new file mode 100644
index 0000000..5dd77d8
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/differentiable_attr_other_module.swift
@@ -0,0 +1,26 @@
+// Verify that `@differentiable` declarations can be differentiated from other
+// modules.
+
+public struct Foo: Differentiable {
+ public var x: Float
+
+ @differentiable
+ public init(_ x: Float) {
+ self.x = x
+ }
+
+ @differentiable
+ public func method() -> Float {
+ x
+ }
+
+ @differentiable
+ public var computedProperty: Float {
+ x
+ }
+
+ @differentiable
+ public subscript() -> Float {
+ x
+ }
+}
diff --git a/test/AutoDiff/downstream/Inputs/differentiable_attr_silgen_other_module.swift b/test/AutoDiff/downstream/Inputs/differentiable_attr_silgen_other_module.swift
new file mode 100644
index 0000000..a27a68b
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/differentiable_attr_silgen_other_module.swift
@@ -0,0 +1,32 @@
+public struct Wrapper : Differentiable, AdditiveArithmetic {
+ public var x: Float
+ public init(_ x: Float) {
+ self.x = x
+ }
+
+ public static func + (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
+ return Wrapper(lhs.x + rhs.x)
+ }
+
+ @derivative(of: +)
+ public static func vjpAdd(lhs: Wrapper, rhs: Wrapper)
+ -> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
+ return (lhs + rhs, { v in (v, v) })
+ }
+
+ public static func * (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
+ return Wrapper(lhs.x * rhs.x)
+ }
+
+ @derivative(of: *)
+ public static func jvpMultiply(lhs: Wrapper, rhs: Wrapper)
+ -> (value: Wrapper, differential: (Wrapper, Wrapper) -> Wrapper) {
+ return (lhs * rhs, { dlhs, drhs in dlhs * rhs + lhs * drhs })
+ }
+
+ @derivative(of: *)
+ public static func vjpMultiply(lhs: Wrapper, rhs: Wrapper)
+ -> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
+ return (lhs * rhs, { v in (v * rhs, v * lhs) })
+ }
+}
diff --git a/test/AutoDiff/downstream/Inputs/differentiable_attr_type_checking_non_primary_file.swift b/test/AutoDiff/downstream/Inputs/differentiable_attr_type_checking_non_primary_file.swift
new file mode 100644
index 0000000..5dc8e01
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/differentiable_attr_type_checking_non_primary_file.swift
@@ -0,0 +1,27 @@
+public protocol Layer: Differentiable {
+ associatedtype Input: Differentiable
+ associatedtype Output: Differentiable
+
+ @differentiable
+ func instanceMethod(_ input: Input) -> Output
+
+ @differentiable
+ var computedProperty: Output { get }
+}
+
+struct DummyLayer: Layer {
+ @differentiable
+ func instanceMethod(_ input: Float) -> Float {
+ return input
+ }
+
+ @differentiable
+ var computedProperty: Float { 1 }
+}
+
+public extension Differentiable {
+ @differentiable
+ func sequenced<L: Layer>(through layer: L) -> L.Output where L.Input == Self {
+ return layer.instanceMethod(self)
+ }
+}
diff --git a/test/AutoDiff/downstream/Inputs/differentiable_requirement_other_module.swift b/test/AutoDiff/downstream/Inputs/differentiable_requirement_other_module.swift
new file mode 100644
index 0000000..f6c23cc
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/differentiable_requirement_other_module.swift
@@ -0,0 +1,6 @@
+public struct Empty : AdditiveArithmetic {}
+
+public protocol DifferentiableRequirement {
+ @differentiable
+ func foo(float: Float, empty: Empty) -> Float
+}
diff --git a/test/AutoDiff/downstream/Inputs/e2e_cross_module_external_module.swift b/test/AutoDiff/downstream/Inputs/e2e_cross_module_external_module.swift
new file mode 100644
index 0000000..132b164
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/e2e_cross_module_external_module.swift
@@ -0,0 +1,14 @@
+import DifferentiationUnittest
+
+@differentiable
+public func doubleThenApplyDefaultF(_ x: Tracked<Float>) -> Tracked<Float> {
+ return x
+}
+
+@differentiable
+public func doubleThenApply(
+ _ x: Tracked<Float>,
+ _ f: @differentiable (Tracked<Float>) -> Tracked<Float> = doubleThenApplyDefaultF
+) -> Tracked<Float> {
+ return f(2 * x)
+}
diff --git a/test/AutoDiff/downstream/Inputs/loadable_by_address_cross_module.swift b/test/AutoDiff/downstream/Inputs/loadable_by_address_cross_module.swift
new file mode 100644
index 0000000..2618acb
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/loadable_by_address_cross_module.swift
@@ -0,0 +1,16 @@
+public struct LargeLoadableType<T>: AdditiveArithmetic, Differentiable {
+ public var a, b, c, d, e: Float
+
+ public init(a: Float) {
+ self.a = a
+ self.b = 0
+ self.c = 0
+ self.d = 0
+ self.e = 0
+ }
+
+ @differentiable
+ public func externalLBAModifiedFunction(_ x: Float) -> Float {
+ return a * x
+ }
+}
diff --git a/test/AutoDiff/downstream/Inputs/noderivative_attr_other_file.swift b/test/AutoDiff/downstream/Inputs/noderivative_attr_other_file.swift
new file mode 100644
index 0000000..b584253
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/noderivative_attr_other_file.swift
@@ -0,0 +1,5 @@
+@noDerivative
+@_silgen_name("float_to_int_noderivative")
+func floatToIntNoDerivative(_ x: Float) -> Int {
+ Int(x)
+}
diff --git a/test/AutoDiff/downstream/Inputs/nondifferentiable_function_other_module.swift b/test/AutoDiff/downstream/Inputs/nondifferentiable_function_other_module.swift
new file mode 100644
index 0000000..b7415c4
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/nondifferentiable_function_other_module.swift
@@ -0,0 +1,5 @@
+// A public function that is not marked with `@differentiable`.
+// Differentiation of `externalFunction` in other modules should fail.
+public func externalFunction(_ x: Float) -> Float {
+ return x + x
+}
diff --git a/test/AutoDiff/downstream/Inputs/silgen_thunking_other_module.swift b/test/AutoDiff/downstream/Inputs/silgen_thunking_other_module.swift
new file mode 100644
index 0000000..e28cbea
--- /dev/null
+++ b/test/AutoDiff/downstream/Inputs/silgen_thunking_other_module.swift
@@ -0,0 +1,13 @@
+struct TF_619: Differentiable {
+ var p: Float = 1
+
+ @differentiable
+ func foo(_ x: Float) -> Float {
+ return p * x
+ }
+
+ @derivative(of: foo)
+ func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) {
+ return (x, { v in (TangentVector(p: v * x), v * self.p) })
+ }
+}
diff --git a/test/AutoDiff/downstream/README.txt b/test/AutoDiff/downstream/README.txt
new file mode 100644
index 0000000..e01e5b4
--- /dev/null
+++ b/test/AutoDiff/downstream/README.txt
@@ -0,0 +1,6 @@
+This directory contains AutoDiff tests on `tensorflow` branch that have not yet
+been upstreamed to `master`.
+
+As tests are upstreamed to `master` branch in organized `test/AutoDiff`
+subdirectories, corresponding tests in this directory can be deleted from
+`tensorflow` branch.
diff --git a/test/AutoDiff/downstream/always_emit_into_client_multi_file.swift b/test/AutoDiff/downstream/always_emit_into_client_multi_file.swift
new file mode 100644
index 0000000..4f1ebcd
--- /dev/null
+++ b/test/AutoDiff/downstream/always_emit_into_client_multi_file.swift
@@ -0,0 +1,21 @@
+// TODO(TF-1103): Fix this test so that there is not a linker error. Then, move this test to
+// cross_module_derivative_attr_e2e.swift.
+
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -parse-as-library -emit-module -module-name MultiFileModule -emit-module-path %t/MultiFileModule.swiftmodule -emit-library -o %t/%target-library-name(MultiFileModule) %S/Inputs/always_emit_into_client/MultiFileModule/file1.swift %S/Inputs/always_emit_into_client/MultiFileModule/file2.swift
+// RUN: not %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lMultiFileModule 2>&1 | %FileCheck %s
+
+import StdlibUnittest
+
+import MultiFileModule
+
+var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
+
+AlwaysEmitIntoClientTests.test("registration") {
+ expectEqual(10, gradient(at: 0, in: f))
+}
+
+runAllTests()
+
+// CHECK: {{[Uu]}}ndefined
+// CHECK: AD__$s15MultiFileModule1fyS2fF_PSRS
diff --git a/test/AutoDiff/downstream/always_emit_into_client_single_file.swift b/test/AutoDiff/downstream/always_emit_into_client_single_file.swift
new file mode 100644
index 0000000..0657e26
--- /dev/null
+++ b/test/AutoDiff/downstream/always_emit_into_client_single_file.swift
@@ -0,0 +1,21 @@
+// TODO(TF-1103): Fix this test so that there is not a linker error. Then, move this test to
+// cross_module_derivative_attr_e2e.swift.
+
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -Xfrontend -parse-as-library -emit-module -module-name SingleFileModule -emit-module-path %t/SingleFileModule.swiftmodule -emit-library -o %t/%target-library-name(SingleFileModule) %S/Inputs/always_emit_into_client/SingleFileModule/file.swift
+// RUN: not %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lSingleFileModule 2>&1 | %FileCheck %s
+
+import StdlibUnittest
+
+import SingleFileModule
+
+var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
+
+AlwaysEmitIntoClientTests.test("registration") {
+ expectEqual(10, gradient(at: 0, in: f))
+}
+
+runAllTests()
+
+// CHECK: {{[Uu]}}ndefined
+// CHECK: AD__$s16SingleFileModule1fyS2fF_PSRS
diff --git a/test/AutoDiff/downstream/ast_serialization.swift b/test/AutoDiff/downstream/ast_serialization.swift
new file mode 100644
index 0000000..3276695
--- /dev/null
+++ b/test/AutoDiff/downstream/ast_serialization.swift
@@ -0,0 +1,8 @@
+// RUN: %target-swift-frontend -emit-sib -primary-file %s
+
+// Test AST serialization of differentiation generated structs/enums.
+@differentiable
+func TF_623(_ x: Float) -> Float {
+ if x > 0 {}
+ return x
+}
diff --git a/test/AutoDiff/downstream/autodiff_generated_decl_member_loading.swift b/test/AutoDiff/downstream/autodiff_generated_decl_member_loading.swift
new file mode 100644
index 0000000..f05ad9a
--- /dev/null
+++ b/test/AutoDiff/downstream/autodiff_generated_decl_member_loading.swift
@@ -0,0 +1,31 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-module %s -o %t/autodiff_generated_decl_member_loading_cross_module.swiftmodule
+// RUN: %target-swift-frontend -merge-modules -emit-module %t/autodiff_generated_decl_member_loading_cross_module.swiftmodule
+
+// Tests TF-805.
+//
+// Previously, `IterableDeclContext::loadAllMembers` was disabled for
+// AD-generated structs/enums:
+// https://github.com/apple/swift/commit/7e89bee39bfda4624f04dcc2c8d53599fbde6191
+//
+// This caused a SIL verification failure for because enum members were not loaded
+// when running `-emit-module` then `-merge-modules -sil-merge-partial-modules`:
+//
+// SIL verification failed: switch_enum dispatches on same enum element
+// more than once: unswitchedElts.count(elt)
+
+public struct Tensor<Scalar> {}
+
+extension Tensor: Differentiable where Scalar: Differentiable{
+ public typealias TangentVector = Float
+ public mutating func move(along direction: Float) {}
+}
+
+extension Tensor {
+ @inlinable
+ @differentiable(wrt: self where Scalar: Differentiable)
+ func TF_805(axis: Int) -> Tensor {
+ if axis != axis {}
+ return self
+ }
+}
diff --git a/test/AutoDiff/downstream/class_differentiation.swift b/test/AutoDiff/downstream/class_differentiation.swift
new file mode 100644
index 0000000..038c547
--- /dev/null
+++ b/test/AutoDiff/downstream/class_differentiation.swift
@@ -0,0 +1,169 @@
+// RUN: %target-run-simple-swift
+// NOTE: Verify whether forward-mode differentiation crashes. It currently does.
+// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
+// REQUIRES: executable_test
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var ClassTests = TestSuite("ClassDifferentiation")
+
+ClassTests.test("TrivialMember") {
+ final class C: Differentiable {
+ @differentiable
+ var float: Float
+
+ @noDerivative
+ final var noDerivative: Float = 1
+
+ @differentiable
+ init(_ float: Float) {
+ self.float = float
+ }
+
+ @differentiable
+ convenience init(convenience x: Float) {
+ self.init(x)
+ }
+
+ @differentiable
+ func method(_ x: Float) -> Float {
+ x * float
+ }
+
+ @differentiable
+ func testNoDerivative() -> Float {
+ noDerivative
+ }
+
+ @differentiable
+ static func controlFlow(_ c1: C, _ c2: C, _ flag: Bool) -> Float {
+ var result: Float = 0
+ if flag {
+ var c3 = C(c1.float * c2.float)
+ result = c3.float
+ } else {
+ result = c2.float * c1.float
+ }
+ return result
+ }
+ }
+ // Test class initializer differentiation.
+ expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
+ expectEqual(10, pullback(at: 3, in: { C(convenience: $0) })(.init(float: 10)))
+ // Test class method differentiation.
+ expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
+ expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() }))
+ expectEqual((.init(float: 20), .init(float: 10)),
+ gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
+}
+
+ClassTests.test("NontrivialMember") {
+ final class C: Differentiable {
+ @differentiable
+ var float: Tracked<Float>
+
+ @differentiable
+ init(_ float: Tracked<Float>) {
+ self.float = float
+ }
+
+ @differentiable
+ func method(_ x: Tracked<Float>) -> Tracked<Float> {
+ x * float
+ }
+
+ @differentiable
+ static func controlFlow(_ c1: C, _ c2: C, _ flag: Bool) -> Tracked<Float> {
+ var result: Tracked<Float> = 0
+ if flag {
+ result = c1.float * c2.float
+ } else {
+ result = c2.float * c1.float
+ }
+ return result
+ }
+ }
+ // Test class initializer differentiation.
+ expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10)))
+ // Test class method differentiation.
+ expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
+ expectEqual((.init(float: 20), .init(float: 10)),
+ gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
+}
+
+ClassTests.test("GenericNontrivialMember") {
+ final class C<T: Differentiable>: Differentiable where T == T.TangentVector {
+ @differentiable
+ var x: Tracked<T>
+
+ @differentiable
+ init(_ x: T) {
+ self.x = Tracked(x)
+ }
+
+ @differentiable
+ convenience init(convenience x: T) {
+ self.init(x)
+ }
+ }
+ // Test class initializer differentiation.
+ expectEqual(10, pullback(at: 3, in: { C<Float>($0) })(.init(x: 10)))
+ expectEqual(10, pullback(at: 3, in: { C<Float>(convenience: $0) })(.init(x: 10)))
+}
+
+// TF-1149: Test class with loadable type but address-only `TangentVector` type.
+// TODO(TF-1149): Uncomment when supported.
+/*
+ClassTests.test("AddressOnlyTangentVector") {
+ final class C<T: Differentiable>: Differentiable {
+ @differentiable
+ var stored: T
+
+ @differentiable
+ init(_ stored: T) {
+ self.stored = stored
+ }
+
+ @differentiable
+ func method(_ x: T) -> T {
+ stored
+ }
+ }
+ // Test class initializer differentiation.
+ expectEqual(10, pullback(at: 3, in: { C<Float>($0) })(.init(float: 10)))
+ // Test class method differentiation.
+ expectEqual((.init(stored: Float(3)), 10),
+ gradient(at: C<Float>(3), 3, in: { c, x in c.method(x) }))
+}
+*/
+
+// TF-1175: Test whether class-typed arguments are not marked active.
+ClassTests.test("ClassArgumentActivity") {
+ class C: Differentiable {
+ @differentiable
+ var x: Float
+
+ init(_ x: Float) {
+ self.x = x
+ }
+
+ // Note: this method mutates `self`. However, since `C` is a class, the
+ // method type does not involve `inout` arguments: `(C) -> () -> ()`.
+ func square() {
+ x *= x
+ }
+ }
+
+ // Returns `x * x`.
+ func squared(_ x: Float) -> Float {
+ var c = C(x)
+ c.square() // FIXME(TF-1175): doesn't get differentiated!
+ return c.x
+ }
+ // FIXME(TF-1175): Find a robust solution so that derivatives are correct.
+ // expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
+ expectEqual((100, 1), valueWithGradient(at: 10, in: squared))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/class_method.swift b/test/AutoDiff/downstream/class_method.swift
new file mode 100644
index 0000000..9db7e1a
--- /dev/null
+++ b/test/AutoDiff/downstream/class_method.swift
@@ -0,0 +1,400 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var ClassMethodTests = TestSuite("ClassMethods")
+
+ClassMethodTests.test("Final") {
+ final class Final: Differentiable {
+ func method(_ x: Tracked<Float>) -> Tracked<Float> {
+ return x * x
+ }
+ }
+
+ for i in -5...5 {
+ expectEqual(Tracked<Float>(Float(i * 2)),
+ gradient(at: Tracked<Float>(Float(i))) {
+ x in Final().method(x)
+ })
+ }
+}
+
+ClassMethodTests.test("Simple") {
+ class Super {
+ @differentiable(wrt: x)
+ func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 2 * x
+ }
+
+ @derivative(of: f)
+ final func jvpf(_ x: Tracked<Float>) -> (value: Tracked<Float>, differential: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 2 * v })
+ }
+
+ @derivative(of: f)
+ final func vjpf(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 2 * v })
+ }
+ }
+
+ class SubOverride: Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ }
+
+ class SubOverrideCustomDerivatives: Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+
+ @derivative(of: f)
+ final func jvpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, differential: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+
+ @derivative(of: f)
+ final func vjpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Tracked<Float>) {
+ return valueWithGradient(at: 1) { c.f($0) }
+ }
+ expectEqual((2, 2), classValueWithGradient(Super()))
+ expectEqual((3, 3), classValueWithGradient(SubOverride()))
+ expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives()))
+}
+
+ClassMethodTests.test("SimpleWrtSelf") {
+ class Super: Differentiable {
+ var base: Tracked<Float>
+ // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
+ var _nontrivial: [Tracked<Float>] = []
+
+ init(base: Tracked<Float>) {
+ self.base = base
+ }
+
+ @differentiable(wrt: (self, x))
+ func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return base * x
+ }
+
+ @derivative(of: f)
+ final func jvpf(
+ _ x: Tracked<Float>
+ ) -> (value: Tracked<Float>, differential: (TangentVector, Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { (dself, dx) in dself.base * dx })
+ }
+
+ @derivative(of: f)
+ final func vjpf(
+ _ x: Tracked<Float>
+ ) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
+ let base = self.base
+ return (f(x), { v in
+ (TangentVector(base: v * x, _nontrivial: []), base * v)
+ })
+ }
+ }
+
+ final class SubOverride: Super {
+ @differentiable
+ override init(base: Tracked<Float>) {
+ super.init(base: base)
+ }
+
+ // Note: `TangentVector` type is unused.
+ // There is no way to customize `SubOverride: Differentiable` conformance.
+ // The conformance is always inherited from `Super`.
+ struct TangentVector: Differentiable & AdditiveArithmetic {
+ var base: Float
+ }
+
+ @differentiable(wrt: (self, x))
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ }
+
+ final class SubOverrideCustomDerivatives: Super {
+ @differentiable
+ override init(base: Tracked<Float>) {
+ super.init(base: base)
+ }
+ @derivative(of: init)
+ static func vjpInit(base: Tracked<Float>) -> (
+ value: SubOverrideCustomDerivatives, pullback: (Super.TangentVector) -> Tracked<Float>
+ ) {
+ return (SubOverrideCustomDerivatives(base: base), { x in x.base * 2 })
+ }
+
+ @differentiable(wrt: (self, x))
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ @derivative(of: f, wrt: x)
+ final func jvpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, differential: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+ @derivative(of: f, wrt: x)
+ final func vjpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ let v = Super.TangentVector(base: 100, _nontrivial: [])
+ expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v))
+ expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v))
+ expectEqual(200, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v))
+
+ // `valueWithGradient` is not used because nested tuples cannot be compared
+ // with `expectEqual`.
+ func classGradient(_ c: Super) -> (Super.TangentVector, Tracked<Float>) {
+ return gradient(at: c, 10) { c, x in c.f(x) }
+ }
+ expectEqual((Super.TangentVector(base: 10, _nontrivial: []), 2),
+ classGradient(Super(base: 2)))
+ expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3),
+ classGradient(SubOverride(base: 2)))
+ expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3),
+ classGradient(SubOverrideCustomDerivatives(base: 2)))
+}
+
+ClassMethodTests.test("Generics") {
+ class Super<T: Differentiable & FloatingPoint> where T == T.TangentVector {
+ @differentiable(wrt: x)
+ func f(_ x: Tracked<T>) -> Tracked<T> {
+ return Tracked<T>(2) * x
+ }
+ @derivative(of: f)
+ final func jvpf(
+ _ x: Tracked<T>
+ ) -> (value: Tracked<T>, differential: (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
+ return (f(x), { v in Tracked<T>(2) * v })
+ }
+ @derivative(of: f)
+ final func vjpf(
+ _ x: Tracked<T>
+ ) -> (value: Tracked<T>, pullback: (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
+ return (f(x), { v in Tracked<T>(2) * v })
+ }
+ }
+
+ class SubOverride<T: Differentiable & FloatingPoint>: Super<T> where T == T.TangentVector {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<T>) -> Tracked<T> {
+ return x
+ }
+ }
+
+ class SubSpecializeOverride: Super<Float> {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ }
+
+ class SubOverrideCustomDerivatives<T: Differentiable & FloatingPoint>: Super<T>
+ where T == T.TangentVector {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<T>) -> Tracked<T> {
+ return Tracked<T>(3) * x
+ }
+ @derivative(of: f)
+ final func jvpf2(
+ _ x: Tracked<T>
+ ) -> (value: Tracked<T>, differential: (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
+ return (f(x), { v in Tracked<T>(3) * v })
+ }
+ @derivative(of: f)
+ final func vjpf2(
+ _ x: Tracked<T>
+ ) -> (value: Tracked<T>, pullback: (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
+ return (f(x), { v in Tracked<T>(3) * v })
+ }
+ }
+
+ class SubSpecializeOverrideCustomDerivatives: Super<Float80> {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float80>) -> Tracked<Float80> {
+ return 3 * x
+ }
+ @derivative(of: f)
+ final func jvpf2(
+ _ x: Tracked<Float80>
+ ) -> (value: Tracked<Float80>, differential: (Tracked<Float80>) -> Tracked<Float80>) {
+ return (f(x), { v in 3 * v })
+ }
+ @derivative(of: f)
+ final func vjpf2(
+ _ x: Tracked<Float80>
+ ) -> (value: Tracked<Float80>, pullback: (Tracked<Float80>) -> Tracked<Float80>) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ func classValueWithGradient<T: Differentiable & FloatingPoint>(
+ _ c: Super<T>
+ ) -> (T, T) where T == T.TangentVector {
+ let (x,y) = valueWithGradient(at: Tracked<T>(1), in: {
+ c.f($0) })
+ return (x.value, y.value)
+ }
+ expectEqual((2, 2), classValueWithGradient(Super<Float>()))
+ expectEqual((1, 1), classValueWithGradient(SubOverride<Float>()))
+ expectEqual((3, 3), classValueWithGradient(SubSpecializeOverride()))
+ expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives<Float>()))
+ expectEqual((3, 3), classValueWithGradient(SubSpecializeOverrideCustomDerivatives()))
+}
+
+ClassMethodTests.test("Methods") {
+ class Super: Differentiable {
+ var base: Tracked<Float>
+ // Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
+ var _nontrivial: [Tracked<Float>] = []
+
+ init(base: Tracked<Float>) {
+ self.base = base
+ }
+
+ @differentiable
+ func squared() -> Tracked<Float> { base * base }
+
+ @derivative(of: squared)
+ final func vjpSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> TangentVector) {
+ let base = self.base
+ return (base * base, { v in
+ TangentVector(base: 2 * base * v, _nontrivial: [])
+ })
+ }
+ }
+
+ class Sub1: Super {
+ @differentiable
+ override func squared() -> Tracked<Float> { base * base }
+ @derivative(of: squared)
+ final func vjpSquared2() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> TangentVector) {
+ let base = self.base
+ return (base * base, { v in
+ TangentVector(base: 2 * base * v, _nontrivial: [])
+ })
+ }
+ }
+
+ func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Super.TangentVector) {
+ return valueWithGradient(at: c) { c in c.squared() }
+ }
+
+ expectEqual(4, gradient(at: 2) { x in Super(base: x).squared() })
+ expectEqual(4, gradient(at: 2) { x in Sub1(base: x).squared() })
+
+ expectEqual(Super.TangentVector(base: 4, _nontrivial: []),
+ gradient(at: Super(base: 2)) { foo in foo.squared() })
+ expectEqual(Sub1.TangentVector(base: 4, _nontrivial: []),
+ gradient(at: Sub1(base: 2)) { foo in foo.squared() })
+}
+
+ClassMethodTests.test("Properties") {
+ class Super: Differentiable {
+ @differentiable
+ var base: Tracked<Float>
+
+ init(base: Tracked<Float>) { self.base = base }
+
+ @differentiable
+ var squared: Tracked<Float> { base * base }
+
+ @derivative(of: squared)
+ final func vjpSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> TangentVector) {
+ let base = self.base
+ return (base * base, { v in TangentVector(base: 2 * base * v) })
+ }
+ }
+
+ class Sub1: Super {
+ @differentiable
+ override var squared: Tracked<Float> { base * base }
+ }
+
+ func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Super.TangentVector) {
+ return valueWithGradient(at: c) { c in c.squared }
+ }
+
+ expectEqual(4, gradient(at: 2) { x in Super(base: x).squared })
+ expectEqual(Super.TangentVector(base: 4),
+ gradient(at: Super(base: 2)) { foo in foo.squared })
+}
+
+ClassMethodTests.test("Capturing") {
+ class Multiplier {
+ var coefficient: Tracked<Float>
+ init(_ coefficient: Tracked<Float>) {
+ self.coefficient = coefficient
+ }
+
+ // Case 1: generated VJP.
+ @differentiable
+ func apply(to x: Tracked<Float>) -> Tracked<Float> {
+ return coefficient * x
+ }
+
+ // Case 2: custom VJP capturing `self`.
+ @differentiable(wrt: (x))
+ func apply2(to x: Tracked<Float>) -> Tracked<Float> {
+ return coefficient * x
+ }
+ @derivative(of: apply2)
+ final func vjpApply2(
+ to x: Tracked<Float>
+ ) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (coefficient * x, { v in self.coefficient * v })
+ }
+
+ // Case 3: custom VJP capturing `self.coefficient`.
+ @differentiable(wrt: x)
+ func apply3(to x: Tracked<Float>) -> Tracked<Float> {
+ return coefficient * x
+ }
+ @derivative(of: apply3)
+ final func vjpApply3(
+ to x: Tracked<Float>
+ ) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ let coefficient = self.coefficient
+ return (coefficient * x, { v in coefficient * v })
+ }
+ }
+
+ func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ let m = Multiplier(10)
+ let result = m.apply(to: x)
+ m.coefficient += 1
+ return result
+ }
+ expectEqual(10, gradient(at: 1, in: f))
+
+ func f2(_ x: Tracked<Float>) -> Tracked<Float> {
+ let m = Multiplier(10)
+ let result = m.apply2(to: x)
+ m.coefficient += 1
+ return result
+ }
+ expectEqual(11, gradient(at: 1, in: f2))
+
+ func f3(_ x: Tracked<Float>) -> Tracked<Float> {
+ let m = Multiplier(10)
+ let result = m.apply3(to: x)
+ m.coefficient += 1
+ return result
+ }
+ expectEqual(10, gradient(at: 1, in: f3))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/class_method_thunk/main.swift b/test/AutoDiff/downstream/class_method_thunk/main.swift
new file mode 100644
index 0000000..8989aab
--- /dev/null
+++ b/test/AutoDiff/downstream/class_method_thunk/main.swift
@@ -0,0 +1,22 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift %S/../Inputs/class_method_thunk_other_module.swift %s -o %t/a.out
+// RUN: %target-codesign %t/a.out
+// RUN: %target-run %t/a.out
+
+// REQUIRES: executable_test
+
+import StdlibUnittest
+
+var ClassMethodThunkTests = TestSuite("ClassMethodThunks")
+
+func classValueWithGradient(_ c: OtherModuleSuper) -> (Float, Float) {
+ return valueWithGradient(at: 1) { c.f($0) }
+}
+
+ClassMethodThunkTests.test("CrossModuleClassMethodThunks") {
+ expectEqual((2, 2), classValueWithGradient(OtherModuleSuper()))
+ expectEqual((3, 3), classValueWithGradient(OtherModuleSubOverride()))
+ expectEqual((3, 3), classValueWithGradient(OtherModuleSubOverrideCustomDerivatives()))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/closures.swift b/test/AutoDiff/downstream/closures.swift
new file mode 100644
index 0000000..7dbd227
--- /dev/null
+++ b/test/AutoDiff/downstream/closures.swift
@@ -0,0 +1,72 @@
+// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
+
+struct Foo {
+ var x: Float
+ var f: @differentiable (Float) -> Float
+}
+func diffableClosureInStruct(s: Foo) {
+ _ = gradient(of: s.f)
+}
+
+// CHECK-LABEL: @{{.*}}diffableClosureInStruct{{.*}} : $@convention(thin) (@guaranteed Foo) -> () {
+// CHECK: [[CLOSURE:%.*]] = struct_extract {{%.*}} : $Foo, #Foo.f
+// CHECK: retain_value [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK: differentiable_function_extract [original] [[CLOSURE]] : $@differentiable @callee_guaranteed (Float) -> Float
+
+struct InoutAliasableCapture {
+ var x: Float = .zero
+ mutating func foo() {
+ func capturesMutableSelf(t: Float) -> Float {
+ self.x = .zero
+ return t
+ }
+ _ = gradient(at: .zero, in: capturesMutableSelf)
+ }
+}
+
+// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
+// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
+// CHECK: [[JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0] [results 0] @{{.*}}capturesMutableSelf{{.*}} : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> Float
+// CHECK-NOT: retain_value_addr [[SELF]]
+// CHECK-NOT: copy_addr [[SELF]]
+// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK: [[VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0] [results 0] @{{.*}}capturesMutableSelf{{.*}} : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> Float
+// CHECK-NOT: retain_value_addr [[SELF]]
+// CHECK-NOT: retain_value_addr [[SELF]]
+// CHECK-NOT: copy_addr [[SELF]]
+// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+public func closureCaptureMutable() {
+ var val: Float = 10
+ _ = gradient(at: 0) { (x: Float) -> Float in
+ val += 2
+ return val * x
+ }
+}
+
+// CHECK-LABEL: @AD__{{.*}}closureCaptureMutable{{.*}}___vjp_src_0_wrt_0
+// CHECK: bb0({{%.*}} : $Float, [[INOUT_ARG:%.*]] : ${ var Float }):
+// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___pullback_src_0_wrt_0
+// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})
+
+// TF-30: VJP return value should match the return type.
+struct TF_30 : Differentiable {
+ var x: Float
+ @noDerivative var y: @differentiable (Float) -> Float
+}
+// Make sure this passes SIL verification.
+let _: @differentiable (TF_30) -> Float = { s in s.x }
+
+// Make sure `@noDerivative` gets propagated through SIL.
+// Make sure `@noDerivative` with non-`Differentiable` also works.
+public func nondiffs(_ f: @differentiable (Float, @noDerivative Float) -> Float,
+ _ g: @differentiable (Float, @noDerivative Int) -> Float) {
+ _ = gradient(at: 0) { f($0, 1) }
+ _ = gradient(at: 0) { g($0, 1) }
+}
+nondiffs({ x, y in x }, { x, y in x })
+
+// Crasher when SILGen'ing @differentiable functions with generic @noDerivative parameters.
+func foo<T>(_ f: @differentiable (Float, @noDerivative T) -> Float, _ t: T) -> Float {
+ return f(1, t)
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers/tf1011-differential-unset-tangent-buffer.swift b/test/AutoDiff/downstream/compiler_crashers/tf1011-differential-unset-tangent-buffer.swift
new file mode 100644
index 0000000..e231c23
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers/tf1011-differential-unset-tangent-buffer.swift
@@ -0,0 +1,33 @@
+// RUN: not --crash %target-swift-emit-sil -enable-experimental-forward-mode-differentiation %s -verify
+// REQUIRES: asserts
+
+// TF-1011: Differential generation crash due to unset tangent buffer.
+
+@differentiable
+func arrayLiteral(_ x: Float, _ y: Float) -> [Float] {
+ var result = [x * y, x * y]
+ return result
+}
+
+// [AD] Original bb0: To differentiate or not to differentiate?
+// [ ] debug_value %0 : $Float, let, name "x", argno 1 // id: %2
+// [ ] debug_value %1 : $Float, let, name "y", argno 2 // id: %3
+// [∂] %4 = alloc_stack $Array<Float>, var, name "result" // users: %26, %25, %21, %22
+// [ ] %5 = integer_literal $Builtin.Word, 2 // user: %7
+// [ ] // function_ref _allocateUninitializedArray<A>(_:)
+// %6 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) // user: %7
+// [∂] %7 = apply %6<Float>(%5) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) // user: %8
+// [∂] (%8, %9) = destructure_tuple %7 : $(Array<Float>, Builtin.RawPointer) // users: %21, %10
+// [ ] %10 = pointer_to_address %9 : $Builtin.RawPointer to [strict] $*Float // users: %16, %14
+// [ ] %11 = metatype $@thin Float.Type // user: %13
+// [ ] // function_ref static Float.* infix(_:_:)
+// %12 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %13
+// [∂] %13 = apply %12(%0, %1, %11) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %14
+// [∂] store %13 to [trivial] %10 : $*Float // id: %14
+// ...
+// [AD] JVPEmitter visited:
+// [ORIG] store %13 to [trivial] %10 : $*Float // id: %14
+// Assertion failed: (!insertion.second && "tangent buffer should already exist"), function getTangentBuffer, file swift/lib/SILOptimizer/Mandatory/Differentiation.cpp, line 4528.
+
+// `store %13 to [trivial] %10` is visited but `%10 = pointer_to_address %9` is
+// not. `%10` does not have a set tangent buffer.
diff --git a/test/AutoDiff/downstream/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift b/test/AutoDiff/downstream/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift
new file mode 100644
index 0000000..4461f5e
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift
@@ -0,0 +1,28 @@
+// RUN: not --crash %target-swift-frontend-typecheck -primary-file %s
+// REQUIRES: asserts
+
+// TF-1063: `@transpose` attribute type-checking crash for static methods in
+// `AnyFunctionType::getTransposeOriginalFunctionType`.
+
+struct Struct: Differentiable & AdditiveArithmetic {
+ static func staticMethod(x: Struct) -> Struct {
+ x
+ }
+
+ @transpose(of: staticMethod, wrt: 0)
+ static func transposeStaticMethod() -> Struct {
+ self
+ }
+}
+
+// Assertion failed: (!empty()), function back, file llvm-project/llvm/include/llvm/ADT/ArrayRef.h, line 158.
+// Stack dump:
+// ...
+// 1. Swift version 5.2-dev (Swift ace3925395)
+// 2. While evaluating request TypeCheckSourceFileRequest(source_file "test/AutoDiff/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift", 0)
+// 3. While type-checking 'Struct' (at test/AutoDiff/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift:6:1)
+// 4. While type-checking 'transposeStaticMethod()' (at test/AutoDiff/compiler_crashers/tf1063-transpose-attr-typechecking-static-method.swift:12:3)
+// ...
+// 7 swiftc 0x00000001111003c3 swift::AnyFunctionType::getTransposeOriginalFunctionType(swift::IndexSubset*, bool) (.cold.3) + 35
+// 8 swiftc 0x000000010dbf68e0 swift::AnyFunctionType::getTransposeOriginalFunctionType(swift::IndexSubset*, bool) + 1904
+// 9 swiftc
diff --git a/test/AutoDiff/downstream/compiler_crashers/tf1122-maximally-abstracted-differentiable-function.swift b/test/AutoDiff/downstream/compiler_crashers/tf1122-maximally-abstracted-differentiable-function.swift
new file mode 100644
index 0000000..f15a81e
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers/tf1122-maximally-abstracted-differentiable-function.swift
@@ -0,0 +1,11 @@
+// RUN: %target-build-swift %s -o %t
+// RUN: not --crash %target-run %t
+// REQUIRES: executable_test
+// REQUIRES: swift_test_mode_optimize_none
+
+func id<T>(_ t: T) -> T { t }
+func foo<X: Differentiable>(_ x: X) {
+ let f: @differentiable (X) -> X = { $0 }
+ let _ = id(f)
+}
+foo(Float(1))
diff --git a/test/AutoDiff/downstream/compiler_crashers/tf429-enable-library-evolution.swift b/test/AutoDiff/downstream/compiler_crashers/tf429-enable-library-evolution.swift
new file mode 100644
index 0000000..acdf8f4
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers/tf429-enable-library-evolution.swift
@@ -0,0 +1,20 @@
+// RUN: not --crash %target-swift-emit-sil -enable-library-evolution %s
+// REQUIRES: asserts
+
+// TF-429: Differentiation transform does not support
+// `-enable-library-evolution` because it assumes that differential/pullback
+// structs are always loadable, i.e. have object value category.
+
+// Function must be public to trigger library evolution crash.
+@differentiable
+public func TF_429(_ x: Float) -> Float { x }
+
+// Assertion failed: (mainPullbackStruct->getType() == pbStructLoweredType), function run, file /Users/danielzheng/swift-merge/swift/lib/SILOptimizer/Mandatory/Differentiation.cpp, line 6279.
+// Stack dump:
+// ...
+// 1. Swift version 5.1.1-dev (Swift c3cdcba346)
+// 2. While running pass #17 SILModuleTransform "Differentiation".
+// ...
+// 7 swiftc 0x0000000101620642 (anonymous namespace)::PullbackEmitter::run() + 3122
+// 8 swiftc 0x00000001015cb1e8 (anonymous namespace)::VJPEmitter::run() + 1224
+// 9 swiftc 0x00000001015c3348 (anonymous namespace)::ADContext::processDifferentiableAttribute(swift::SILFunction*, swift::SILDifferentiableAttr*, (anonymous namespace)::DifferentiationInvoker) + 4536
diff --git a/test/AutoDiff/downstream/compiler_crashers/tf756-irgen-witness-method-archetype.swift b/test/AutoDiff/downstream/compiler_crashers/tf756-irgen-witness-method-archetype.swift
new file mode 100644
index 0000000..d567fb3
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers/tf756-irgen-witness-method-archetype.swift
@@ -0,0 +1,36 @@
+// RUN: not --crash %target-swift-emit-ir -primary-file %s
+// REQUIRES: asserts
+
+// TF-756: IRGen crash for `witness_method` instruction generated by the
+// differentiation transform.
+
+struct Tensor<Scalar> {}
+extension Tensor: Differentiable where Scalar == Float {}
+
+extension Tensor where Scalar == Float {
+ // Arbitrary `@differentiable` operation with >1 parameter, so that index
+ // subset thunk may be generated.
+ @differentiable
+ static func + (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
+ return lhs
+ }
+
+ @derivative(of: +)
+ static func _vjpAdd(lhs: Tensor, rhs: Tensor)
+ -> (value: Tensor, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs + rhs, { v in (v, v) })
+ }
+}
+
+@differentiable
+func TF_756(input: Tensor<Float>) -> Tensor<Float> {
+ let other = Tensor<Float>()
+ return other + input
+}
+
+// Assertion failed: (!type->hasArchetype() && !type->hasTypeParameter()), function getAddrOfTypeMetadataAccessFunction, file /Users/danielzheng/swift-merge/swift/lib/IRGen/GenDecl.cpp, line 3352.
+// Stack dump:
+// ...
+// 1. Swift version 5.1.1-dev (Swift 3943c1e36b)
+// 2. While emitting IR SIL function "@AD__$s4main6TensorVAASfRszlE13TangentVectorVySf_GA2FIegyyd_A2FIegyd_TR_src_0_wrt_1_differential_index_subset_thunk".
+// for expression at [/Users/danielzheng/swift-merge/swift/test/AutoDiff/compiler_crashers/tf756-irgen-witness-method-archetype.swift:27:16 - line:27:16] RangeText=""
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf1202-eliminating-differentiability-witness-original-function.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf1202-eliminating-differentiability-witness-original-function.swift
new file mode 100644
index 0000000..9639df1
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf1202-eliminating-differentiability-witness-original-function.swift
@@ -0,0 +1,7 @@
+@inlinable
+@differentiable(where T: Differentiable)
+public func identity<T>(_ x: T) -> T { x }
+
+public func foo<T: Differentiable>(_ f: @differentiable (T) -> T = identity) -> T {
+ fatalError()
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf953-invalid-differentiable-attr-other-module.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf953-invalid-differentiable-attr-other-module.swift
new file mode 100644
index 0000000..0c48009
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf953-invalid-differentiable-attr-other-module.swift
@@ -0,0 +1,6 @@
+// Invalid `@differentiable` attribute in non-primary-file should not crash
+// SILGen.
+@differentiable
+func foo(_ x: Int) -> Float {
+ return Float(x)
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf960-irgen-diff-witness-no-derivatives-other-module.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf960-irgen-diff-witness-no-derivatives-other-module.swift
new file mode 100644
index 0000000..69cbd1c
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/Inputs/tf960-irgen-diff-witness-no-derivatives-other-module.swift
@@ -0,0 +1,6 @@
+struct TF_960: Differentiable {
+ @differentiable
+ func callAsFunction(_ input: Float) -> Float {
+ return input
+ }
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1017-diff-attr-invalid-original-interface-type.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1017-diff-attr-invalid-original-interface-type.swift
new file mode 100644
index 0000000..9f72c31
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1017-diff-attr-invalid-original-interface-type.swift
@@ -0,0 +1,26 @@
+// RUN: %target-swift-frontend -typecheck %s -verify
+// REQUIRES: asserts
+
+// TF-1017: `@differentiable` attribute type-checking crash when original
+// declaration has an error type.
+
+protocol P: Differentiable {
+ @differentiable
+ init(x: Float)
+}
+extension P {
+ @differentiable
+ // expected-error @+1 {{generic parameter 'U' is not used in function signature}}
+ init<U>(_ x: Float) {
+ self.init(x: x)
+ }
+}
+extension P where Self: FloatingPoint {
+ @differentiable
+ func hello(_ x: Float) -> Self {
+ .init(x: x)
+ }
+}
+
+// Assertion failed: (isa<X>(Val) && "cast<Ty>() argument of incompatible type!"), function cast, file llvm/include/llvm/Support/Casting.h, line 264.
+// Assertion failed: (D), function printDifferentiableAttrArguments, file swift/lib/AST/Attr.cpp, line 493.
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1033-differential-ownership-uninit-memory.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1033-differential-ownership-uninit-memory.swift
new file mode 100644
index 0000000..21066ec
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1033-differential-ownership-uninit-memory.swift
@@ -0,0 +1,21 @@
+// RUN: %target-swift-emit-sil -enable-experimental-forward-mode-differentiation %s
+// REQUIRES: asserts
+
+// TF-1033: Ownership verification error in differential function generated by
+// the differentiation transform.
+// Differential generation sees a `copy_addr [take]` in original function and
+// emits a `copy_addr [take]` in the differential function, but this is
+// problematic because the source buffer in the differential function is not
+// fully initialized.
+
+@_silgen_name("tuple")
+@differentiable
+func tupleInitialNonactive<T: AdditiveArithmetic & Differentiable>(_ x: T) -> T {
+ var tuple = (T.zero, T.zero)
+ tuple.0 = x
+ return tuple.0
+}
+
+// SIL memory lifetime failure in @AD__tuple__differential_src_0_wrt_0: memory is not initialized, but should
+// memory location: %15 = tuple_element_addr %14 : $*(τ_0_0.TangentVector, τ_0_0.TangentVector), 0 // user: %16
+// at instruction: copy_addr [take] %7 to %10 : $*τ_0_0.TangentVector // id: %11
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1037-unmangled-derivative-generic-signatures.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1037-unmangled-derivative-generic-signatures.swift
new file mode 100644
index 0000000..41b0c1c
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1037-unmangled-derivative-generic-signatures.swift
@@ -0,0 +1,59 @@
+// RUN: %target-swift-emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-1037: Differentiation transform crashes due to multiple SIL
+// differentiability witnesses with same parameter indices but different
+// derivative generic signatures. Since derivative generic signatures are
+// currently not mangled in derivative function names (TF-680), there is a
+// name clash.
+// Detailed explanation: https://github.com/apple/swift/pull/28621#issuecomment-562763390
+
+// Test small derivative generic signature difference.
+protocol P1: Differentiable {}
+extension P1 {
+ @differentiable // derivative generic signature: none
+ func foo() -> Float { 1 }
+}
+extension P1 {
+ @derivative(of: foo) // derivative generic signature: `<P1 where Self: P1>`
+ func vjpFoo() -> (value: Float, pullback: (Float) -> (TangentVector)) {
+ fatalError()
+ }
+}
+
+// Test bigger derivative generic signature difference.
+protocol P2: Differentiable {}
+extension P2 {
+ @differentiable // derivative generic signature: none
+ func foo() -> Float { 1 }
+}
+extension P2 where Self: AdditiveArithmetic {
+ // derivative generic signature: `<P2 where Self: P2, Self: AdditiveArithmetic>`
+ @derivative(of: foo)
+ func vjpFoo() -> (value: Float, pullback: (Float) -> (TangentVector)) {
+ fatalError()
+ }
+}
+
+// // AD__$s4main2P1PAAE3fooSfyF__vjp_src_0_wrt_0
+// sil hidden [thunk] [always_inline] [ossa] @AD__$s4main2P1PAAE3fooSfyF__vjp_src_0_wrt_0 : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) {
+// // %0 // user: %2
+// bb0(%0 : $*τ_0_0):
+// // function_ref P1.vjpFoo()
+// %1 = function_ref @$s4main2P1PAAE6vjpFooSf5value_13TangentVectorQzSfc8pullbacktyF : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) // user: %2
+// %2 = apply %1<τ_0_0>(%0) : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) // user: %3
+// return %2 : $(Float, @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) // id: %3
+// } // end sil function 'AD__$s4main2P1PAAE3fooSfyF__vjp_src_0_wrt_0'
+//
+// Assertion failed: (!entry->getValue() && "function already exists"), function create, file /Users/danielzheng/swift-bart/swift/lib/SIL/SILFunction.cpp, line 74.
+// Stack dump:
+// 0. Program arguments: /Users/danielzheng/swift-bart/build/Ninja-ReleaseAssert+stdlib-Release/swift-macosx-x86_64/bin/swift -frontend -interpret swift/test/AutoDiff/compiler_crashers/tf1037-multiple-differentiable-derivative-attributes.swift -enable-objc-interop -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -color-diagnostics -module-name main
+// 1. Swift version 5.1.1-dev (Swift f70940798d)
+// 2. While running pass #52 SILModuleTransform "Differentiation".
+// 3. While processing // differentiability witness for P1.foo()
+// sil_differentiability_witness hidden [parameters 0] [results 0] @$s4main2P1PAAE3fooSfyF : $@convention(method) <Self where Self : P1> (@in_guaranteed Self) -> Float {
+// }
+// on SIL function "@$s4main2P1PAAE3fooSfyF".
+// for 'foo()' (at swift/test/AutoDiff/compiler_crashers/tf1037-multiple-differentiable-derivative-attributes.swift:15:3)
+// 4. While creating SIL function "@AD__$s4main2P1PAAE3fooSfyF__vjp_src_0_wrt_0".
+// for 'foo()' (at swift/test/AutoDiff/compiler_crashers/tf1037-multiple-differentiable-derivative-attributes.swift:15:3)
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1039-cloned-curry-thunk-verification.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1039-cloned-curry-thunk-verification.swift
new file mode 100644
index 0000000..e54a119
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1039-cloned-curry-thunk-verification.swift
@@ -0,0 +1,47 @@
+// RUN: %target-swift-emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-1039: Cloned curry thunks generated by differentiation should create a
+// `differentiable_function` instruction before any `dealloc_stack` instructions
+// to prevent `alloc_stack`/`dealloc_stack` ordering issues.
+
+protocol P {
+ @differentiable
+ func foo(_ x: Float) -> Float
+}
+struct S: P {
+ @differentiable
+ func foo(_ x: Float) -> Float { x }
+}
+func foo<T: P>(_ x: T) {
+ // Curry thunk emitted here for `x.foo`.
+ _ = gradient(at: 1, in: x.foo)
+}
+
+// SIL verification failed: stack dealloc does not match most recent stack alloc: op == state.Stack.back()
+// Verifying instruction:
+// %2 = alloc_stack $τ_0_0 // users: %7, %5, %9, %8, %3
+// -> dealloc_stack %2 : $*τ_0_0 // id: %9
+// In function:
+// // AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0
+// sil shared [thunk] @AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> @owned @differentiable @callee_guaranteed (Float) -> Float {
+// // %0 // user: %3
+// bb0(%0 : $*τ_0_0):
+// %1 = witness_method $τ_0_0, #P.foo!1 : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> Float // user: %8
+// %2 = alloc_stack $τ_0_0 // users: %7, %5, %9, %8, %3
+// copy_addr %0 to [initialization] %2 : $*τ_0_0 // id: %3
+// %4 = alloc_stack $τ_0_0 // users: %15, %11, %5
+// copy_addr %2 to [initialization] %4 : $*τ_0_0 // id: %5
+// %6 = alloc_stack $τ_0_0 // users: %14, %13, %7
+// copy_addr %2 to [initialization] %6 : $*τ_0_0 // id: %7
+// %8 = partial_apply [callee_guaranteed] %1<τ_0_0>(%2) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> Float // user: %16
+// dealloc_stack %2 : $*τ_0_0 // id: %9
+// %10 = witness_method $τ_0_0, #P.foo!1.jvp.SU : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %11
+// %11 = partial_apply [callee_guaranteed] %10<τ_0_0>(%4) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %16
+// %12 = witness_method $τ_0_0, #P.foo!1.vjp.SU : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %13
+// %13 = partial_apply [callee_guaranteed] %12<τ_0_0>(%6) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %16
+// dealloc_stack %6 : $*τ_0_0 // id: %14
+// dealloc_stack %4 : $*τ_0_0 // id: %15
+// %16 = differentiable_function [parameters 0] %8 : $@callee_guaranteed (Float) -> Float with_derivative {%11 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %13 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} // user: %17
+// return %16 : $@differentiable @callee_guaranteed (Float) -> Float // id: %17
+// } // end sil function 'AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0'
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1126-derivative-generic-specialization.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1126-derivative-generic-specialization.swift
new file mode 100644
index 0000000..ad6d462
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1126-derivative-generic-specialization.swift
@@ -0,0 +1,56 @@
+// RUN: %target-swift-frontend -O -emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-1126: Generic specialization crash during capture propagation.
+// Related to `@differentiable` function with `partial_apply` operands,
+// to be specialized. Occurs only with `-O`.
+
+struct A: Differentiable{
+ var b: SIMD8<Float>
+}
+
+@differentiable
+func function(a: A) -> A {
+ var a = a
+ a.b = a.b - SIMD8<Float>(repeating: 1.0)
+ return a
+}
+
+let masks: [SIMD8<Float>] = [[1,0,0,0,0,0,0,0],
+ [0,1,0,0,0,0,0,0],
+ [0,0,1,0,0,0,0,0],
+ [0,0,0,1,0,0,0,0],
+ [0,0,0,0,1,0,0,0],
+ [0,0,0,0,0,1,0,0],
+ [0,0,0,0,0,0,1,0],
+ [0,0,0,0,0,0,0,1]]
+
+extension SIMD8 where Scalar == Float{
+ @differentiable(where Scalar: Differentiable)
+ func updated(at index: Int, with newValue: Scalar) -> Self {
+ let mask = masks[index]
+ let result = self - (self * mask) + (newValue * mask)
+ return result
+ }
+}
+
+// Looking for a function: $ss4SIMDPss14DifferentiableRzSB6Scalars11SIMDStoragePRpzsAA13TangentVectorsACPRpzSBAhI_AdFRPzrlE12_vjpSubtract3lhs3rhsx5value_AJ_AJtAJc8pullbacktx_xtFZs5SIMD8VySfG_Tg5
+// Expected type: @convention(method) (@in_guaranteed SIMD8<Float>, @in_guaranteed SIMD8<Float>, @thick SIMD8<Float>.Type) -> (@out SIMD8<Float>, @owned @callee_guaranteed (@in_guaranteed SIMD8<Float>) -> (@out SIMD8<Float>, @out SIMD8<Float>))
+// Found type: @convention(method) (SIMD8<Float>, SIMD8<Float>, @thick SIMD8<Float>.Type) -> (@out SIMD8<Float>, @owned @callee_guaranteed (@in_guaranteed SIMD8<Float>) -> (@out SIMD8<Float>, @out SIMD8<Float>))
+// Assertion failed: (ReInfo.getSpecializedType() == SpecializedF->getLoweredFunctionType() && "Previously specialized function does not match expected type."), function lookupSpecialization, file /Users/swiftninjas/s4tf/swift/lib/SILOptimizer/Utils/Generics.cpp, line 1833.
+// Stack dump:
+// ...
+// 1. Swift version 5.2-dev (Swift bf631dc2e4)
+// 2. While running pass #113021 SILFunctionTransform "CapturePropagation" on SILFunction "@AD__$ss5SIMD8V6deleteSfRszrlE7updated2at4withABySfGSi_SftF__vjp_src_0_wrt_1_2".
+// for 'updated(at:with:)' (at /Users/porter/Dropbox (PassiveLogic)/Team/Team Members Scratch Space/Porter/Experiments/Playgrounds/delete/delete/main.swift:75:5)
+// llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
+// llvm::sys::RunSignalHandlers() + 85
+// SignalHandler(int) + 278
+// _sigtramp + 29
+// _sigtramp + 2821162056
+// abort + 127
+// basename_r + 0
+// swift::GenericFuncSpecializer::lookupSpecialization() (.cold.1) + 35
+// swift::GenericFuncSpecializer::lookupSpecialization() + 2109
+// (anonymous namespace)::CapturePropagation::optimizePartialApply(swift::PartialApplyInst*) + 1301
+// (anonymous namespace)::CapturePropagation::run() + 265
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1159-function-conversion-begin-borrow.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1159-function-conversion-begin-borrow.swift
new file mode 100644
index 0000000..043d1f2
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1159-function-conversion-begin-borrow.swift
@@ -0,0 +1,35 @@
+// RUN: %target-swift-frontend -emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-1159: `begin_borrow` instruction unhandled in the
+// `reapplyFunctionConversion` helper function.
+
+func id<T>(_ x: T) -> T { x }
+
+@differentiable
+func TF_1159(_ x: Float) -> Float {
+ // Note: code below generates `partial_apply` and `begin_borrow`.
+ let fn: (Float) -> Float = id
+ return fn(x)
+}
+
+// Unhandled function conversion instruction
+// UNREACHABLE executed at swift/lib/SILOptimizer/Mandatory/Differentiation.cpp:433!
+// Stack dump:
+// ...
+// 1. Swift version 5.2-dev (Swift 415d33b3f1)
+// 2. While running pass #27 SILModuleTransform "Differentiation".
+// 3. While canonicalizing `differentiable_function` SIL node %12 = differentiable_function [parameters 0] %10 : $@callee_guaranteed (Float) -> Float // users: %17, %13
+// 4. While ...in SIL function "@AD__$s4main3fooyS2fF__vjp_src_0_wrt_0".
+// for 'foo(_:)' (at tf-1159.swift:4:1)
+// 0 swift 0x0000000107b15105 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
+// 1 swift 0x0000000107b14078 llvm::sys::RunSignalHandlers() + 248
+// 2 swift 0x0000000107b15706 SignalHandler(int) + 278
+// 3 libsystem_platform.dylib 0x00007fff68ed2b5d _sigtramp + 29
+// 4 libsystem_platform.dylib 0x0000000000000053 _sigtramp + 2534593811
+// 5 libsystem_c.dylib 0x00007fff68d8c6a6 abort + 127
+// 6 swift 0x0000000108dcc09e llvm::llvm_unreachable_internal(char const*, char const*, unsigned int) + 462
+// 7 swift 0x0000000104047ec2 reapplyFunctionConversion(swift::autodiff::ADContext&, swift::SILValue, swift::SILValue, swift::SILValue, swift::SILBuilder&, swift::SILLocation, llvm::SmallVectorImpl<swift::AllocStackInst*>&, swift::IndexSubset*, swift::GenericSignature) + 1506
+// 8 swift 0x000000010402d5c0 (anonymous namespace)::DifferentiationTransformer::promoteToDifferentiableFunction(swift::DifferentiableFunctionInst*, swift::SILBuilder&, swift::SILLocation, swift::autodiff::DifferentiationInvoker) + 8880
+// 9 swift 0x00000001040292ea (anonymous namespace)::DifferentiationTransformer::processDifferentiableFunctionInst(swift::DifferentiableFunctionInst*) + 426
+// 10 swift 0x0000000104026b06 (anonymous namespace)::Differentiation::run() + 1174
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1202-eliminating-differentiability-witness-original-function.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1202-eliminating-differentiability-witness-original-function.swift
new file mode 100644
index 0000000..c6d72f0
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1202-eliminating-differentiability-witness-original-function.swift
@@ -0,0 +1,15 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -emit-module -module-name tf1202 -emit-module-path %t/tf1202.swiftmodule %S/Inputs/tf1202-eliminating-differentiability-witness-original-function.swift
+// RUN: %target-build-swift -I%t -emit-module -O %s
+
+// This situation exposed a bug where DeadFunctionElimination eliminated the
+// SILFunction for `identity<T>` even though a differentiability witness for it
+// still existed. This causes deserialization of this module to crash when
+// trying to deserialize the differentiability witness because it can't find
+// the original function `identity<T>`.
+
+import tf1202
+
+func callit() -> Float {
+ return foo()
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf1204-pullback-inout-subset-parameters-thunk.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1204-pullback-inout-subset-parameters-thunk.swift
new file mode 100644
index 0000000..d226ed9
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf1204-pullback-inout-subset-parameters-thunk.swift
@@ -0,0 +1,50 @@
+// RUN: %target-swift-emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-1204: Subset parameters thunk crash for original function with `inout`
+// parameters.
+
+struct Convolution<T>: Differentiable
+where T: Differentiable, T == T.TangentVector {
+ var bias: T
+
+ @differentiable(wrt: self)
+ @differentiable
+ func callAsFunction(_ input: T) -> T {
+ var result = withoutDerivative(at: bias)
+ infer(result: &result, input: input, bias: bias)
+ return result
+ }
+
+ @differentiable
+ func infer(result: inout T, input: T, bias: T) {
+ fatalError()
+ }
+
+ @derivative(of: infer)
+ func _vjpInfer(result: inout T, input: T, bias: T)
+ -> (
+ value: Void, pullback: (inout T) -> (Convolution<T>.TangentVector, T, T)
+ )
+ {
+ fatalError()
+ }
+}
+
+// Original crasher:
+// Assertion failed: (origFnType->getResults().size() == 1), function getOrCreateSubsetParametersThunkForDerivativeFunction, file /Users/swiftninjas/s4tf/swift/lib/SILOptimizer/Utils/Differentiation/Thunk.cpp, line 812.
+// Stack dump:
+// 1. Swift version 5.2-dev (LLVM b3057cffb6, Swift c8bea53782)
+// 2. While running pass #135 SILModuleTransform "Differentiation".
+// 3. While canonicalizing `differentiable_function` SIL node %22 = differentiable_function [parameters 0 2 3] %18 : $@callee_guaranteed (@inout τ_0_0, @in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @in_guaranteed Convolution<τ_0_0>) -> () // users: %27, %23
+// 4. While ...in SIL function "@AD__$s4conv11ConvolutionV14callAsFunctionyxxF__vjp_src_0_wrt_1_s14DifferentiableRz13TangentVectorsAAPQzRszl".
+// for 'callAsFunction(_:)' (at conv.swift:8:5)
+// 0 swift 0x0000000104c08e75 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
+// 1 swift 0x0000000104c080b5 llvm::sys::RunSignalHandlers() + 85
+// 2 swift 0x0000000104c0945c SignalHandler(int) + 268
+// 3 libsystem_platform.dylib 0x00007fff6deebb5d _sigtramp + 29
+// 4 libsystem_platform.dylib 0x0000000000005290 _sigtramp + 18446603338671822672
+// 5 libsystem_c.dylib 0x00007fff6dda56a6 abort + 127
+// 6 libsystem_c.dylib 0x00007fff6dd6e20d basename_r + 0
+// 7 swift 0x0000000104ec58a3 swift::autodiff::getOrCreateSubsetParametersThunkForDerivativeFunction(swift::SILOptFunctionBuilder&, swift::SILValue, swift::SILValue, swift::AutoDiffDerivativeFunctionKind, swift::SILAutoDiffIndices, swift::SILAutoDiffIndices) (.cold.10) + 35
+// 8 swift 0x000000010136f079 swift::autodiff::getOrCreateSubsetParametersThunkForDerivativeFunction(swift::SILOptFunctionBuilder&, swift::SILValue, swift::SILValue, swift::AutoDiffDerivativeFunctionKind, swift::SILAutoDiffIndices, swift::SILAutoDiffIndices) + 7545
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf123-differentiable-function-opaque-reabstraction.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf123-differentiable-function-opaque-reabstraction.swift
new file mode 100644
index 0000000..8eb5faa
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf123-differentiable-function-opaque-reabstraction.swift
@@ -0,0 +1,24 @@
+// RUN: %target-swift-frontend-emit-silgen %s
+// REQUIRES: asserts
+
+// TF-123: SILGen crashes when reabstracting `@differentiable` functions to
+// opaque abstraction patterns.
+// The culprit is `createAutoDiffThunk` in lib/SILGen/SILGenPoly.cpp.
+
+// Reproducer: cast `@differentiable` function-typed value to `Any`.
+let function: @differentiable (Float) -> Float
+_ = function as Any
+
+// SIL verification failed: JVP type does not match expected JVP type
+// $@callee_guaranteed (@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+
+// Reproducer: create key path to `@differentiable` function-typed value.
+struct TF_123: KeyPathIterable {
+ let function: @differentiable (Float) -> Float
+}
+_ = \TF_123.function
+
+// SIL verification failed: JVP type does not match expected JVP type
+// $@callee_guaranteed (@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf881-derivative-local-variable-capture.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf881-derivative-local-variable-capture.swift
new file mode 100644
index 0000000..61ff0c2
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf881-derivative-local-variable-capture.swift
@@ -0,0 +1,43 @@
+// RUN: %target-swift-emit-silgen %s -verify
+// REQUIRES: asserts
+
+// TF-881: User-defined Swift derivative functions cannot capture local values.
+// Captured local values become extra SIL function arguments, breaking the
+// expected derivative function type logic.
+//
+// In the short term, we should diagnose these cases to prevent crashes.
+// In the long term, we should investigate supporting these cases.
+
+do {
+ let capturedValue: Int = 3
+
+ func original(_ x: Float) -> Float { x }
+
+ // expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
+ @derivative(of: original)
+ func vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ // Reference a local variable.
+ // This causes the top-level SIL function @vjp to have extra arguments.
+ _ = capturedValue
+ return (x, { $0 })
+ }
+}
+
+// Original crasher:
+// SIL verification failed: apply doesn't have right number of arguments for function: site.getNumArguments() == substConv.getNumSILArguments()
+// Verifying instruction:
+// %0 = argument of bb0 : $Float // user: %2
+// // function_ref vjp #1 (_:) in
+// %1 = function_ref @$s4main3vjpL_ySf5value_S2fc8pullbacktSfF : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %2
+// -> %2 = apply %1(%0) : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %3
+// return %2 : $(Float, @callee_guaranteed (Float) -> Float) // id: %3
+// In function:
+// // AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0
+// sil hidden [always_inline] [ossa] @AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+// // %0 // user: %2
+// bb0(%0 : $Float):
+// // function_ref vjp #1 (_:) in
+// %1 = function_ref @$s4main3vjpL_ySf5value_S2fc8pullbacktSfF : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %2
+// %2 = apply %1(%0) : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %3
+// return %2 : $(Float, @callee_guaranteed (Float) -> Float) // id: %3
+// } // end sil function 'AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0'
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift
new file mode 100644
index 0000000..37c8dd7
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift
@@ -0,0 +1,77 @@
+// RUN: %target-swift-frontend -O -emit-ir %s
+// REQUIRES: asserts
+
+// TF-891: Generic specialization crash during capture propagation.
+// Related to `@differentiable` function with `partial_apply` operands,
+// to be specialized. Occurs only with `-O`.
+
+public protocol Protocol: Differentiable {
+ @differentiable
+ func requirement1<T: Protocol>(_ arg: T) -> Float
+
+ @differentiable
+ func requirement2() -> Float
+}
+
+public extension Protocol {
+ @differentiable
+ func requirement1<T: Protocol>(_ arg: T) -> Float {
+ return arg.requirement2()
+ }
+
+ @differentiable
+ func requirement2() -> Float {
+ return 0
+ }
+}
+
+public struct Struct: Protocol {}
+
+@differentiable
+public func func1<T: Protocol>(_ arg1: Float, _ arg2: T) -> Float {
+ return arg2.requirement1(arg2)
+}
+
+@differentiable
+public func func2(_ arg: Struct) -> Float {
+ return func1(0.0, arg)
+}
+
+// swift: /usr/local/google/home/marcrasi/swift-base/swift/lib/AST/ProtocolConformance.cpp:78: swift::ProtocolDecl *swift::ProtocolConformanceRef::getRequirement() const: Assertion `!isInvalid()' failed.
+// Stack dump:
+// 0. Program arguments: /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift -frontend -target x86_64-unknown-linux-gnu -module-cache-path /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 4 -ignore-module-source-info -typo-correction-limit 10 -O -emit-ir /usr/local/google/home/marcrasi/swift-base/swift/test/AutoDiff/generated/generated0002.swift
+// 1. Swift version 5.1.1-dev (LLVM 6e04008c7f, Swift 439808dd48)
+// 2. While running pass #10183 SILFunctionTransform "CapturePropagation" on SILFunction "@AD__orig_$s13generated00025func1yS2f_xtAA8ProtocolRzlF_$sSf13generated00026StructVS3fAC13TangentVectorVIegydr_Iegyndo_SfACS2fAEIegyr_Iegyndo_TR_src_0_wrt_1_vjp_subset_parameters_thunk".
+// for expression at [/usr/local/google/home/marcrasi/swift-base/swift/test/AutoDiff/generated/generated0002.swift:35:10 - line:35:24] RangeText="func1(0.0, arg"
+// #0 0x0000000004bebbc4 PrintStackTraceSignalHandler(void*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4bebbc4)
+// #1 0x0000000004be97de llvm::sys::RunSignalHandlers() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4be97de)
+// #2 0x0000000004bebe78 SignalHandler(int) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4bebe78)
+// #3 0x00007f1ffb1283a0 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x123a0)
+// #4 0x00007f1ffa25fcfb raise (/lib/x86_64-linux-gnu/libc.so.6+0x36cfb)
+// #5 0x00007f1ffa24a8ad abort (/lib/x86_64-linux-gnu/libc.so.6+0x218ad)
+// #6 0x00007f1ffa24a77f (/lib/x86_64-linux-gnu/libc.so.6+0x2177f)
+// #7 0x00007f1ffa258542 (/lib/x86_64-linux-gnu/libc.so.6+0x2f542)
+// #8 0x000000000171bf3d (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x171bf3d)
+// #9 0x0000000000ecf489 swift::WitnessMethodInst::create(swift::SILDebugLocation, swift::CanType, swift::ProtocolConformanceRef, swift::SILDeclRef, swift::SILType, swift::SILFunction*, swift::SILOpenedArchetypesState&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xecf489)
+// #10 0x0000000000e0cdbc swift::SILCloner<swift::GenericCloner>::visitWitnessMethodInst(swift::WitnessMethodInst*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xe0cdbc)
+// #11 0x0000000000e0972c swift::SILCloner<swift::GenericCloner>::visitBlocksDepthFirst(swift::SILBasicBlock*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xe0972c)
+// #12 0x0000000000e0761e swift::SILCloner<swift::GenericCloner>::cloneFunctionBody(swift::SILFunction*, swift::SILBasicBlock*, llvm::ArrayRef<swift::SILValue>) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xe0761e)
+// #13 0x0000000000e07171 swift::GenericCloner::populateCloned() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xe07171)
+// #14 0x0000000000b2ea3f swift::GenericCloner::cloneFunction(swift::SILOptFunctionBuilder&, swift::SILFunction*, swift::ReabstractionInfo const&, swift::SubstitutionMap, llvm::StringRef, std::function<void (swift::SILInstruction*, swift::SILInstruction*)>) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xb2ea3f)
+// #15 0x0000000000b2e7ba swift::GenericFuncSpecializer::tryCreateSpecialization() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xb2e7ba)
+// #16 0x0000000000c5569c (anonymous namespace)::CapturePropagation::optimizePartialApply(swift::PartialApplyInst*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xc5569c)
+// #17 0x0000000000c55158 (anonymous namespace)::CapturePropagation::run() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0xc55158)
+// #18 0x000000000099288e swift::SILPassManager::runPassOnFunction(unsigned int, swift::SILFunction*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x99288e)
+// #19 0x0000000000993913 swift::SILPassManager::runFunctionPasses(unsigned int, unsigned int) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x993913)
+// #20 0x0000000000994c1f swift::SILPassManager::execute() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x994c1f)
+// #21 0x000000000056892b swift::SILPassManager::executePassPipelinePlan(swift::SILPassPipelinePlan const&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x56892b)
+// #22 0x000000000099d0dc swift::runSILOptimizationPasses(swift::SILModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x99d0dc)
+// #23 0x0000000000769d37 swift::CompilerInstance::performSILProcessing(swift::SILModule*, swift::UnifiedStatsReporter*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x769d37)
+// #24 0x00000000004f3d14 performCompileStepsPostSILGen(swift::CompilerInstance&, swift::CompilerInvocation&, std::unique_ptr<swift::SILModule, std::default_delete<swift::SILModule> >, bool, llvm::PointerUnion<swift::ModuleDecl*, swift::SourceFile*>, swift::PrimarySpecificPaths const&, bool, int&, swift::FrontendObserver*, swift::UnifiedStatsReporter*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4f3d14)
+// #25 0x00000000004e9f76 performCompile(swift::CompilerInstance&, swift::CompilerInvocation&, llvm::ArrayRef<char const*>, int&, swift::FrontendObserver*, swift::UnifiedStatsReporter*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4e9f76)
+// #26 0x00000000004e7749 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4e7749)
+// #27 0x0000000000487e21 main (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x487e21)
+// #28 0x00007f1ffa24c52b __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x2352b)
+// #29 0x0000000000487a6a _start (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x487a6a)
+// /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/test-linux-x86_64/AutoDiff/generated/Output/generated0002.swift.script: line 1: 18696 Aborted /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift -frontend -target x86_64-unknown-linux-gnu -module-cache-path '/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache' -swift-version 4 -ignore-module-source-info -typo-correction-limit 10 -O -emit-ir /usr/local/google/home/marcrasi/swift-base/swift/test/AutoDiff/generated/generated0002.swift
+
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf922-adjoint-value-type-mismatch.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf922-adjoint-value-type-mismatch.swift
new file mode 100644
index 0000000..7d10e86
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf922-adjoint-value-type-mismatch.swift
@@ -0,0 +1,29 @@
+// RUN: %target-swift-emit-sil %s
+// REQUIRES: asserts
+
+// TF-922: Adjoint value type mismatch assertion failure during direct adjoint
+// accumulation in differentiation transform pullback generation.
+
+// The cause is an adjoint value accumulation bug in
+// `PullbackEmitter::visitDestructureTupleInst` for tuple values with
+// non-tuple-typed adjoint values. This is relevant for array literal
+// initialization because `array.uninitialized_intrinsic` returns a tuple
+// with a `destructure_tuple` user.
+
+@differentiable
+func TF_922(_ x: Float) -> [Float] {
+ var result = [x]
+ let result2 = true ? result : result
+ let result3 = true ? result : result2
+ return result3
+}
+
+// Assertion failed: (lhs->getType() == rhs->getType() && "Adjoints must have equal types!"), function accumulateDirect, file swift/lib/SILOptimizer/Mandatory/Differentiation.cpp, line 7654.
+// Stack dump:
+// ...
+// 1. Swift version 5.1.1-dev (Swift d89e9d1881)
+// 2. While running pass #59 SILModuleTransform "Differentiation".
+// 3. While canonicalizing `differentiable_function` SIL node %12 = differentiable_function [parameters 0] %11 : $@callee_guaranteed (Float) -> @owned Array<Float> // users: %30, %13
+// 4. While ...in SIL function "@main".
+// 5. While processing `[differentiable source 0 wrt 0]` attribute on SIL function "@$s4main17oneElementLiteralySaySfGSfF".
+// for 'oneElementLiteral(_:)' (at tf-922-array.swift:2:1)
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf923-pullback-ownership-use-after-free.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf923-pullback-ownership-use-after-free.swift
new file mode 100644
index 0000000..4d839b5
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf923-pullback-ownership-use-after-free.swift
@@ -0,0 +1,44 @@
+// RUN: %target-swift-emit-sil %s
+// REQUIRES: asserts
+
+// TF-923: Ownership verification error in pullback function generated by the
+// differentiation transform.
+// Adjoint value of basic block argument is destroyed at end of pullback block
+// (because it's a temporary value), but is also set as adjoint value of
+// incoming values, causing use-after-free errors in pullback successor blocks.
+
+struct Tensor<Scalar> {
+ class Box {
+ init() {}
+ }
+ var box: Box = Box()
+}
+extension Tensor: Equatable where Scalar: Equatable {
+ static func ==(_: Self, _: Self) -> Bool { fatalError() }
+}
+extension Tensor: AdditiveArithmetic where Scalar: AdditiveArithmetic {
+ static var zero: Self { fatalError() }
+ static func +(_: Self, _: Self) -> Self { fatalError() }
+ static func -(_: Self, _: Self) -> Self { fatalError() }
+}
+extension Tensor: Differentiable where Scalar: Differentiable & AdditiveArithmetic {
+ typealias TangentVector = Self
+}
+
+struct Tuple<T: Differentiable & AdditiveArithmetic>: Differentiable {
+ var first: Tensor<T>
+ @noDerivative var second: Tensor<T>
+}
+
+@differentiable(wrt: (input))
+func TF_923<T>(_ input: Tensor<T>, _ bool: Bool) -> Tuple<T> {
+ let x = bool ? input : input
+ return Tuple(first: x, second: x)
+}
+
+// Function: 'AD__$s4main6TF_923yAA5TupleVyxGAA6TensorVyxG_Sbts18AdditiveArithmeticRzs14DifferentiableRzlF__pullback_src_0_wrt_0'
+// Found use after free due to unvisited non lifetime ending uses?!
+// Value: %22 = load [take] %10 : $*Tensor<τ_0_0> // users: %88, %60, %47
+// Remaining Users:
+// User: %60 = copy_value %22 : $Tensor<τ_0_0> // user: %65
+// User: %88 = copy_value %22 : $Tensor<τ_0_0> // user: %93
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf928-pullback-ownership-memory-leak.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf928-pullback-ownership-memory-leak.swift
new file mode 100644
index 0000000..a39632f
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf928-pullback-ownership-memory-leak.swift
@@ -0,0 +1,36 @@
+// RUN: %target-swift-emit-sil %s
+
+// TF-928: Ownership verification error in pullback function generated by the
+// differentiation transform.
+
+struct Tracked<T> {
+ class Box {
+ init() {}
+ }
+ var box: Box = Box()
+}
+extension Tracked : Equatable where T : Equatable {
+ static func ==(_: Self, _: Self) -> Bool { fatalError() }
+}
+extension Tracked : AdditiveArithmetic where T : AdditiveArithmetic {
+ static var zero: Self { fatalError() }
+ static func +(_: Self, _: Self) -> Self { fatalError() }
+ static func -(_: Self, _: Self) -> Self { fatalError() }
+}
+extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
+ typealias TangentVector = Tracked<T.TangentVector>
+}
+
+func TF_928(
+ _ lossFunction: @differentiable (Tracked<Float>, Tracked<Float>) -> Tracked<Float>,
+ _ x: Tracked<Float>
+) {
+ _ = pullback(at: x) { x in lossFunction(x, Tracked<Float>()) }
+ _ = pullback(at: x) { x in lossFunction(Tracked<Float>(), x) }
+}
+
+// Function: 'AD__$s4main6TF_928yyAA7TrackedVySfGAE_AEtXF_AEtFA2EcfU___pullback_src_0_wrt_0'
+// Error! Found a leaked owned value that was never consumed.
+// Value: (%5, **%6**) = destructure_tuple %3 : $(Tracked<Float>, Tracked<Float>)
+// Stack dump:
+// While verifying SIL function "@AD__$s4main6TF_928yyAA7TrackedVySfGAE_AEtXF_AEtFA2EcfU___pullback_src_0_wrt_0".
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf945-activity-analysis-tuple-element-addr.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf945-activity-analysis-tuple-element-addr.swift
new file mode 100644
index 0000000..2f04c18
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf945-activity-analysis-tuple-element-addr.swift
@@ -0,0 +1,14 @@
+// RUN: %target-swift-emit-sil %s
+// REQUIRES: asserts
+
+// TF-945: Activity analysis crash because
+// `DifferentiableActivityInfo::getLookupConformanceFunction` returned a
+// `LookupConformanceFn` (type alias for `llvm::function_ref`), which does not
+// own the underlying callable.
+
+@differentiable
+func TF_945(_ x: Float) -> Float {
+ var result = (x, 1)
+ let (x, y) = result
+ return x
+}
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf953-invalid-differentiable-attr-cross-module.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf953-invalid-differentiable-attr-cross-module.swift
new file mode 100644
index 0000000..66d5354
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf953-invalid-differentiable-attr-cross-module.swift
@@ -0,0 +1,15 @@
+// RUN: not %target-swift-frontend -c -primary-file %s %S/Inputs/tf953-invalid-differentiable-attr-other-module.swift -module-name main
+
+// Verify that invalid `@differentiable` attribute in non-primary file does not
+// crash SILGen.
+
+func bar(_ x: Int) -> Float {
+ return foo(2)
+}
+
+// Assertion failed: (paramIndices && "Parameter indices should have been resolved"), function addFunctionAttributes, file /Users/swiftninjas/s4tf/swift/lib/SIL/SILFunctionBuilder.cpp, line 97.
+// Stack dump:
+// 1. Swift version 5.1.1-dev (Swift e242a8825f)
+// 2. While emitting SIL for 'bar(_:)' (at /Users/danielzheng/swift-merge/swift/test/AutoDiff/compiler_crashers/tf953-invalid-differentiable-attr-cross-module.swift:3:1)
+// 3. While silgen emitFunction SIL function "@$s4main3barySfSiF".
+// for 'bar(_:)' (at /Users/danielzheng/swift-merge/swift/test/AutoDiff/compiler_crashers/tf953-invalid-differentiable-attr-cross-module.swift:3:1)
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf960-irgen-diff-witness-no-derivatives-cross-module.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf960-irgen-diff-witness-no-derivatives-cross-module.swift
new file mode 100644
index 0000000..0987d09
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf960-irgen-diff-witness-no-derivatives-cross-module.swift
@@ -0,0 +1,18 @@
+// RUN: %target-swift-frontend -c %S/Inputs/tf960-irgen-diff-witness-no-derivatives-other-module.swift %s -O -module-name main -num-threads 36
+// REQUIRES: asserts
+
+// TF-960: IRGen crash for uncanonicalized differentiability witnesses.
+// This issue will become obsolete after TF-894, when the differentiation
+// transform canonicalizes differentiability witnesses to have derivative
+// functions and assertions are added to IRGen.
+
+// Stack dump:
+// ...
+// 1. Swift version 5.1.1-dev (Swift af915c09de)
+// 0 swiftc 0x0000000109acda65 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
+// 1 swiftc 0x0000000109acca68 llvm::sys::RunSignalHandlers() + 248
+// 2 swiftc 0x0000000109ace058 SignalHandler(int) + 264
+// 3 libsystem_platform.dylib 0x00007fff728e4b5d _sigtramp + 29
+// 4 libsystem_platform.dylib 0x00007ff639560230 _sigtramp + 3334977264
+// 5 swiftc 0x0000000105b316bb swift::irgen::IRGenerator::emitGlobalTopLevel() + 1307
+// 6 swiftc 0x0000000105bf2572 swift::performIRGeneration(swift::IRGenOptions&, swift::ModuleDecl*, std::__1::unique_ptr<swift::SILModule, std::__1::default_delete<swift::SILModule> >, llvm::StringRef, swift::PrimarySpecificPaths const&, llvm::LLVMContext&, llvm::ArrayRef<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, llvm::GlobalVariable**) + 1682
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf961-protocol-req-loadable-by-address.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf961-protocol-req-loadable-by-address.swift
new file mode 100644
index 0000000..d17c7fb
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf961-protocol-req-loadable-by-address.swift
@@ -0,0 +1,63 @@
+// RUN: %target-swift-frontend -emit-ir %s
+// REQUIRES: asserts
+
+public protocol Protocol00023: Differentiable {
+ @differentiable
+ func requirement00024(_ arg00026: Float) -> Float
+}
+
+public extension Protocol00023 {
+ @differentiable
+ func requirement00024(_ arg00026: Float) -> Float {
+ return 0
+ }
+}
+
+public struct Struct00042: Protocol00023 {
+ public var field00043: Int
+ public var field00045: Float
+ public var field00046: Int
+}
+
+public struct Struct00063 {
+ public var field00064: Struct00042
+ public var field00065: Struct00042
+}
+
+// swift: /usr/local/google/home/marcrasi/swift-base/swift/lib/IRGen/LoadableByAddress.cpp:104: bool isLargeLoadableType(swift::GenericEnvironment *, swift::SILType, irgen::IRGenModule &): Assertion `GenericEnv && "Expected a GenericEnv"' failed.
+// Stack dump:
+// 0. Program arguments: /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift -frontend -target x86_64-unknown-linux-gnu -module-cache-path /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 4 -ignore-module-source-info -typo-correction-limit 10 -emit-ir /usr/local/google/home/marcrasi/swift-base/swift/test/AutoDiff/generated/generated0001.swift
+// 1. Swift version 5.1.1-dev (LLVM 6e04008c7f, Swift 439808dd48)
+// 2. While running pass #192 SILModuleTransform "LoadableByAddress".
+// #0 0x0000000004bebbc4 PrintStackTraceSignalHandler(void*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4bebbc4)
+// #1 0x0000000004be97de llvm::sys::RunSignalHandlers() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4be97de)
+// #2 0x0000000004bebe78 SignalHandler(int) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4bebe78)
+// #3 0x00007f276b2e63a0 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x123a0)
+// #4 0x00007f276a41dcfb raise (/lib/x86_64-linux-gnu/libc.so.6+0x36cfb)
+// #5 0x00007f276a4088ad abort (/lib/x86_64-linux-gnu/libc.so.6+0x218ad)
+// #6 0x00007f276a40877f (/lib/x86_64-linux-gnu/libc.so.6+0x2177f)
+// #7 0x00007f276a416542 (/lib/x86_64-linux-gnu/libc.so.6+0x2f542)
+// #8 0x000000000058185a isLargeLoadableType(swift::GenericEnvironment*, swift::SILType, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x58185a)
+// #9 0x0000000000580b43 LargeSILTypeMapper::getNewSILType(swift::GenericEnvironment*, swift::SILType, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x580b43)
+// #10 0x00000000005819d3 LargeSILTypeMapper::getNewTupleType(swift::GenericEnvironment*, swift::irgen::IRGenModule&, swift::SILType const&, swift::SILType const&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5819d3)
+// #11 0x0000000000580b6d LargeSILTypeMapper::getNewSILType(swift::GenericEnvironment*, swift::SILType, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x580b6d)
+// #12 0x00000000005806cc LargeSILTypeMapper::shouldTransformResults(swift::GenericEnvironment*, swift::CanTypeWrapper<swift::SILFunctionType>, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5806cc)
+// #13 0x00000000005804ec LargeSILTypeMapper::shouldTransformFunctionType(swift::GenericEnvironment*, swift::CanTypeWrapper<swift::SILFunctionType>, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5804ec)
+// #14 0x0000000000580bf5 LargeSILTypeMapper::getNewSILType(swift::GenericEnvironment*, swift::SILType, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x580bf5)
+// #15 0x0000000000580e50 LargeSILTypeMapper::getNewResults(swift::GenericEnvironment*, swift::CanTypeWrapper<swift::SILFunctionType>, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x580e50)
+// #16 0x00000000005810e0 LargeSILTypeMapper::getNewSILFunctionType(swift::GenericEnvironment*, swift::CanTypeWrapper<swift::SILFunctionType>, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5810e0)
+// #17 0x0000000000584e71 (anonymous namespace)::LoadableByAddress::run() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x584e71)
+// #18 0x00000000009941ef swift::SILPassManager::runModulePass(unsigned int) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x9941ef)
+// #19 0x0000000000994c3a swift::SILPassManager::execute() (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x994c3a)
+// #20 0x000000000056892b swift::SILPassManager::executePassPipelinePlan(swift::SILPassPipelinePlan const&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x56892b)
+// #21 0x0000000000568694 runIRGenPreparePasses(swift::SILModule&, swift::irgen::IRGenModule&) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x568694)
+// #22 0x00000000005667b0 performIRGeneration(swift::IRGenOptions&, swift::ModuleDecl*, std::unique_ptr<swift::SILModule, std::default_delete<swift::SILModule> >, llvm::StringRef, swift::PrimarySpecificPaths const&, llvm::LLVMContext&, swift::SourceFile*, llvm::GlobalVariable**) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5667b0)
+// #23 0x00000000005650ae swift::performIRGeneration(swift::IRGenOptions&, swift::ModuleDecl*, std::unique_ptr<swift::SILModule, std::default_delete<swift::SILModule> >, llvm::StringRef, swift::PrimarySpecificPaths const&, llvm::LLVMContext&, llvm::ArrayRef<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, llvm::GlobalVariable**) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x5650ae)
+// #24 0x00000000004f4215 performCompileStepsPostSILGen(swift::CompilerInstance&, swift::CompilerInvocation&, std::unique_ptr<swift::SILModule, std::default_delete<swift::SILModule> >, bool, llvm::PointerUnion<swift::ModuleDecl*, swift::SourceFile*>, swift::PrimarySpecificPaths const&, bool, int&, swift::FrontendObserver*, swift::UnifiedStatsReporter*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4f4215)
+// #25 0x00000000004e9f76 performCompile(swift::CompilerInstance&, swift::CompilerInvocation&, llvm::ArrayRef<char const*>, int&, swift::FrontendObserver*, swift::UnifiedStatsReporter*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4e9f76)
+// #26 0x00000000004e7749 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x4e7749)
+// #27 0x0000000000487e21 main (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x487e21)
+// #28 0x00007f276a40a52b __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x2352b)
+// #29 0x0000000000487a6a _start (/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift+0x487a6a)
+// /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/test-linux-x86_64/AutoDiff/generated/Output/generated0001.swift.script: line 1: 18425 Aborted /usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/bin/swift -frontend -target x86_64-unknown-linux-gnu -module-cache-path '/usr/local/google/home/marcrasi/swift-base/build/Ninja-ReleaseAssert/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache' -swift-version 4 -ignore-module-source-info -typo-correction-limit 10 -emit-ir /usr/local/google/home/marcrasi/swift-base/swift/test/AutoDiff/generated/generated0001.swift
+
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf962-pullback-ownership-tuple-over-consume.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf962-pullback-ownership-tuple-over-consume.swift
new file mode 100644
index 0000000..b43e880
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf962-pullback-ownership-tuple-over-consume.swift
@@ -0,0 +1,26 @@
+// RUN: %target-swift-emit-sil %s
+// REQUIRES: asserts
+
+// TF-962: Ownership verification error in pullback function generated by the
+// differentiation transform.
+// Multiple consuming uses for tuple-typed adjoint value:
+// - `destructure_tuple` (from `PullbackEmitter::visitTupleInst`)
+// - `destroy_value` (for pullback temporary values)
+
+// Key: `Tensor` and `Tensor.TangentVector` must both be non-trivial.
+struct Tensor: Differentiable {
+ var x: [Float] = []
+}
+
+@differentiable
+func TF_962(_ x: Tensor) -> Tensor {
+ let tup = (x, x)
+ if true {}
+ return tup.0
+}
+
+// Function: 'AD__$s4main6TF_962yAA6TensorVADF__pullback_src_0_wrt_0'
+// Found over consume?!
+// Value: %81 = argument of bb7 : $(Tensor.TangentVector, Tensor.TangentVector) // users: %154, %84
+// User: destroy_value %81 : $(Tensor.TangentVector, Tensor.TangentVector) // id: %154
+// Block: bb7
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf964-pullback-tuple-nontuple-adjoint-value.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf964-pullback-tuple-nontuple-adjoint-value.swift
new file mode 100644
index 0000000..af5d76b
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf964-pullback-tuple-nontuple-adjoint-value.swift
@@ -0,0 +1,24 @@
+// RUN: %target-swift-emit-sil %s -verify
+// REQUIRES: asserts
+
+// TF-964: `PullbackEmitter::visitTupleInst` crash for `tuple` instructions with
+// non-tuple-typed adjoint values.
+
+@differentiable
+func TF_964(_ x: Float) -> Float {
+ let tuple = (x, 1)
+ return tuple.0
+}
+
+// Original crasher:
+// Assertion failed: (Operand->getType().is<TupleType>() && "Expected a tuple typed operand?!"), function create, file /Users/swiftninjas/s4tf/swift/lib/SIL/SILInstructions.cpp, line 2676.
+// Stack dump:
+// 0. Program arguments: /Library/Developer/Toolchains/swift-tensorflow-RELEASE-0.6.xctoolchain/usr/bin/swift -frontend -interpret tf-964.swift -enable-objc-interop -sdk /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.15.sdk -color-diagnostics -module-name main
+// 1. Swift version 5.1.1-dev (Swift 7b97b0ced0)
+// 2. While running pass #17 SILModuleTransform "Differentiation".
+// 3. While processing `[differentiable source 0 wrt 0]` attribute on SIL function "@$s4main6TF_964yS2fF".
+// for 'TF_964(_:)' (at tf-964.swift:2:1)
+// 4. While generating VJP for SIL function "@$s4main6TF_964yS2fF".
+// for 'TF_964(_:)' (at tf-964.swift:2:1)
+// 5. While generating pullback for SIL function "@$s4main6TF_964yS2fF".
+// for 'TF_964(_:)' (at tf-964.swift:2:1)
diff --git a/test/AutoDiff/downstream/compiler_crashers_fixed/tf984-differential-unset-tangent-buffer.swift b/test/AutoDiff/downstream/compiler_crashers_fixed/tf984-differential-unset-tangent-buffer.swift
new file mode 100644
index 0000000..bc91662
--- /dev/null
+++ b/test/AutoDiff/downstream/compiler_crashers_fixed/tf984-differential-unset-tangent-buffer.swift
@@ -0,0 +1,30 @@
+// RUN: %target-swift-emit-sil -enable-experimental-forward-mode-differentiation %s -verify
+// REQUIRES: asserts
+
+// TF-984: Differential generation crash due to unset tangent buffer.
+
+struct Mut: Differentiable {}
+extension Mut {
+ mutating func mutatingMethodWrtMultipleResults(_ x: Mut) -> Mut {
+ return x
+ }
+}
+
+@differentiable(wrt: x)
+func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) {
+ _ = nonactive.mutatingMethodWrtMultipleResults(x)
+}
+
+// [AD] Original bb0: To differentiate or not to differentiate?
+// [ ] debug_value_addr %0 : $*Mut, var, name "nonactive", argno 1 // id: %2
+// [ ] debug_value %1 : $Mut, let, name "x", argno 2 // id: %3
+// [∂] %4 = alloc_stack $Mut, var, name "result" // users: %18, %6, %8, %12, %15
+// [ ] %5 = begin_access [read] [static] %0 : $*Mut // users: %7, %6
+// [∂] copy_addr %5 to [initialization] %4 : $*Mut // id: %6
+// ...
+// [AD] JVPEmitter visited:
+// [ORIG] copy_addr %5 to [initialization] %4 : $*Mut // id: %6
+// Assertion failed: (!insertion.second && "tangent buffer should already exist"), function getTangentBuffer, file swift/lib/SILOptimizer/Mandatory/Differentiation.cpp, line 4528.
+
+// `copy_addr %5 to [initialization] %4` is visited but `%5 = begin_access` is
+// not. `%5` does not have a set tangent buffer.
diff --git a/test/AutoDiff/downstream/core_builtins.swift b/test/AutoDiff/downstream/core_builtins.swift
new file mode 100644
index 0000000..c5fb13d
--- /dev/null
+++ b/test/AutoDiff/downstream/core_builtins.swift
@@ -0,0 +1,34 @@
+// RUN: %target-swift-frontend -parse-stdlib -typecheck -verify %s
+// RUN: %target-swift-frontend -parse-stdlib -emit-silgen %s | %FileCheck -check-prefix=CHECK-SIL %s
+
+import Swift
+
+func evaldiff<T: Differentiable, U: Differentiable>(_ f: @differentiable (T) -> U, _ x: T) -> (U, (T.TangentVector) -> U.TangentVector)
+ where T == T.TangentVector {
+ return Builtin.applyDerivative_jvp(f, x)
+}
+
+// CHECK-SIL-LABEL: @{{.*}}evaldiff{{.*}}
+// CHECK-SIL: bb0([[ORIG_RES_BUF:%.*]] : $*U, [[ORIG_FN:%.*]] : $@differentiable @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U>, [[ORIG_FN_ARG:%.*]] : $*T):
+// CHECK-SIL: [[ORIG_FN_CONVERTED:%.*]] = convert_function [[ORIG_FN]] : $@differentiable @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U> to $@differentiable @noescape @callee_guaranteed (@in_guaranteed T) -> @out U
+// CHECK-SIL: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ORIG_FN_CONVERTED]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T) -> @out U
+// CHECK-SIL: [[JVP_RES_BUF:%.*]] = alloc_stack $(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>)
+// CHECK-SIL: [[JVP_RES_BUF_0:%.*]] = tuple_element_addr [[JVP_RES_BUF]] : $*(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>), 0
+// CHECK-SIL: [[DIFFERENTIAL:%.*]] = apply [[JVP_FN]]([[JVP_RES_BUF_0]], [[ORIG_FN_ARG]]) : $@noescape @callee_guaranteed (@in_guaranteed T) -> (@out U, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>)
+// CHECK-SIL: [[JVP_RES_BUF_1:%.*]] = tuple_element_addr [[JVP_RES_BUF]] : $*(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>), 1
+// CHECK-SIL: store [[DIFFERENTIAL]] to [init] [[JVP_RES_BUF_1]] : $*@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>
+// CHECK-SIL: [[JVP_RES_BUF_0:%.*]] = tuple_element_addr [[JVP_RES_BUF]] : $*(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>), 0
+// CHECK-SIL: [[JVP_RES_BUF_1:%.*]] = tuple_element_addr [[JVP_RES_BUF]] : $*(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>), 1
+// CHECK-SIL: [[DIFFERENTIAL:%.*]] = load [take] [[JVP_RES_BUF_1]] : $*@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>
+// CHECK-SIL: copy_addr [take] [[JVP_RES_BUF_0]] to [initialization] [[ORIG_RES_BUF]] : $*U
+// CHECK-SIL: dealloc_stack [[JVP_RES_BUF]] : $*(U, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>)
+// CHECK-SIL: return [[DIFFERENTIAL]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <T, U.TangentVector>
+
+func evaldiff2<T: Differentiable, U: Differentiable, V: Differentiable>(_ f: @differentiable (T, U) -> V, _ x: T, _ y: U) -> (V, (T.TangentVector, U.TangentVector) -> V.TangentVector)
+ where T == T.TangentVector, U == U.TangentVector {
+ return Builtin.applyDerivative_jvp_arity2(f, x, y)
+}
+
+// CHECK-LABEL: @{{.*}}evaldiff2{{.*}}
+// CHECK: bb0({{.*}} : $*V, [[DIFFED:%.*]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T, @in_guaranteed U) -> @out V, {{.*}} : $*T, {{.*}} : $*U):
+// CHECK: differentiable_function_extract [jvp] [[DIFFED]] : $@differentiable @noescape @callee_guaranteed (@in_guaranteed T, @in_guaranteed U) -> @out V // user: %14
diff --git a/test/AutoDiff/downstream/cross_module_derivative_attr_e2e.swift b/test/AutoDiff/downstream/cross_module_derivative_attr_e2e.swift
new file mode 100644
index 0000000..f44806f
--- /dev/null
+++ b/test/AutoDiff/downstream/cross_module_derivative_attr_e2e.swift
@@ -0,0 +1,5 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -working-directory %t -I%t -parse-as-library -emit-module -module-name module1 -emit-module-path %t/module1.swiftmodule -emit-library -static %S/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift %S/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift -Xfrontend -validate-tbd-against-ir=none
+// RUN: %target-build-swift -I%t -L%t %S/Inputs/cross_module_derivative_attr_e2e/main/main.swift -o %t/a.out -lm -lmodule1 -Xfrontend -validate-tbd-against-ir=none
+// RUN: %target-run %t/a.out
+// REQUIRES: executable_test
diff --git a/test/AutoDiff/downstream/currying.swift b/test/AutoDiff/downstream/currying.swift
new file mode 100644
index 0000000..6b182b5
--- /dev/null
+++ b/test/AutoDiff/downstream/currying.swift
@@ -0,0 +1,24 @@
+// RUN: %target-run-simple-swift
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var CurryingAutodiffTests = TestSuite("CurryingAutodiff")
+
+CurryingAutodiffTests.testWithLeakChecking("StructMember") {
+ struct A {
+ @differentiable(wrt: (value))
+ func instanceMethod(_ value: Tracked<Float>) -> Tracked<Float> { return value * value }
+ }
+
+ let a = A()
+ // Referencing `a.instanceMethod` implicitly applies the curried function
+ // `A.instanceMethod` of type `(A) -> (Tracked<Float>) -> Tracked<Float>` to
+ // the value `a`, producing a `(Tracked<Float>) -> Tracked<Float>` value.
+ // This value is then converted to a `@differentiable` function-typed value.
+ let g: @differentiable (Tracked<Float>) -> Tracked<Float> = a.instanceMethod
+
+ expectEqual(Tracked<Float>(6.0), gradient(at: 3, in: g))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/custom_derivatives.swift b/test/AutoDiff/downstream/custom_derivatives.swift
new file mode 100644
index 0000000..676548a
--- /dev/null
+++ b/test/AutoDiff/downstream/custom_derivatives.swift
@@ -0,0 +1,99 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+import StdlibUnittest
+#if os(macOS)
+import Darwin.C
+#else
+import Glibc
+#endif
+import DifferentiationUnittest
+
+var CustomDerivativesTests = TestSuite("CustomDerivatives")
+
+// Specify non-differentiable functions.
+// These will be wrapped in `differentiableFunction` and tested.
+
+func unary(_ x: Tracked<Float>) -> Tracked<Float> {
+ var x = x
+ x *= 2
+ return x
+}
+
+func binary(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ var x = x
+ x *= y
+ return x
+}
+
+CustomDerivativesTests.testWithLeakChecking("differentiableFunction-unary") {
+ let diffableUnary = differentiableFunction { x in
+ (value: unary(x), pullback: { v in v * x * 2 })
+ }
+ expectEqual(20, gradient(at: 10, in: diffableUnary))
+ // Test differentiation of @differentiable function.
+ expectEqual(20, gradient(at: 10, in: { diffableUnary($0) }))
+ expectEqual(40, gradient(at: 10, in: { diffableUnary($0) * 2 }))
+}
+
+CustomDerivativesTests.testWithLeakChecking("differentiableFunction-binary") {
+ let diffableBinary = differentiableFunction { (x, y) in
+ (value: binary(x, y), pullback: { v in (v * y, v * x) })
+ }
+ expectEqual((10, 5), gradient(at: 5, 10, in: diffableBinary))
+ // Test differentiation of @differentiable function.
+ expectEqual((10, 5), gradient(at: 5, 10, in: { diffableBinary($0, $1) }))
+ expectEqual((20, 10), gradient(at: 5, 10, in: { diffableBinary($0, $1) * 2 }))
+}
+
+CustomDerivativesTests.testWithLeakChecking("Checkpointing") {
+ var count = 0
+ func cube(_ x: Tracked<Float>) -> Tracked<Float> {
+ count += 1
+ return x * x * x
+ }
+ // Test the top-level function variant of the checkpointing API.
+ expectEqual(324, gradient(at: 3) { (x: Tracked<Float>) -> Tracked<Float> in
+ expectEqual(0, count)
+ let y = withRecomputationInPullbacks(cube)(x)
+ expectEqual(1, count)
+ return y * 3 * x
+ })
+ expectEqual(2, count)
+ // Reset and test the method variant.
+ count = 0
+ expectEqual(324, gradient(at: 3) { (x: Tracked<Float>) -> Tracked<Float> in
+ expectEqual(0, count)
+ let y = x.withRecomputationInPullbacks(cube)
+ expectEqual(1, count)
+ return y * 3 * x
+ })
+ expectEqual(2, count)
+}
+
+CustomDerivativesTests.testWithLeakChecking("SumOfGradPieces") {
+ var grad: Tracked<Float> = 0
+ func addToGrad(_ x: inout Tracked<Float>) { grad += x }
+ _ = gradient(at: 4) { (x: Tracked<Float>) in
+ x.withDerivative(addToGrad)
+ * x.withDerivative(addToGrad)
+ * x.withDerivative(addToGrad)
+ }
+ expectEqual(48, grad)
+}
+
+CustomDerivativesTests.testWithLeakChecking("ModifyGradientOfSum") {
+ expectEqual(30, gradient(at: 4) { (x: Tracked<Float>) in
+ x.withDerivative { $0 *= 10 } + x.withDerivative { $0 *= 20 }
+ })
+}
+
+CustomDerivativesTests.testWithLeakChecking("WithoutDerivative") {
+ expectEqual(0, gradient(at: Tracked<Float>(4)) { x in
+ withoutDerivative(at: x) { x in
+ Tracked<Float>(sinf(x.value) + cosf(x.value))
+ }
+ })
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/derivative_generic_signature.swift b/test/AutoDiff/downstream/derivative_generic_signature.swift
new file mode 100644
index 0000000..49ee8df
--- /dev/null
+++ b/test/AutoDiff/downstream/derivative_generic_signature.swift
@@ -0,0 +1,197 @@
+// RUN: %target-swift-emit-sil -enable-experimental-forward-mode-differentiation -verify -module-name main %s | %FileCheck %s
+
+// Test derivative generic signatures:
+// - In `@differentiable` and `@derivative` attributes.
+// - In SIL differentiability witnesses.
+// - In generated derivative functions and derivative function thunks.
+
+//===----------------------------------------------------------------------===//
+// Same-type requirements
+//===----------------------------------------------------------------------===//
+
+// If all generic parameters are concrete (e.g. bound via same-type requirements
+// to concrete types), `@differentiable` attribute should not have a derivative
+// generic signature.
+
+// Test `@differentiable` attribute where original declaration has generic
+// signature and all generic parameters are concrete (e.g. bound to concrete
+// types via same-type requirements). SILGen lowers the original declaration to
+// a function with no generic signature, so the differentiability witness should
+// have no derivative generic signature.
+
+// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
+
+// Test same-type requirements.
+
+// If all generic parameters are concrete (e.g. bound via same-type requirements
+// to concrete types), `@differentiable` attribute should have no derivative
+// generic signature. Otherwise, :
+
+struct AllConcrete<T>: Differentiable {}
+
+extension AllConcrete {
+ // Original generic signature: `<T>`.
+ // Where clause generic signature: `<T where T == Float>`.
+ @_silgen_name("allconcrete_where_gensig_constrained")
+ @differentiable(where T == Float)
+ func whereClauseGenericSignatureConstrained() -> AllConcrete {
+ return self
+ }
+
+// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
+// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: vjp: @AD__allconcrete_where_gensig_constrained__vjp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: }
+}
+extension AllConcrete where T == Float {
+ @derivative(of: whereClauseGenericSignatureConstrained)
+ func jvpWhereClauseGenericSignatureConstrained() -> (
+ value: AllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (whereClauseGenericSignatureConstrained(), { $0 })
+ }
+}
+
+extension AllConcrete where T == Float {
+ // Original generic signature: `<T where T == Float>`.
+ // Where clause generic signature: none.
+ @_silgen_name("allconcrete_original_gensig")
+ @differentiable
+ func originalGenericSignature() -> AllConcrete {
+ return self
+ }
+
+ @derivative(of: originalGenericSignature)
+ func jvpOriginalGenericSignature() -> (
+ value: AllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (originalGenericSignature(), { $0 })
+ }
+
+// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
+// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: vjp: @AD__allconcrete_original_gensig__vjp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: }
+
+ // Original generic signature: `<T where T == Float>`.
+ // Where clause generic signature: `<T where T == Float>`.
+ @_silgen_name("allconcrete_where_gensig")
+ @differentiable(where T == Float)
+ func whereClauseGenericSignature() -> AllConcrete {
+ return self
+ }
+
+ @derivative(of: whereClauseGenericSignature)
+ func jvpWhereClauseGenericSignature() -> (
+ value: AllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (whereClauseGenericSignature(), { $0 })
+ }
+
+// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
+// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: vjp: @AD__allconcrete_where_gensig__vjp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
+// CHECK-NEXT: }
+}
+
+extension AllConcrete where T == Float {
+ func testDifferentiability() {
+ let _: @differentiable (AllConcrete) -> AllConcrete =
+ { $0.originalGenericSignature() }
+ let _: @differentiable (AllConcrete) -> AllConcrete =
+ { $0.whereClauseGenericSignature() }
+ let _: @differentiable (AllConcrete) -> AllConcrete =
+ { $0.whereClauseGenericSignatureConstrained() }
+ }
+}
+
+// Test `@differentiable` attribute where original declaration has generic
+// signature and not all generic parameters are concrete. The lowered SIL
+// function and the differentiability witness should both have a derivative
+// generic signature.
+
+// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
+
+struct NotAllConcrete<T, U>: Differentiable {}
+
+extension NotAllConcrete {
+ // Original generic signature: `<T, U>`.
+ // Where clause generic signature: `<T, U where T == Float>`.
+ @_silgen_name("notallconcrete_where_gensig_constrained")
+ @differentiable(where T == Float)
+ func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
+ return self
+ }
+
+// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
+// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: vjp: @AD__notallconcrete_where_gensig_constrained__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: }
+}
+extension NotAllConcrete where T == Float {
+ @derivative(of: whereClauseGenericSignatureConstrained)
+ func jvpWhereClauseGenericSignatureConstrained() -> (
+ value: NotAllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (whereClauseGenericSignatureConstrained(), { $0 })
+ }
+}
+
+extension NotAllConcrete where T == Float {
+ // Original generic signature: `<T, U where T == Float>`.
+ // Where clause generic signature: none.
+ @_silgen_name("notallconcrete_original_gensig")
+ @differentiable
+ func originalGenericSignature() -> NotAllConcrete {
+ return self
+ }
+
+ @derivative(of: originalGenericSignature)
+ func jvpOriginalGenericSignature() -> (
+ value: NotAllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (originalGenericSignature(), { $0 })
+ }
+
+// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
+// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: vjp: @AD__notallconcrete_original_gensig__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: }
+
+ // Original generic signature: `<T, U where T == Float>`.
+ // Where clause generic signature: `<T, U where T == Float>`.
+ @_silgen_name("notallconcrete_where_gensig")
+ @differentiable(where T == Float)
+ func whereClauseGenericSignature() -> NotAllConcrete {
+ return self
+ }
+
+ @derivative(of: whereClauseGenericSignature)
+ func jvpWhereClauseGenericSignature() -> (
+ value: NotAllConcrete, differential: (TangentVector) -> TangentVector
+ ) {
+ (whereClauseGenericSignature(), { $0 })
+ }
+
+// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
+// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: vjp: @AD__notallconcrete_where_gensig__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
+// CHECK-NEXT: }
+}
+
+extension NotAllConcrete where T == Float {
+ func testDifferentiability() {
+ let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
+ { $0.originalGenericSignature() }
+ let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
+ { $0.whereClauseGenericSignature() }
+ let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
+ { $0.whereClauseGenericSignatureConstrained() }
+ }
+}
diff --git a/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.c b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.c
new file mode 100644
index 0000000..2a3151b
--- /dev/null
+++ b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.c
@@ -0,0 +1 @@
+float cFunction(float x) { return x; }
diff --git a/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.h b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.h
new file mode 100644
index 0000000..cfa6983
--- /dev/null
+++ b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/Foreign.h
@@ -0,0 +1 @@
+float cFunction(float);
diff --git a/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/module.modulemap b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/module.modulemap
new file mode 100644
index 0000000..b9668e5
--- /dev/null
+++ b/test/AutoDiff/downstream/derivative_registration_foreign/Inputs/module.modulemap
@@ -0,0 +1,4 @@
+module CForeign {
+ header "Foreign.h"
+ export *
+}
diff --git a/test/AutoDiff/downstream/derivative_registration_foreign/main.swift b/test/AutoDiff/downstream/derivative_registration_foreign/main.swift
new file mode 100644
index 0000000..03ffdd3
--- /dev/null
+++ b/test/AutoDiff/downstream/derivative_registration_foreign/main.swift
@@ -0,0 +1,40 @@
+// RUN: %empty-directory(%t)
+// RUN: %clang -c %S/Inputs/Foreign.c -fmodules -o %t/CForeign.o
+// RUN: %target-swift-emit-silgen -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SILGEN --check-prefix=CHECK
+// RUN: %target-swift-emit-sil -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SIL --check-prefix=CHECK
+// RUN: %target-build-swift -I %S/Inputs -I %t %s %t/CForeign.o
+
+import CForeign
+
+// TF-1087: Test derivative registration for foreign declaration (Clang-imported).
+// Original SILDeclRef must have `isForeign` bit set correctly.
+
+// CHECK-SILGEN-LABEL: // differentiability witness for cFunction
+// CHECK-SILGEN: sil_differentiability_witness [serialized] [parameters 0] [results 0] @cFunction : $@convention(c) (Float) -> Float {
+// CHECK-SILGEN: vjp: @AD__cFunction__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-SILGEN: }
+
+// CHECK-SIL-LABEL: // differentiability witness for cFunction
+// CHECK-SIL: sil_differentiability_witness [serialized] [parameters 0] [results 0] @cFunction : $@convention(c) (Float) -> Float {
+// CHECK-SIL: jvp: @AD__cFunction__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-SIL: vjp: @AD__cFunction__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-SIL: }
+
+// Check that original SIL function is correct.
+
+// CHECK-SILGEN-LABEL: sil [serializable] [clang cFunction] @cFunction : $@convention(c) (Float) -> Float
+
+@inlinable
+@derivative(of: cFunction)
+func vjpCFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (cFunction(x), { $0 })
+}
+
+@_silgen_name("test_derivative")
+@differentiable
+func testDerivative(_ x: Float) -> Float {
+ cFunction(x)
+}
+
+// CHECK-SILGEN-LABEL: sil hidden [ossa] @test_derivative : $@convention(thin) (Float) -> Float {
+// CHECK-SILGEN: {{%.*}} = function_ref @cFunction : $@convention(c) (Float) -> Float
diff --git a/test/AutoDiff/downstream/derived_differentiable.swift b/test/AutoDiff/downstream/derived_differentiable.swift
new file mode 100644
index 0000000..2d358de
--- /dev/null
+++ b/test/AutoDiff/downstream/derived_differentiable.swift
@@ -0,0 +1,143 @@
+// RUN: %target-swift-frontend -print-ast %s | %FileCheck %s --check-prefix=CHECK-AST
+// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN
+// RUN: %target-swift-frontend -emit-sil -verify %s
+
+struct PointwiseMultiplicativeDummy : EuclideanDifferentiable, PointwiseMultiplicative {}
+
+public struct Foo : EuclideanDifferentiable {
+ public var a: Float
+}
+
+// CHECK-AST-LABEL: public struct Foo : EuclideanDifferentiable {
+// CHECK-AST: @differentiable
+// CHECK-AST: public var a: Float
+// CHECK-AST: internal init(a: Float)
+// CHECK-AST: public struct TangentVector
+// CHECK-AST: public typealias TangentVector = Foo.TangentVector
+// CHECK-AST: public var differentiableVectorView: Foo.TangentVector { get }
+
+// CHECK-SILGEN-LABEL: // differentiability witness for Foo.a.getter
+// CHECK-SILGEN-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0]
+
+struct AdditiveTangentIsSelf : AdditiveArithmetic, EuclideanDifferentiable {
+ var a: Float
+ var dummy: PointwiseMultiplicativeDummy
+}
+let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
+ x.a + x.a
+}
+
+// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, EuclideanDifferentiable {
+// CHECK-AST: internal var a: Float
+// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
+// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
+// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
+// The following should not exist because when `Self == Self.TangentVector`, `differentiableVectorView` is not synthesized.
+// CHECK-AST-NOT: internal var differentiableVectorView: AdditiveTangentIsSelf { get }
+
+struct TestNoDerivative : EuclideanDifferentiable {
+ var w: Float
+ @noDerivative var technicallyDifferentiable: Float
+}
+
+// CHECK-AST-LABEL: internal struct TestNoDerivative : EuclideanDifferentiable {
+// CHECK-AST: var w: Float
+// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
+// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions
+// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
+// CHECK-AST: internal var differentiableVectorView: TestNoDerivative.TangentVector { get }
+
+struct TestPointwiseMultiplicative : Differentiable {
+ var w: PointwiseMultiplicativeDummy
+ @noDerivative var technicallyDifferentiable: PointwiseMultiplicativeDummy
+}
+
+// CHECK-AST-LABEL: internal struct TestPointwiseMultiplicative : Differentiable {
+// CHECK-AST: var w: PointwiseMultiplicativeDummy
+// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
+// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
+// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector
+
+
+struct TestKeyPathIterable : Differentiable, KeyPathIterable {
+ var w: Float
+ @noDerivative var technicallyDifferentiable: Float
+}
+
+// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
+// CHECK-AST: var w: Float
+// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
+// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
+// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector
+
+struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic {
+ var x: T.TangentVector
+}
+
+// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
+// CHECK-AST: internal var x: T.TangentVector
+// CHECK-AST: internal init(x: T.TangentVector)
+// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
+// CHECK-AST: internal static var zero: GenericTanMember<T> { get }
+// CHECK-AST: internal static func + (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>
+// CHECK-AST: internal static func - (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>
+// CHECK-AST: @_implements(Equatable, ==(_:_:)) internal static func __derived_struct_equals(_ a: GenericTanMember<T>, _ b: GenericTanMember<T>) -> Bool
+
+public struct ConditionallyDifferentiable<T> {
+ public var x: T
+}
+extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}
+
+// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
+// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
+// CHECK-AST: public var x: T
+// CHECK-AST: internal init(x: T)
+// CHECK-AST: }
+
+// Verify that `TangentVector` is not synthesized to be `Self` for
+// `AdditiveArithmetic`-conforming classes.
+final class AdditiveArithmeticClass<T : AdditiveArithmetic & Differentiable> : AdditiveArithmetic, Differentiable {
+ var x, y: T
+ init(x: T, y: T) {
+ self.x = x
+ self.y = y
+ }
+
+ // Dummy `AdditiveArithmetic` requirements.
+ static func == (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Bool {
+ fatalError()
+ }
+ static var zero: AdditiveArithmeticClass {
+ fatalError()
+ }
+ static func + (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
+ fatalError()
+ }
+ static func - (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
+ fatalError()
+ }
+}
+
+// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
+// CHECK-AST: final internal var x: T, y: T
+// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
+// CHECK-AST: }
+
+@frozen
+public struct FrozenStruct: Differentiable {}
+
+// CHECK-AST-LABEL: @frozen public struct FrozenStruct : Differentiable {
+// CHECK-AST: internal init()
+// CHECK-AST: @frozen public struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+
+@usableFromInline
+struct UsableFromInlineStruct: Differentiable {}
+
+// CHECK-AST-LABEL: @usableFromInline
+// CHECK-AST: struct UsableFromInlineStruct : Differentiable {
+// CHECK-AST: internal init()
+// CHECK-AST: @usableFromInline
+// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
diff --git a/test/AutoDiff/downstream/derived_differentiable_runtime.swift b/test/AutoDiff/downstream/derived_differentiable_runtime.swift
new file mode 100644
index 0000000..3fc507f
--- /dev/null
+++ b/test/AutoDiff/downstream/derived_differentiable_runtime.swift
@@ -0,0 +1,54 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+import StdlibUnittest
+#if os(macOS)
+import Darwin.C
+#else
+import Glibc
+#endif
+
+var DerivedConformanceTests = TestSuite("DerivedConformances")
+
+DerivedConformanceTests.test("MemberwiseInitializers") {
+ struct AllVarStoredPropertiesHaveInitialValue: Differentiable, AdditiveArithmetic {
+ var x = Float(100)
+ var y = Float(100)
+ }
+ // Verify that `.zero` actually initializes properties to zero.
+ expectEqual(AllVarStoredPropertiesHaveInitialValue(x: 0, y: 0),
+ AllVarStoredPropertiesHaveInitialValue.zero)
+ expectEqual(AllVarStoredPropertiesHaveInitialValue.zero.x, 0)
+ expectEqual(AllVarStoredPropertiesHaveInitialValue.zero.y, 0)
+
+ struct HasNoDerivativeConstant: Differentiable {
+ @noDerivative let constant1 = Float(1)
+ @noDerivative let constant2 = Double(1)
+ var x = Float(1)
+ }
+ expectEqual(HasNoDerivativeConstant.TangentVector(x: 0),
+ HasNoDerivativeConstant.TangentVector.zero)
+}
+
+DerivedConformanceTests.test("EuclideanVectorView") {
+ do {
+ struct Foo: EuclideanDifferentiable {
+ var x: SIMD4<Float>
+ @noDerivative var y: SIMD4<Int32>
+ init() { x = [1, 2, 3, 4]; y = .zero }
+ }
+ let x = Foo()
+ expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
+ }
+ do {
+ class FooClass: EuclideanDifferentiable {
+ var x: SIMD4<Float>
+ @noDerivative var y: SIMD4<Int32>
+ init() { x = [1, 2, 3, 4]; y = .zero }
+ }
+ let x = FooClass()
+ expectEqual(FooClass.TangentVector(x: [1, 2, 3, 4]), x.differentiableVectorView)
+ }
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/deserialization_crashers.swift b/test/AutoDiff/downstream/deserialization_crashers.swift
new file mode 100644
index 0000000..fd058a1
--- /dev/null
+++ b/test/AutoDiff/downstream/deserialization_crashers.swift
@@ -0,0 +1,23 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-sib %s -o %t/tmp.sib
+// RUN: %target-sil-opt %t/tmp.sib
+
+// TF-256: Crashes when deserializing witness thunk for requirement requiring
+// differentiability wrt a subset of parameters.
+protocol DifferentiableWRTSubset : Differentiable {
+ @differentiable(wrt: (self))
+ func f(x: Float) -> Float
+
+ @differentiable(wrt: (x))
+ func g(x: Float) -> Float
+}
+
+struct TF256 : DifferentiableWRTSubset {
+ var param: Float = 0
+
+ @differentiable(wrt: (self))
+ func f(x: Float) -> Float { return x + param }
+
+ @differentiable(wrt: (x))
+ func g(x: Float) -> Float { return x + param }
+}
diff --git a/test/AutoDiff/downstream/differentiability_witness_inlining.sil b/test/AutoDiff/downstream/differentiability_witness_inlining.sil
new file mode 100644
index 0000000..b13a2b4
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiability_witness_inlining.sil
@@ -0,0 +1,42 @@
+// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s
+
+sil_stage raw
+
+import Swift
+import Builtin
+
+sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float {
+ jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+ vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+}
+
+sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
+
+// This is an example of a witness that is available (via deserialization)
+// even though it is not defined in the current module.
+// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float
+sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+
+sil @witness_defined_in_module : $@convention(thin) (Float) -> Float
+
+sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil @witness_definition_not_available : $@convention(thin) (Float) -> Float
+
+sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+
+sil @test : $@convention(thin) (Float) -> () {
+bb0(%0 : $Float):
+ %1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float
+ // CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+ %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
+ // CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
+
+ %3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+ // CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))
+
+ return undef : $()
+}
diff --git a/test/AutoDiff/downstream/differentiability_witness_swiftinterface.swift b/test/AutoDiff/downstream/differentiability_witness_swiftinterface.swift
new file mode 100644
index 0000000..c959e84
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiability_witness_swiftinterface.swift
@@ -0,0 +1,36 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -typecheck -emit-module-interface-path - %s -swift-version 5 -enable-library-evolution -module-name Module > %t/Module.swiftinterface
+// RUN: %target-swift-frontend -emit-silgen %t/Module.swiftinterface | %FileCheck %s --check-prefix=CHECK-SILGEN
+// RUN: not %target-swift-frontend -compile-module-from-interface %t/Module.swiftinterface -o %t/Module.swiftmodule 2>&1 | %FileCheck %s --check-prefix=CHECK-COMPILE
+
+// TF-1094: Derivative registration fails for `.swiftinterface` compilation when
+// original `@differentiable` function is serialized but `@derivative` function
+// is unserialized.
+
+@inlinable // serialized
+@differentiable
+@_silgen_name("foo")
+public func foo(_ x: Float) -> Float {
+ fatalError()
+}
+
+@usableFromInline // not serialized
+@derivative(of: foo)
+@_silgen_name("vjp_foo")
+func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (x, { $0 })
+}
+
+// Missing differentiability witness VJP entry because `func vjpFoo` is bodiless
+// in the `.swiftinterface` file and not lowered to a SIL function.
+
+// CHECK-SILGEN-LABEL: // differentiability witness for foo
+// CHECK-SILGEN-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
+// CHECK-SILGEN-NEXT: }
+
+// CHECK-SILGEN-LABEL: sil [serialized] [ossa] @foo
+// CHECK-SILGEN-NOT: sil {{.*}} @vjp_foo
+
+// CHECK-COMPILE: Module.swiftinterface:5:2: error: function is not differentiable
+// CHECK-COMPILE: Module.swiftinterface:7:24: note: when differentiating this function definition
+// CHECK-COMPILE: Module.swiftinterface:9:1: note: missing return for differentiation
diff --git a/test/AutoDiff/downstream/differentiable_attr_cross_module/main.swift b/test/AutoDiff/downstream/differentiable_attr_cross_module/main.swift
new file mode 100644
index 0000000..7ce6da0
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_attr_cross_module/main.swift
@@ -0,0 +1,26 @@
+// Verify that `@differentiable` declarations can be differentiated from other
+// modules.
+
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null -lm
+// NOTE(TF-892): `-lm` is necessary to prevent linker errors related to `ElementaryFunctions` on Ubuntu.
+
+@differentiable(wrt: x)
+func testInitializer(_ x: Float) -> Float {
+ return Foo(x).x
+}
+
+@differentiable(wrt: foo)
+func testMethod(_ foo: Foo) -> Float {
+ return foo.method()
+}
+
+@differentiable(wrt: foo)
+func testComputedProperty(_ foo: Foo) -> Float {
+ return foo.computedProperty
+}
+
+@differentiable(wrt: foo)
+func testSubscript(_ foo: Foo) -> Float {
+ return foo[]
+}
diff --git a/test/AutoDiff/downstream/differentiable_attr_serialization.swift b/test/AutoDiff/downstream/differentiable_attr_serialization.swift
new file mode 100644
index 0000000..3f08f43
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_attr_serialization.swift
@@ -0,0 +1,32 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-module %s -o %t/differentiable_attr_serialization.swiftmodule
+// RUN: %target-swift-frontend -merge-modules -emit-module %t/differentiable_attr_serialization.swiftmodule
+
+// Test round-trip `@differentiable` attribute AST serialization.
+
+// Motivation: check that `@differentiable` attributes always have original
+// declaration set.
+
+struct Foo: Differentiable {
+ @differentiable
+ func method() -> Self { self }
+
+ @differentiable
+ init(_ x: Float) {}
+
+ @differentiable
+ var computedProperty: Float { 1 }
+
+ var computedPropertyGetter: Float {
+ @differentiable
+ get { 1 }
+ }
+
+ @differentiable
+ subscript() -> Float { 1 }
+
+ subscript(_ x: Float) -> Float {
+ @differentiable
+ get { 1 }
+ }
+}
diff --git a/test/AutoDiff/downstream/differentiable_attr_type_checking_primary_file.swift b/test/AutoDiff/downstream/differentiable_attr_type_checking_primary_file.swift
new file mode 100644
index 0000000..475cc24
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_attr_type_checking_primary_file.swift
@@ -0,0 +1,29 @@
+// RUN: %target-swift-frontend -typecheck -verify %S/Inputs/differentiable_attr_type_checking_non_primary_file.swift -primary-file %s
+
+// Test TF-1043: Type-checking protocol requirement `@differentiable` attributes
+// from non-primary files.
+
+struct OuterLayer: Layer {
+ typealias Input = Float
+ typealias Output = Float
+
+ var dummy: DummyLayer
+
+ @differentiable
+ var computedProperty: Output {
+ // NOTE(TF-1043): Old misleading error:
+ // error: 'Int' is not convertible to 'Float'
+ // return Float(1).sequenced(through: dummy)
+ // ^~~~~~~~
+ return Float(1).sequenced(through: dummy)
+ }
+
+ @differentiable
+ func instanceMethod(_ input: Input) -> Output {
+ // NOTE(TF-1043): Old misleading error:
+ // error: type of expression is ambiguous without more context
+ // return input.sequenced(through: dummy)
+ // ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
+ return input.sequenced(through: dummy)
+ }
+}
diff --git a/test/AutoDiff/downstream/differentiable_func_debuginfo.swift b/test/AutoDiff/downstream/differentiable_func_debuginfo.swift
new file mode 100644
index 0000000..e7da004
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_func_debuginfo.swift
@@ -0,0 +1,44 @@
+// RUN: %target-swift-frontend -c %s -g -O -parse-stdlib -parse-as-library -module-name Swift
+
+// TF-597: Exact minimal reproducer for IRGenDebugInfo crash.
+// Crash occurred only with `-g` and `-O`.
+//
+// ```
+// Assertion failed: (OffsetInBits + SizeInBits <= getSizeInBits(Var) && "pars > totum"),
+// function emitVariableDeclaration, file swift/lib/IRGen/IRGenDebugInfo.cpp, line 2216.
+// Stack dump:
+// 1. Swift version 5.1-dev (LLVM 200186e28b, Swift c09c14dec5)
+// 2. While emitting IR SIL function "@$ss8pullback2at2inyx_q_xXFts14DifferentiableRzsADR_r0_lF".
+// for 'pullback(at:in:)' (at swift/test/AutoDiff/differentiable_func_debuginfo.swift:21:8)
+// ```
+//
+// The crash was because `IRGenDebugInfoImpl::getOrCreateType` computes
+// `llvm::DIType` type debug info by demangling type names.
+//
+// Since `@differentiable` and `@differentiable(linear)` function types did
+// not have mangling support, `getOrCreateType` computed a regular `(A) -> B`
+// function type instead of a `@differentiable (A) -> B` function type, leading
+// to a type size inconsistency.
+//
+// Conclusion: mangling coverage is important.
+
+// Minimal dummy compiler-known `Differentiable` protocol.
+public protocol Differentiable {
+ associatedtype TangentVector
+}
+
+// This declaration is necessary to reproduce the crash.
+// `Builtin.autodiffApply_vjp` constructs a use of the `tf597ProblematicVarDecl`
+// type, which was mangled without `@differentiable` attribute. The parameter
+// for `blackHole` is of type `$@noescape @callee_guaranteed (@in_guaranteed T) -> @out U`,
+// which matched the mangled name for the type of the parameter of `blackHole`.
+// As a result, the types are uniqued when generating debug info. The type of
+// the parameter of `blackHole` is smaller than the `@differentiable` function
+// type, causing IRGenDebugInfo to crash.
+public func blackHole<T, U>(_: (T) -> U) {}
+
+public func pullback<T, R>(
+ at x: T, in tf597ProblematicVarDecl: @differentiable (T) -> R
+) {
+ let _ = Builtin.applyDerivative_vjp(tf597ProblematicVarDecl, x)
+}
diff --git a/test/AutoDiff/downstream/differentiable_func_type.sil b/test/AutoDiff/downstream/differentiable_func_type.sil
new file mode 100644
index 0000000..ad297d8
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_func_type.sil
@@ -0,0 +1,76 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_func_type
+// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_func_type
+// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_func_type | %FileCheck %s -check-prefix=CHECK-SIL
+
+// RUN: %target-swift-frontend %s -emit-ir -module-name differentiable_func_type | %FileCheck %s -check-prefix=CHECK-LLVM
+
+sil_stage raw
+
+import Swift
+
+sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
+bb0(%0 : $@differentiable(linear) (Float) -> Float):
+ return %0 : $@differentiable(linear) (Float) -> Float
+}
+
+// CHECK-SIL-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float {
+// CHECK-SIL: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
+// CHECK-SIL: return [[ARG]] : $@differentiable(linear) (Float) -> Float
+// CHECK-SIL: }
+
+// CHECK-LLVM-LABEL: define{{.*}} swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8* %0, %swift.refcounted* %1, i8* %2, %swift.refcounted* %3) #0 {
+// CHECK-LLVM: entry:
+// CHECK-LLVM: %4 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } undef, i8* %0, 0
+// CHECK-LLVM: %5 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %4, %swift.refcounted* %1, 1
+// CHECK-LLVM: %6 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %5, i8* %2, 2
+// CHECK-LLVM: %7 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %6, %swift.refcounted* %3, 3
+// CHECK-LLVM: ret { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %7
+// CHECK-LLVM: }
+
+
+sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
+bb0(%0 : $@differentiable (Float) -> Float):
+ return %0 : $@differentiable (Float) -> Float
+}
+
+// CHECK-SIL-LABEL: sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
+// CHECK-SIL: bb0([[ARG:%.*]] : $@differentiable (Float) -> Float):
+// CHECK-SIL: return [[ARG]] : $@differentiable (Float) -> Float
+// CHECK-SIL: }
+
+// CHECK-LLVM-LABEL: define{{.*}} swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret %0, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48) %1) #0 {
+// CHECK-LLVM: entry:
+// CHECK-LLVM: %.original = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 0
+// CHECK-LLVM: %.original.fn = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 0
+// CHECK-LLVM: %2 = load i8*, i8** %.original.fn, align 8
+// CHECK-LLVM: %.original.data = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 1
+// CHECK-LLVM: %3 = load %swift.refcounted*, %swift.refcounted** %.original.data, align 8
+// CHECK-LLVM: %.jvp = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 1
+// CHECK-LLVM: %.jvp.fn = getelementptr inbounds %swift.function, %swift.function* %.jvp, i32 0, i32 0
+// CHECK-LLVM: %4 = load i8*, i8** %.jvp.fn, align 8
+// CHECK-LLVM: %.jvp.data = getelementptr inbounds %swift.function, %swift.function* %.jvp, i32 0, i32 1
+// CHECK-LLVM: %5 = load %swift.refcounted*, %swift.refcounted** %.jvp.data, align 8
+// CHECK-LLVM: %.vjp = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 2
+// CHECK-LLVM: %.vjp.fn = getelementptr inbounds %swift.function, %swift.function* %.vjp, i32 0, i32 0
+// CHECK-LLVM: %6 = load i8*, i8** %.vjp.fn, align 8
+// CHECK-LLVM: %.vjp.data = getelementptr inbounds %swift.function, %swift.function* %.vjp, i32 0, i32 1
+// CHECK-LLVM: %7 = load %swift.refcounted*, %swift.refcounted** %.vjp.data, align 8
+// CHECK-LLVM: %.original1 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 0
+// CHECK-LLVM: %.original1.fn = getelementptr inbounds %swift.function, %swift.function* %.original1, i32 0, i32 0
+// CHECK-LLVM: store i8* %2, i8** %.original1.fn, align 8
+// CHECK-LLVM: %.original1.data = getelementptr inbounds %swift.function, %swift.function* %.original1, i32 0, i32 1
+// CHECK-LLVM: store %swift.refcounted* %3, %swift.refcounted** %.original1.data, align 8
+// CHECK-LLVM: %.jvp2 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 1
+// CHECK-LLVM: %.jvp2.fn = getelementptr inbounds %swift.function, %swift.function* %.jvp2, i32 0, i32 0
+// CHECK-LLVM: store i8* %4, i8** %.jvp2.fn, align 8
+// CHECK-LLVM: %.jvp2.data = getelementptr inbounds %swift.function, %swift.function* %.jvp2, i32 0, i32 1
+// CHECK-LLVM: store %swift.refcounted* %5, %swift.refcounted** %.jvp2.data, align 8
+// CHECK-LLVM: %.vjp3 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 2
+// CHECK-LLVM: %.vjp3.fn = getelementptr inbounds %swift.function, %swift.function* %.vjp3, i32 0, i32 0
+// CHECK-LLVM: store i8* %6, i8** %.vjp3.fn, align 8
+// CHECK-LLVM: %.vjp3.data = getelementptr inbounds %swift.function, %swift.function* %.vjp3, i32 0, i32 1
+// CHECK-LLVM: store %swift.refcounted* %7, %swift.refcounted** %.vjp3.data, align 8
+// CHECK-LLVM: ret void
+// CHECK-LLVM: }
+
diff --git a/test/AutoDiff/downstream/differentiable_function_inst_lowered.sil b/test/AutoDiff/downstream/differentiable_function_inst_lowered.sil
new file mode 100644
index 0000000..e90f370
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_function_inst_lowered.sil
@@ -0,0 +1,71 @@
+// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s
+
+// Test `differentiable_function_extract` and
+// `differentiability_witness_function` with explicit lowered type.
+// SIL generated via `%target-sil-opt -loadable-address %s`.
+// Note: SIL serialization/deserialization does not support lowered SIL.
+
+sil_stage lowered
+
+import Swift
+import Builtin
+
+struct Large : Differentiable {
+ @_hasStorage @noDerivative let a: Float { get }
+ @_hasStorage @noDerivative let b: Float { get }
+ @_hasStorage @noDerivative let c: Float { get }
+ @_hasStorage @noDerivative let d: Float { get }
+ @_hasStorage @noDerivative let e: Float { get }
+ init(a: Float, b: Float, c: Float, d: Float, e: Float)
+ struct TangentVector : Differentiable, AdditiveArithmetic {
+ init()
+ typealias TangentVector = Large.TangentVector
+ static var zero: Large.TangentVector { get }
+ static func + (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
+ static func - (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector
+ @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Large.TangentVector, _ b: Large.TangentVector) -> Bool
+ }
+ mutating func move(along direction: Large.TangentVector)
+}
+
+sil_differentiability_witness [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+
+sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+
+// CHECK-LABEL: sil @test
+sil @test : $@convention(thin) () -> () {
+bb0:
+ %func = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ %func_jvpwitness_wrt_012 = differentiability_witness_function [jvp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector)
+ %func_vjpwitness_wrt_012 = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
+ %func_diff_wrt_012 = differentiable_function [parameters 0 1 2] [results 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large with_derivative {%func_jvpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector), %func_vjpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))}
+ %func_vjp_wrt_012 = differentiable_function_extract [vjp] %func_diff_wrt_012 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
+
+ // CHECK: [[FUNC_REF:%.*]] = function_ref @examplefunc
+ // CHECK: [[DIFF_WRT_012:%.*]] = differentiable_function [parameters 0 1 2] [results 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ // CHECK: [[VJP_WRT_012:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_012]] : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
+
+ %func_diff_wrt_0 = differentiable_function [parameters 0] [results 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ %func_vjp_wrt_0 = differentiable_function_extract [vjp] %func_diff_wrt_0 : $@differentiable @convention(thin) (@in_constant Large, @noDerivative @in_constant Large, @noDerivative @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
+
+ // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [results 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(thin) (@in_constant Large, @noDerivative @in_constant Large, @noDerivative @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
+
+ %method = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ %method_diff_wrt_0123 = differentiable_function [parameters 0 1 2] [results 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ %7 = differentiable_function_extract [vjp] %method_diff_wrt_0123 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
+
+ // CHECK: [[METHOD_REF:%.*]] = function_ref @examplemethod
+ // CHECK: [[DIFF_WRT_0123:%.*]] = differentiable_function [parameters 0 1 2] [results 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ // CHECK: [[VJP_WRT_0123:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0123]] : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
+
+ %method_diff_wrt_0 = differentiable_function [parameters 0] [results 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ %method_vjp_wrt_0 = differentiable_function_extract [vjp] %method_diff_wrt_0 : $@differentiable @convention(method) (@in_constant Large, @noDerivative @in_constant Large, @noDerivative @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
+
+ // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [results 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
+ // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(method) (@in_constant Large, @noDerivative @in_constant Large, @noDerivative @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
+
+ %10 = tuple ()
+ return %10 : $()
+}
diff --git a/test/AutoDiff/downstream/differentiable_function_silgen.swift b/test/AutoDiff/downstream/differentiable_function_silgen.swift
new file mode 100644
index 0000000..03e6933
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_function_silgen.swift
@@ -0,0 +1,118 @@
+// RUN: %target-swift-frontend -dump-ast %s | %FileCheck %s -check-prefix=CHECK-AST
+// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s -check-prefix=CHECK-SILGEN
+
+//===----------------------------------------------------------------------===//
+// Closure conversion
+//===----------------------------------------------------------------------===//
+
+func thin(x: Float) -> Float { return x }
+
+func myfunction(_ f: @escaping @differentiable (Float) -> (Float)) -> (Float) -> Float {
+ // @differentiable functions should be callable.
+ _ = f(.zero)
+ return f
+}
+
+func myfunction2(_ f: @escaping @differentiable(linear) (Float) -> (Float)) -> (Float) -> Float {
+ // @differentiable(linear) functions should be callable.
+ _ = f(.zero)
+ return f
+}
+
+var global_f: @differentiable (Float) -> Float = {$0}
+var global_f_linear: @differentiable(linear) (Float) -> Float = {$0}
+
+func calls_global_f() {
+ _ = global_f(10)
+ // TODO(TF-900, TF-902): Uncomment the following line to test loading a linear function from memory and direct calls to a linear function.
+ // _ = global_f_linear(10)
+}
+
+func apply() {
+ _ = myfunction(thin)
+ _ = myfunction2(thin)
+}
+
+// CHECK-AST-LABEL: (func_decl {{.*}} "myfunction(_:)"
+// CHECK-AST: (call_expr type='(Float)'
+// CHECK-AST: (declref_expr type='@differentiable (Float) -> (Float)'
+// CHECK-AST: (return_stmt
+// CHECK-AST: (function_conversion_expr implicit type='(Float) -> Float'
+// CHECK-AST: (differentiable_function_extract_original implicit type='(Float) -> (Float)'
+// CHECK-AST: (declref_expr type='@differentiable (Float) -> (Float)'
+// CHECK-AST-LABEL: (func_decl {{.*}} "apply()"
+// CHECK-AST: (function_conversion_expr implicit type='@differentiable (Float) -> (Float)'
+// CHECK-AST: (differentiable_function implicit type='@differentiable (Float) -> Float'
+// CHECK-AST: (declref_expr type='(Float) -> Float'
+
+// CHECK-SILGEN-LABEL: @{{.*}}myfunction{{.*}}
+// CHECK-SILGEN: bb0([[DIFF:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
+// CHECK-SILGEN: [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: apply [[BORROWED_DIFF]]({{%.*}}) : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: end_borrow [[BORROWED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: destroy_value [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = differentiable_function_extract [original] [[BORROWED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float
+
+// CHECK-SILGEN-LABEL: @{{.*}}myfunction2{{.*}}
+// CHECK-SILGEN: bb0([[LIN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float):
+// CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: apply [[BORROWED_LIN]]({{%.*}}) : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = linear_function_extract [original] [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: destroy_value [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float
+
+// CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}}
+// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
+// CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
+// CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN-NEXT: [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
+
+//===----------------------------------------------------------------------===//
+// Reabstraction
+//===----------------------------------------------------------------------===//
+
+func pullback<T, R>(
+ at x: T, in f: @escaping @differentiable (T) -> R
+) -> (R.TangentVector) -> T.TangentVector {
+ fatalError()
+}
+
+func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) {
+ _ = pullback(at: .zero, in: f)
+}
+
+// CHECK-SILGEN-LABEL: @{{.*}}appliesReabstraction{{.*}}
+// CHECK-SILGEN: bb0([[DIFF_FUNC_ARG:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
+// CHECK-SILGEN: [[DIFF_FUNC:%.*]] = copy_value [[DIFF_FUNC_ARG]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[DIFF_FUNC_BORROWED:%.*]] = begin_borrow [[DIFF_FUNC]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[ORIG:%.*]] = differentiable_function_extract [original] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[REABS_ORIG:%.*]] = function_ref @$sS2fIegyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float
+// CHECK-SILGEN: [[NEW_ORIG:%.*]] = partial_apply [callee_guaranteed] [[REABS_ORIG]]([[ORIG_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float
+// CHECK-SILGEN: [[NEW_ORIG_CONVERTED:%.*]] = convert_function [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>
+// CHECK-SILGEN: [[JVP:%.*]] = differentiable_function_extract [jvp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[JVP_COPY:%.*]] = copy_value [[JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-SILGEN: [[REABS_JVP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>)
+// CHECK-SILGEN: [[NEW_JVP:%.*]] = partial_apply [callee_guaranteed] %19(%18) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>)
+// CHECK-SILGEN: [[NEW_JVP_CONVERTED:%.*]] = convert_function [[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>
+// CHECK-SILGEN: [[VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
+// CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>)
+// CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>)
+// CHECK-SILGEN: [[NEW_VJP_CONVERTED:%.*]] = convert_function [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>
+// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [parameters 0] [results 0] [[NEW_ORIG_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> with_derivative {[[NEW_JVP_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, [[NEW_VJP_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>}
+// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>
+// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>
diff --git a/test/AutoDiff/downstream/differentiable_requirement_cross_module.swift b/test/AutoDiff/downstream/differentiable_requirement_cross_module.swift
new file mode 100644
index 0000000..da37a45
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_requirement_cross_module.swift
@@ -0,0 +1,33 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/differentiable_requirement_other_module.swift -emit-module-path %t/differentiable_requirement_other_module.swiftmodule
+// RUN: %target-swift-frontend -typecheck -I %t -primary-file %s -verify
+
+import differentiable_requirement_other_module
+
+// Conform `Empty` to `Differentiable`.
+// The `foo` protocol requirement is `@differentiable` and has an `Empty` parameter.
+extension Empty : Differentiable {
+ public typealias TangentVector = Empty
+ public typealias AllDifferentiableVariables = Empty
+ public var zeroTangentVectorInitializer: () -> TangentVector { { .zero } }
+}
+
+private struct PrivateConforming : DifferentiableRequirement {
+ fileprivate func foo(float: Float, empty: Empty) -> Float {
+ return float
+ }
+}
+
+struct InternalConforming : DifferentiableRequirement {
+ func foo(float: Float, empty: Empty) -> Float {
+ return float
+ }
+}
+
+// expected-error @+1 {{type 'PublicConforming' does not conform to protocol 'DifferentiableRequirement'}}
+public struct PublicConforming : DifferentiableRequirement {
+ // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: float)'}}
+ public func foo(float: Float, empty: Empty) -> Float {
+ return float
+ }
+}
diff --git a/test/AutoDiff/downstream/differentiable_sil_attr_roundtrip.swift b/test/AutoDiff/downstream/differentiable_sil_attr_roundtrip.swift
new file mode 100644
index 0000000..e5698ad
--- /dev/null
+++ b/test/AutoDiff/downstream/differentiable_sil_attr_roundtrip.swift
@@ -0,0 +1,26 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-silgen %s -o %t/roundtrip.sil
+// RUN: %target-swift-frontend -emit-sil %t/roundtrip.sil
+
+// TF-656: Verify that `AutoDiffIndexSubset` for SIL `[differentiable]`
+// attribute is set correctly.
+
+// Otherwise, an assertion is triggered during the differentiation transform:
+// Assertion failed: (newCapacity >= capacity), function extendingCapacity
+// ... ADContext::promoteToDifferentiableFunction
+
+// NOTE: We cannot differentiate external functions in roundtrip SIL tests.
+// Reason: When we print then parse the SIL we lose the information that the
+// external function is associated with an AST decl. So the differentiation
+// pass can't see the AST differentiable attrs, and the differentiation pass
+// thinks that we're trying to differentiate an external function without
+// explicit AST differentiable attrs.
+// TODO(TF-988): This can probably be fixed.
+
+@differentiable(wrt: x)
+func TF_656(_ x: Float, _ y: Float) -> Float {
+ // FIXME(TF-988): Cannot differentiate external functions.
+ // return x + y
+ return 0
+}
+_ = gradient(at: 1, in: { x in TF_656(x, 2) })
diff --git a/test/AutoDiff/downstream/e2e_cross_module.swift b/test/AutoDiff/downstream/e2e_cross_module.swift
new file mode 100644
index 0000000..5973332
--- /dev/null
+++ b/test/AutoDiff/downstream/e2e_cross_module.swift
@@ -0,0 +1,26 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -working-directory %t -parse-as-library -emit-module -module-name e2e_cross_module_external_module -emit-module-path %t/e2e_cross_module_external_module.swiftmodule -emit-library -static %S/Inputs/e2e_cross_module_external_module.swift
+// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out -lm -le2e_cross_module_external_module
+// RUN: %target-run %t/a.out
+// REQUIRES: executable_test
+
+import e2e_cross_module_external_module
+import StdlibUnittest
+import DifferentiationUnittest
+
+var Tests = TestSuite("E2ECrossModule")
+
+// Reproduces TF-1025.
+Tests.testWithLeakChecking("differentiable function default argument") {
+ let actualGrad = gradient(at: 0) { doubleThenApply($0) }
+ let expectedGrad = Tracked<Float>(2)
+ expectEqual(actualGrad, expectedGrad)
+}
+
+Tests.testWithLeakChecking("differentiable function specified default argument") {
+ let actualGrad = gradient(at: 0) { doubleThenApply($0, { 10 * $0 }) }
+ let expectedGrad = Tracked<Float>(20)
+ expectEqual(actualGrad, expectedGrad)
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/forward_mode_runtime.swift b/test/AutoDiff/downstream/forward_mode_runtime.swift
new file mode 100644
index 0000000..7b9a145
--- /dev/null
+++ b/test/AutoDiff/downstream/forward_mode_runtime.swift
@@ -0,0 +1,1257 @@
+// RUN: %target-run-simple-swift-forward-mode-differentiation
+// REQUIRES: executable_test
+
+import StdlibUnittest
+import DifferentiationUnittest
+#if os(macOS)
+import Darwin.C
+#else
+import Glibc
+#endif
+
+var ForwardModeTests = TestSuite("ForwardMode")
+
+//===----------------------------------------------------------------------===//
+// Basic tests.
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("Identity") {
+ func func_to_diff(x: Float) -> Float {
+ return x
+ }
+ let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff)
+ expectEqual(4, y)
+ expectEqual(1, differential(1))
+}
+
+ForwardModeTests.test("Unary") {
+ func func_to_diff(x: Float) -> Float {
+ return x * x
+ }
+ let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff)
+ expectEqual(16, y)
+ expectEqual(8, differential(1))
+}
+
+ForwardModeTests.test("Binary") {
+ func func_to_diff(x: Float, y: Float) -> Float {
+ return x * y
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff)
+ expectEqual(20, y)
+ expectEqual(9, differential(1, 1))
+}
+
+ForwardModeTests.test("BinaryWithLets") {
+ func func_to_diff(x: Float, y: Float) -> Float {
+ let a = x + y
+ let b = a
+ return b * -y
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff)
+ expectEqual(-45, y)
+ expectEqual(-19, differential(1, 1))
+}
+
+ForwardModeTests.test("SubsetParametersDiff") {
+ func func_to_diff1(x: Int, y: Float, z: Int) -> Float {
+ return y
+ }
+ let (y1, differential1) = valueWithDifferential(at: 5) { y in
+ func_to_diff1(x: 0, y: y, z: 0)
+ }
+ expectEqual(5, y1)
+ expectEqual(1, differential1(1))
+
+ func func_to_diff2(x: Float, y: Int, z: Int) -> Float {
+ return 2 * x
+ }
+ let (y2, differential2) = valueWithDifferential(at: 6) { x in
+ func_to_diff2(x: x, y: 0, z: 0)
+ }
+ expectEqual(12, y2)
+ expectEqual(2, differential2(1))
+
+ func func_to_diff3(x: Int, y: Int, z: Float) -> Float {
+ return 3 * z
+ }
+ let (y3, differential3) = valueWithDifferential(at: 7) { z in
+ func_to_diff3(x: 0, y: 0, z: z)
+ }
+ expectEqual(21, y3)
+ expectEqual(3, differential3(1))
+}
+
+//===----------------------------------------------------------------------===//
+// Functions with variables
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("UnaryWithVars") {
+ func unary(x: Float) -> Float {
+ var a = x
+ a = x
+ var b = a + 2
+ b = b - 1
+ let c: Float = 3
+ var d = a + b + c - 1
+ d = d + d
+ return d
+ }
+
+ let (y, differential) = valueWithDifferential(at: 4, in: unary)
+ expectEqual(22, y)
+ expectEqual(4, differential(1))
+}
+
+//===----------------------------------------------------------------------===//
+// Functions with basic struct
+//===----------------------------------------------------------------------===//
+
+struct A: Differentiable & AdditiveArithmetic {
+ var x: Float
+}
+
+ForwardModeTests.test("StructInit") {
+ func structInit(x: Float) -> A {
+ return A(x: 2 * x)
+ }
+
+ let (y, differential) = valueWithDifferential(at: 4, in: structInit)
+ expectEqual(A(x: 8), y)
+ expectEqual(A(x: 2), differential(1))
+}
+
+ForwardModeTests.test("StructExtract") {
+ func structExtract(x: A) -> Float {
+ return 2 * x.x
+ }
+
+ let (y, differential) = valueWithDifferential(
+ at: A(x: 4),
+ in: structExtract)
+ expectEqual(8, y)
+ expectEqual(2, differential(A(x: 1)))
+}
+
+ForwardModeTests.test("LocalStructVariable") {
+ func structExtract(x: A) -> A {
+ let a = A(x: 2 * x.x) // 2x
+ var b = A(x: a.x + 2) // 2x + 2
+ b = A(x: b.x + a.x) // 2x + 2 + 2x = 4x + 2
+ return b
+ }
+
+ let (y, differential) = valueWithDifferential(
+ at: A(x: 4),
+ in: structExtract)
+ expectEqual(A(x: 18), y)
+ expectEqual(A(x: 4), differential(A(x: 1)))
+}
+
+//===----------------------------------------------------------------------===//
+// Functions with methods
+//===----------------------------------------------------------------------===//
+
+extension A {
+ func noParamMethodA() -> A {
+ return A(x: 2 * x)
+ }
+
+ func noParamMethodx() -> Float {
+ return 2 * x
+ }
+
+ static func *(lhs: A, rhs: A) -> A {
+ return A(x: lhs.x * rhs.x)
+ }
+
+ func complexBinaryMethod(u: A, v: Float) -> A {
+ var b: A = u * A(x: 2) // A(x: u * 2)
+ b.x = b.x * v // A(x: u * 2 * v)
+ let c = b.x + 1 // u * 2 * v + 1
+
+ // A(x: u * 2 * v + 1 + u * 2 * v) = A(x: x * (4uv + 1))
+ return A(x: x * (c + b.x))
+ }
+}
+
+ForwardModeTests.test("noParamMethodA") {
+ let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in
+ x.noParamMethodA()
+ }
+ expectEqual(A(x: 8), y)
+ expectEqual(A(x: 2), differential(A(x: 1)))
+}
+
+ForwardModeTests.test("noParamMethodx") {
+ let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in
+ x.noParamMethodx()
+ }
+ expectEqual(8, y)
+ expectEqual(2, differential(A(x: 1)))
+}
+
+ForwardModeTests.test("complexBinaryMethod") {
+ let (y, differential) = valueWithDifferential(at: A(x: 4), A(x: 5), 3) {
+ (x, y, z) in
+ // derivative = A(x: 4uv + 4xv + 4ux + 1) = 4*5*3 + 4*4*3 + 4*5*4 + 1 = 189
+ x.complexBinaryMethod(u: y, v: z)
+ }
+ expectEqual(A(x: 244), y)
+ expectEqual(A(x: 189), differential(A(x: 1), A(x: 1), 1))
+}
+
+//===----------------------------------------------------------------------===//
+// Tracked struct
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.testWithLeakChecking("TrackedIdentity") {
+ func identity(x: Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ let (y, differential) = valueWithDifferential(at: 4, in: identity)
+ expectEqual(4, y)
+ expectEqual(1, differential(1))
+}
+
+ForwardModeTests.testWithLeakChecking("TrackedAddition") {
+ func add(x: Tracked<Float>, y: Tracked<Float>) -> Tracked<Float> {
+ return x + y
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5, in: add)
+ expectEqual(9, y)
+ expectEqual(2, differential(1, 1))
+}
+
+ForwardModeTests.testWithLeakChecking("TrackedDivision") {
+ func divide(x: Tracked<Float>, y: Tracked<Float>) -> Tracked<Float> {
+ return x / y
+ }
+ let (y, differential) = valueWithDifferential(at: 10, 5, in: divide)
+ expectEqual(2, y)
+ expectEqual(-0.2, differential(1, 1))
+}
+
+ForwardModeTests.testWithLeakChecking("TrackedMultipleMultiplication") {
+ func add(x: Tracked<Float>, y: Tracked<Float>) -> Tracked<Float> {
+ return x * y * x
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5, in: add)
+ expectEqual(80, y)
+ // 2yx+xx
+ expectEqual(56, differential(1, 1))
+}
+
+ForwardModeTests.testWithLeakChecking("TrackedWithLets") {
+ func add(x: Tracked<Float>, y: Tracked<Float>) -> Tracked<Float> {
+ let a = x + y
+ let b = a * a // (x+y)^2
+ let c = b / x + y // (x+y)^2/x+y
+ return c
+ }
+ // (3x^2+2xy-y^2)/x^2+1
+ let (y, differential) = valueWithDifferential(at: 4, 5, in: add)
+ expectEqual(25.25, y)
+ expectEqual(4.9375, differential(1, 1))
+}
+
+//===----------------------------------------------------------------------===//
+// Tuples
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("TupleLet") {
+ do {
+ func tupleLet(_ x: Float) -> Float {
+ let tuple = (2 * x, x)
+ return tuple.0
+ }
+ let (value, derivative) = valueWithDerivative(at: 4, in: tupleLet)
+ expectEqual(8, value)
+ expectEqual(2, derivative)
+ }
+}
+
+ForwardModeTests.test("TupleVar") {
+ do {
+ func tupleVar(_ x: Float) -> Float {
+ var tuple = (2 * x, x)
+ return tuple.0
+ }
+ let (value, derivative) = valueWithDerivative(at: 4, in: tupleVar)
+ expectEqual(8, value)
+ expectEqual(2, derivative)
+ }
+
+ do {
+ // TF-964: Test tuple with non-tuple-typed adjoint value.
+ func TF_964(_ x: Float) -> Float {
+ var tuple = (2 * x, 1)
+ return tuple.0
+ }
+ let (value, derivative) = valueWithDerivative(at: 4, in: TF_964)
+ expectEqual(8, value)
+ expectEqual(2, derivative)
+ }
+}
+
+ForwardModeTests.test("TupleMutation") {
+ func foo(_ x: Float) -> Float {
+ var tuple = (x, x)
+ tuple.0 = tuple.0 * x
+ return x * tuple.0
+ }
+ expectEqual(27, derivative(at: 3, in: foo))
+
+ func fifthPower(_ x: Float) -> Float {
+ var tuple = (x, x)
+ tuple.0 = tuple.0 * x
+ tuple.1 = tuple.0 * x
+ return tuple.0 * tuple.1
+ }
+ expectEqual(405, derivative(at: 3, in: fifthPower))
+
+ func nested(_ x: Float) -> Float {
+ var tuple = ((x, x), x)
+ tuple.0.0 = tuple.0.0 * x
+ tuple.0.1 = tuple.0.0 * x
+ return tuple.0.0 * tuple.0.1
+ }
+ expectEqual(405, derivative(at: 3, in: nested))
+
+ func generic<T: Differentiable & AdditiveArithmetic>(_ x: T) -> T {
+ var tuple = (x, x)
+ return tuple.0
+ }
+ expectEqual(1, derivative(at: 3.0, in: generic))
+
+ // FIXME(TF-1033): Fix forward-mode ownership error for tuple with non-active
+ // initial values.
+ /*
+ func genericInitialNonactive<T: Differentiable & AdditiveArithmetic>(
+ _ x: T
+ ) -> T {
+ var tuple = (T.zero, T.zero)
+ tuple.0 = x
+ tuple.1 = x
+ return tuple.0
+ }
+ expectEqual(1, derivative(at: 3.0, in: genericInitialNonactive))
+ */
+}
+
+// Tests TF-321.
+ForwardModeTests.test("TupleNonDifferentiableElements") {
+ // TF-964: Test tuple with non-tuple-typed adjoint value.
+ func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> {
+ let tuple = (2 * x, 1)
+ return tuple.0
+ }
+ expectEqual((8, 2), valueWithDerivative(at: 4, in: tupleLet))
+
+ func tupleVar(_ x: Tracked<Float>) -> Tracked<Float> {
+ var tuple = (x, 1)
+ tuple.0 = x
+ tuple.1 = 1
+ return tuple.0
+ }
+ expectEqual((3, 1), valueWithDerivative(at: 3, in: tupleVar))
+
+ func nested(_ x: Tracked<Float>) -> Tracked<Float> {
+ // Convoluted function computing `x * x`.
+ var tuple: (Int, (Int, Tracked<Float>), Tracked<Float>) = (1, (1, 0), 0)
+ tuple.0 = 1
+ tuple.1.0 = 1
+ tuple.1.1 = x
+ tuple.2 = x
+ return tuple.1.1 * tuple.2
+ }
+ expectEqual((16, 8), valueWithDerivative(at: 4, in: nested))
+
+ struct Wrapper<T> {
+ @differentiable(where T : Differentiable)
+ func baz(_ x: T) -> T {
+ var tuple = (1, 1, x, 1)
+ tuple.0 = 1
+ tuple.2 = x
+ tuple.3 = 1
+ return tuple.2
+ }
+ }
+ func wrapper(_ x: Tracked<Float>) -> Tracked<Float> {
+ let w = Wrapper<Tracked<Float>>()
+ return w.baz(x)
+ }
+ expectEqual((3, 1), valueWithDerivative(at: 3, in: wrapper))
+}
+
+//===----------------------------------------------------------------------===//
+// Generics
+//===----------------------------------------------------------------------===//
+
+struct Tensor<Scalar : FloatingPoint & Differentiable>
+ : AdditiveArithmetic, Differentiable {
+ // NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`)
+ // until differentiation has indirect passing support.
+ var value: Float
+ init(_ value: Float) { self.value = value }
+}
+
+ForwardModeTests.test("GenericIdentity") {
+ func identity<T : Differentiable>(_ x: T) -> T {
+ return x
+ }
+ let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in
+ identity(x)
+ }
+ expectEqual(4, y)
+ expectEqual(1, differential(1))
+}
+
+ForwardModeTests.test("GenericTensorIdentity") {
+ func identity<T : FloatingPoint & Differentiable>(
+ _ x: Tensor<T>) -> Tensor<T> {
+ return x
+ }
+ let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in
+ identity(Tensor<Float>(x))
+ }
+ expectEqual(Tensor<Float>(4), y)
+ expectEqual(Tensor<Float>(1), differential(1))
+}
+
+ForwardModeTests.test("GenericTensorPlus") {
+ func plus<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Float {
+ return x.value + x.value
+ }
+ let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in
+ plus(Tensor<Float>(x))
+ }
+ expectEqual(8, y)
+ expectEqual(2, differential(1))
+}
+
+ForwardModeTests.test("GenericTensorBinaryInput") {
+ func binary<T : FloatingPoint & Differentiable>(
+ _ x: Tensor<T>, _ y: Tensor<T>) -> Float {
+ return x.value * y.value
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5) {
+ (x: Float, y: Float) in
+ binary(Tensor<Float>(x), Tensor<Float>(y))
+ }
+ expectEqual(20, y)
+ expectEqual(9, differential(1, 1))
+}
+
+ForwardModeTests.test("GenericTensorWithLets") {
+ func binary<T : FloatingPoint & Differentiable>(
+ _ x: Tensor<T>, _ y: Tensor<T>) -> Float {
+ let a = Tensor<T>(x.value)
+ let b = Tensor<T>(y.value)
+ return a.value * b.value
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5) {
+ (x: Float, y: Float) in
+ binary(Tensor<Float>(x), Tensor<Float>(y))
+ }
+ expectEqual(20, y)
+ expectEqual(9, differential(1, 1))
+}
+
+ForwardModeTests.test("GenericTensorWithVars") {
+ func binary<T : FloatingPoint & Differentiable>(
+ _ x: Tensor<T>, _ y: Tensor<T>) -> Float {
+ var a = Tensor<T>(x.value)
+ var b = Tensor<T>(y.value)
+ b = a
+ a = Tensor<T>(y.value)
+ return a.value * b.value
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5) {
+ (x: Float, y: Float) in
+ binary(Tensor<Float>(x), Tensor<Float>(y))
+ }
+ expectEqual(20, y)
+ expectEqual(9, differential(1, 1))
+}
+
+// Test case where associated derivative function's requirements are met.
+extension Tensor where Scalar : Numeric {
+ @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
+ func mean() -> Tensor {
+ return self
+ }
+
+ @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
+ func variance() -> Tensor {
+ return mean() // ok
+ }
+}
+_ = differential(at: Tensor<Float>(1), in: { $0.variance() })
+
+// Tests TF-508: differentiation requirements with dependent member types.
+protocol TF_508_Proto {
+ associatedtype Scalar
+}
+extension TF_508_Proto where Scalar : FloatingPoint {
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Conformance requirement with dependent member type.
+ Self.TangentVector : TF_508_Proto
+ )
+ static func +(lhs: Self, rhs: Self) -> Self {
+ return lhs
+ }
+
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Same-type requirement with dependent member type.
+ Self.TangentVector == Float
+ )
+ static func -(lhs: Self, rhs: Self) -> Self {
+ return lhs
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector : TF_508_Proto {
+ @derivative(of: +)
+ static func jvpAdd(lhs: Self, rhs: Self)
+ -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) {
+ return (lhs, { (dlhs, drhs) in dlhs })
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector == Float {
+ @derivative(of: -)
+ static func jvpSubtract(lhs: Self, rhs: Self)
+ -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) {
+ return (lhs, { (dlhs, drhs) in dlhs })
+ }
+}
+
+struct TF_508_Struct<Scalar : AdditiveArithmetic>
+ : TF_508_Proto, AdditiveArithmetic {}
+extension TF_508_Struct : Differentiable where Scalar : Differentiable {
+ typealias TangentVector = TF_508_Struct
+}
+
+// func TF_508() {
+// let x = TF_508_Struct<Float>()
+// // Test conformance requirement with dependent member type.
+// _ = differential(at: x, in: {
+// (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+// return x + x
+// })
+// // Test same-type requirement with dependent member type.
+// _ = differential(at: x, in: {
+// (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+// return x - x
+// })
+// }
+
+// TF-523
+struct TF_523_Struct : Differentiable & AdditiveArithmetic {
+ var a: Float = 1
+ typealias TangentVector = TF_523_Struct
+ typealias AllDifferentiableVariables = TF_523_Struct
+}
+
+@differentiable
+func TF_523_f(_ x: TF_523_Struct) -> Float {
+ return x.a * 2
+}
+
+// TF-534: Thunk substitution map remapping.
+protocol TF_534_Layer : Differentiable {
+ associatedtype Input : Differentiable
+ associatedtype Output : Differentiable
+
+ @differentiable
+ func callAsFunction(_ input: Input) -> Output
+}
+struct TF_534_Tensor<Scalar> : Differentiable {}
+
+func TF_534<Model: TF_534_Layer>(
+ _ model: inout Model, inputs: Model.Input
+) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
+ return valueWithDifferential(at: model) { model -> Model.Output in
+ return model(inputs)
+ }.0
+}
+
+// TODO: uncomment once control flow is supported in forward mode.
+// TF-652: Test VJPEmitter substitution map generic signature.
+// The substitution map should have the VJP's generic signature, not the
+// original function's.
+// struct TF_652<Scalar> {}
+// extension TF_652 : Differentiable where Scalar : FloatingPoint {}
+
+// @differentiable(wrt: x where Scalar: FloatingPoint)
+// func test<Scalar: Numeric>(x: TF_652<Scalar>) -> TF_652<Scalar> {
+// for _ in 0..<10 {
+// let _ = x
+// }
+// return x
+// }
+
+//===----------------------------------------------------------------------===//
+// Tracked Generic.
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("GenericTrackedIdentity") {
+ func identity<T : Differentiable>(_ x: Tracked<T>) -> Tracked<T> {
+ return x
+ }
+ let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in
+ identity(Tracked(x))
+ }
+ expectEqual(4, y)
+ expectEqual(1, differential(1))
+}
+
+ForwardModeTests.test("GenericTrackedBinaryAdd") {
+ func add<T>(_ x: Tracked<T>, _ y: Tracked<T>) -> Tracked<T>
+ where T: Differentiable, T == T.TangentVector {
+ return x + y
+ }
+ let (y, differential) = valueWithDifferential(at: 4, 5) {
+ (x: Float, y: Float) in
+ add(Tracked(x), Tracked(y))
+ }
+ expectEqual(9, y)
+ expectEqual(2, differential(1, 1))
+}
+
+ForwardModeTests.test("GenericTrackedBinaryLets") {
+ func add<T>(_ x: Tracked<T>, _ y: Tracked<T>) -> Tracked<T>
+ where T: Differentiable & SignedNumeric,
+ T == T.TangentVector,
+ T == T.Magnitude {
+ let a = x * y // xy
+ let b = a + a // 2xy
+ return b + b // 4xy
+ }
+ // 4y + 4x
+ let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in
+ add(Tracked(x), Tracked(y))
+ }
+ expectEqual(80, y)
+ expectEqual(36, differential(1, 1))
+}
+
+ForwardModeTests.test("GenericTrackedBinaryVars") {
+ func add<T>(_ x: Tracked<T>, _ y: Tracked<T>) -> Tracked<T>
+ where T: Differentiable & SignedNumeric,
+ T == T.TangentVector,
+ T == T.Magnitude {
+ var a = x * y // xy
+ a = a + a // 2xy
+ var b = x
+ b = a
+ return b + b // 4xy
+ }
+ // 4y + 4x
+ let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in
+ add(Tracked(x), Tracked(y))
+ }
+ expectEqual(80, y)
+ expectEqual(36, differential(1, 1))
+}
+
+ForwardModeTests.testWithLeakChecking("TrackedDifferentiableFuncType") {
+ func valAndDeriv(
+ f: @escaping @differentiable (Tracked<Float>) -> Tracked<Float>
+ ) -> (Tracked<Float>, Tracked<Float>) {
+ let (y, diff) = valueWithDifferential(at: 5, in: f)
+ return (y, diff(1))
+ }
+
+ func func1(_ x: Tracked<Float>) -> Tracked<Float> {
+ let a = x + x // 2x
+ let b = a + a // 4x
+ return b * b // 16x^2
+ }
+ let (val1, dv1) = valAndDeriv(f: func1)
+ expectEqual(400, val1)
+ expectEqual(160, dv1)
+}
+
+//===----------------------------------------------------------------------===//
+// Classes
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("Final") {
+ final class Final : Differentiable {
+ func method(_ x: Float) -> Float {
+ return x * x
+ }
+ }
+
+ for i in -5...5 {
+ expectEqual(
+ Float(i) * 2,
+ derivative(at: Float(i)) { x in Final().method(x) })
+ }
+}
+
+ForwardModeTests.test("Simple") {
+ class Super {
+ @differentiable(wrt: x)
+ func f(_ x: Float) -> Float {
+ return 2 * x
+ }
+ @derivative(of: f)
+ final func jvpf(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (f(x), { v in 2 * v })
+ }
+ @derivative(of: f)
+ final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (f(x), { v in 2 * v })
+ }
+ }
+
+ class SubOverride : Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+ }
+
+ class SubOverrideCustomDerivatives : Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+ @derivative(of: f)
+ final func jvpf2(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+ @derivative(of: f)
+ final func vjpf2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ func classValueWithDerivative(_ c: Super) -> (Float, Float) {
+ return valueWithDerivative(at: 1) { c.f($0) }
+ }
+
+ expectEqual((2, 2), classValueWithDerivative(Super()))
+ expectEqual((3, 3), classValueWithDerivative(SubOverride()))
+ expectEqual((3, 3), classValueWithDerivative(SubOverrideCustomDerivatives()))
+}
+
+ForwardModeTests.test("SimpleWrtSelf") {
+ class Super : Differentiable {
+ var base: Float
+ // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
+ var _nontrivial: [Float] = []
+
+ // FIXME(SR-12175): Fix forward-mode differentiation crash.
+ // @differentiable
+ required init(base: Float) {
+ self.base = base
+ }
+
+ @differentiable(wrt: (self, x))
+ func f(_ x: Float) -> Float {
+ return base * x
+ }
+ @derivative(of: f)
+ final func jvpf(_ x: Float) -> (value: Float, differential: (TangentVector, Float) -> Float) {
+ return (f(x), { (dself, dx) in dself.base * dx })
+ }
+ @derivative(of: f)
+ final func vjpf(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) {
+ let base = self.base
+ return (f(x), { v in
+ (TangentVector(base: v * x, _nontrivial: []), base * v)
+ })
+ }
+ }
+
+ class SubOverride : Super {
+ @differentiable(wrt: (self, x))
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+ }
+
+ class SubOverrideCustomDerivatives : Super {
+ @differentiable(wrt: (self, x))
+ @differentiable(wrt: x)
+ override func f(_ x: Float) -> Float {
+ return 3 * x
+ }
+ @derivative(of: f, wrt: x)
+ final func jvpf2(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+ @derivative(of: f, wrt: x)
+ final func vjpf2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ // FIXME(SR-12175): Fix forward-mode differentiation crash.
+ // let v = Super.TangentVector(base: 100, _nontrivial: [])
+ // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v))
+ // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v))
+ // expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v))
+
+ // `valueWithDerivative` is not used because the derivative requires `Super`
+ // to conform to `FloatingPoint`.
+ func classDifferential(
+ _ c: Super
+ ) -> (Float, (Super.TangentVector, Float) -> Float) {
+ return valueWithDifferential(at: c, 10) { (c: Super, x: Float) in c.f(x) }
+ }
+
+ let (y1, diff1) = classDifferential(Super(base: 5))
+ expectEqual(50, y1)
+ let c1 = Super.TangentVector(base: 1, _nontrivial: [])
+ expectEqual(1, diff1(c1, 1))
+ let (y2, diff2) = classDifferential(SubOverride(base: 5))
+ expectEqual(30, y2)
+ let c2 = SubOverride.TangentVector(base: 1, _nontrivial: [])
+ expectEqual(3, diff2(c2, 1))
+ let (y3, diff3) = classDifferential(SubOverrideCustomDerivatives(base: 5))
+ expectEqual(30, y3)
+ let c3 = SubOverrideCustomDerivatives.TangentVector(base: 1, _nontrivial: [])
+ expectEqual(3, diff3(c3, 1))
+}
+
+//===----------------------------------------------------------------------===//
+// Protocols
+//===----------------------------------------------------------------------===//
+
+protocol Prot : Differentiable {
+ @differentiable(wrt: x)
+ func foo(x: Float) -> Float
+}
+ForwardModeTests.test("Simple Protocol") {
+ struct Linear: Prot, AdditiveArithmetic {
+ typealias TangentVector = Linear
+
+ let m: Float
+ let b: Float
+
+ @differentiable(wrt: x)
+ func foo(x: Float) -> Float {
+ return m * x + b
+ }
+ }
+
+ func genericFoo<T: Prot>(_ t: T, _ x: Float) -> Float {
+ t.foo(x: x)
+ }
+ let inst = Linear(m: 5, b: -2)
+ let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) }
+ expectEqual(23, y1)
+ expectEqual(5, diff1(1))
+}
+
+protocol DiffReq : Differentiable {
+ @differentiable(wrt: (self, x))
+ func f(_ x: Float) -> Float
+}
+
+extension DiffReq where TangentVector : AdditiveArithmetic {
+ @inline(never) // Prevent specialization, to test all witness code.
+ func derivF(at x: Float) -> Float {
+ return (valueWithDifferential(at: x) { x in self.f(x) }).1(1)
+ }
+}
+
+struct Quadratic : DiffReq, AdditiveArithmetic {
+ typealias TangentVector = Quadratic
+
+ @differentiable
+ let a: Float
+
+ @differentiable
+ let b: Float
+
+ @differentiable
+ let c: Float
+
+ init(_ a: Float, _ b: Float, _ c: Float) {
+ self.a = a
+ self.b = b
+ self.c = c
+ }
+
+ @differentiable(wrt: (self, x))
+ func f(_ x: Float) -> Float {
+ return a * x * x + b * x + c
+ }
+}
+
+ForwardModeTests.test("ProtocolFunc") {
+ expectEqual(12, Quadratic(11, 12, 13).derivF(at: 0))
+ expectEqual(2 * 11 + 12, Quadratic(11, 12, 13).derivF(at: 1))
+ expectEqual(2 * 11 * 2 + 12, Quadratic(11, 12, 13).derivF(at: 2))
+}
+
+// MARK: Constructor, accessor, and subscript requirements.
+
+protocol FunctionsOfX: Differentiable {
+ @differentiable
+ init(x: Float)
+
+ @differentiable
+ var x: Float { get }
+
+ @differentiable
+ var y: Float { get }
+
+ @differentiable
+ var z: Float { get }
+
+ @differentiable
+ subscript() -> Float { get }
+}
+
+struct TestFunctionsOfX: FunctionsOfX {
+ @differentiable
+ init(x: Float) {
+ self.x = x
+ self.y = x * x
+ }
+
+ /// x = x
+ var x: Float
+
+ /// y = x * x
+ var y: Float
+
+ /// z = x * x + x
+ var z: Float {
+ return y + x
+ }
+
+ @differentiable
+ subscript() -> Float {
+ return z
+ }
+}
+
+@inline(never) // Prevent specialization, to test all witness code.
+func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
+ -> (Float, Float, Float, Float)
+{
+ let dxdx = derivative(at: x) { x in F(x: x).x }
+ let dydx = derivative(at: x) { x in F(x: x).y }
+ let dzdx = derivative(at: x) { x in F(x: x).z }
+ let dsubscriptdx = derivative(at: x) { x in F(x: x)[] }
+ return (dxdx, dydx, dzdx, dsubscriptdx)
+}
+
+ForwardModeTests.test("constructor, accessor, subscript") {
+ expectEqual(
+ (1.0, 4.0, 5.0, 5.0),
+ derivatives(at: 2.0, in: TestFunctionsOfX.self))
+}
+
+// MARK: - Test witness method SIL type computation.
+
+protocol P : Differentiable {
+ @differentiable(wrt: (x, y))
+ func foo(_ x: Float, _ y: Double) -> Float
+}
+struct S : P {
+ @differentiable(wrt: (x, y))
+ func foo(_ x: Float, _ y: Double) -> Float {
+ return x
+ }
+}
+
+// MARK: - Overridden protocol method adding differentiable attribute.
+
+public protocol Distribution {
+ associatedtype Value
+ func logProbability(of value: Value) -> Float
+}
+
+public protocol DifferentiableDistribution: Differentiable, Distribution {
+ @differentiable(wrt: self)
+ func logProbability(of value: Value) -> Float
+}
+
+struct Foo: DifferentiableDistribution {
+ @differentiable(wrt: self)
+ func logProbability(of value: Float) -> Float {
+ .zero
+ }
+}
+
+@differentiable
+func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
+ x.logProbability(of: .zero)
+}
+
+// Adding a more general `@differentiable` attribute.
+public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
+ where Value: Differentiable {
+ @differentiable(wrt: self)
+ @differentiable(wrt: (self, value))
+ func logProbability(of value: Value) -> Float
+}
+
+@differentiable
+func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Float
+ where T.Value: AdditiveArithmetic {
+ x.logProbability(of: value)
+}
+
+protocol DifferentiableFoo {
+ associatedtype T: Differentiable
+ @differentiable(wrt: x)
+ func foo(_ x: T) -> Float
+}
+
+protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
+ @differentiable(wrt: (self, x))
+ func foo(_ x: T) -> Float
+}
+
+struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
+ @differentiable(wrt: (self, x))
+ func foo(_ x: Float) -> Float {
+ x
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Simple Math
+//===----------------------------------------------------------------------===//
+
+ForwardModeTests.test("Arithmetics") {
+ func foo1(x: Float, y: Float) -> Float {
+ return x * y
+ }
+ expectEqual(7, derivative(at: 3, 4, in: foo1))
+ func foo2(x: Float, y: Float) -> Float {
+ return -x * y
+ }
+ expectEqual(-7, derivative(at: 3, 4, in: foo2))
+ func foo3(x: Float, y: Float) -> Float {
+ return -x + y
+ }
+ expectEqual(0, derivative(at: 3, 4, in: foo3))
+}
+
+ForwardModeTests.test("Fanout") {
+ func foo1(x: Float) -> Float {
+ x - x
+ }
+ expectEqual(0, derivative(at: 100, in: foo1))
+ func foo2(x: Float) -> Float {
+ x + x
+ }
+ expectEqual(2, derivative(at: 100, in: foo2))
+ func foo3(x: Float, y: Float) -> Float {
+ x + x + x * y
+ }
+ expectEqual(7, derivative(at: 3, 2, in: foo3))
+}
+
+ForwardModeTests.test("FunctionCall") {
+ func foo(_ x: Float, _ y: Float) -> Float {
+ return 3 * x + { $0 * 3 }(3) * y
+ }
+ expectEqual(12, derivative(at: 3, 4, in: foo))
+ expectEqual(3, derivative(at: 3) { x in foo(x, 4) })
+}
+
+ForwardModeTests.test("ResultSelection") {
+ func foo(_ x: Float, _ y: Float) -> (Float, Float) {
+ return (x + 1, y + 2)
+ }
+ expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).0 }))
+ expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).1 }))
+}
+
+ForwardModeTests.test("CaptureLocal") {
+ let z: Float = 10
+ func foo(_ x: Float) -> Float {
+ return z * x
+ }
+ expectEqual(10, derivative(at: 0, in: foo))
+}
+
+var globalVar: Float = 10
+ForwardModeTests.test("CaptureGlobal") {
+ func foo(x: Float) -> Float {
+ globalVar += 20
+ return globalVar * x
+ }
+ expectEqual(30, derivative(at: 0, in: foo))
+}
+
+ForwardModeTests.test("Mutation") {
+ func fourthPower(x: Float) -> Float {
+ var a = x
+ a = a * x
+ a = a * x
+ return a * x
+ }
+ expectEqual(4 * 27, derivative(at: 3, in: fourthPower))
+}
+
+// Tests TF-21.
+ForwardModeTests.test("StructMemberwiseInitializer") {
+ struct Foo : AdditiveArithmetic, Differentiable {
+ var stored: Float
+ var computed: Float {
+ return stored * stored
+ }
+ }
+
+ let derivFoo = differential(at: Float(4), in: { input -> Foo in
+ let foo = Foo(stored: input)
+ let foo2 = foo + foo
+ return Foo(stored: foo2.stored)
+ })(1)
+ expectEqual(Foo.TangentVector(stored: 2), derivFoo)
+
+ let computed = derivative(at: Float(4)) { input -> Float in
+ let foo = Foo(stored: input)
+ return foo.computed
+ }
+ expectEqual(8, computed)
+
+ let derivProduct = derivative(at: Float(4)) { input -> Float in
+ let foo = Foo(stored: input)
+ return foo.computed * foo.stored
+ }
+ expectEqual(48, derivProduct)
+
+ struct Custom : AdditiveArithmetic, Differentiable {
+ var x: Float
+
+ // Custom initializer with `@differentiable`.
+ @differentiable
+ init(x: Float) {
+ self.x = x
+ }
+ }
+
+ let derivCustom = differential(at: Float(4), in: { input -> Custom in
+ let foo = Custom(x: input)
+ return foo + foo
+ })(1)
+ expectEqual(Custom.TangentVector(x: 2), derivCustom)
+}
+
+// Tests TF-319: struct with non-differentiable constant stored property.
+ForwardModeTests.test("StructConstantStoredProperty") {
+ struct TF_319 : Differentiable {
+ var x: Float
+ @noDerivative let constant = Float(2)
+
+ @differentiable
+ init(x: Float) {
+ self.x = x
+ }
+
+ @differentiable(wrt: (self, input))
+ func applied(to input: Float) -> Float {
+ return x * constant * input
+ }
+ }
+ func testStructInit(to input: Float) -> Float {
+ let model = TF_319(x: 10)
+ return model.applied(to: input)
+ }
+ expectEqual(6, derivative(at: 10, in: { TF_319(x: $0).applied(to: 3) }))
+ expectEqual(20, derivative(at: 3, in: testStructInit))
+}
+
+ForwardModeTests.test("StructMutation") {
+ struct Point : AdditiveArithmetic, Differentiable {
+ var x: Float
+ var y: Float
+ var z: Float
+ }
+
+ func double(_ input: Float) -> Point {
+ let point = Point(x: input, y: input, z: input)
+ return point + point
+ }
+ expectEqual(Point(x: 2, y: 2, z: 2), differential(at: 4, in: double)(1))
+
+ func fifthPower(_ input: Float) -> Float {
+ var point = Point(x: input, y: input, z: input)
+ point.x = point.x * input
+ point.y = point.x * input
+ return point.x * point.y
+ }
+ expectEqual(405, derivative(at: 3, in: fifthPower))
+
+ func mix(_ input: Float) -> Float {
+ var tuple = (point: Point(x: input, y: input, z: input), float: input)
+ tuple.point.x = tuple.point.x * tuple.float
+ tuple.point.y = tuple.point.x * input
+ return tuple.point.x * tuple.point.y
+ }
+ expectEqual(405, derivative(at: 3, in: mix))
+
+ // Test TF-282.
+ struct Add : Differentiable {
+ var bias: Float
+ func applied(to input: Float) -> Float {
+ var tmp = input
+ tmp = tmp + bias
+ return tmp
+ }
+ }
+ expectEqual(1, derivative(at: 1) { m in Add(bias: m).applied(to: 1) })
+}
+
+ForwardModeTests.test("StructGeneric") {
+ struct Generic<T : AdditiveArithmetic & Differentiable> : AdditiveArithmetic, Differentiable {
+ var x: T
+ var y: T
+ var z: T
+ }
+
+ let deriv = differential(at: Float(3), in: { input -> Generic<Float> in
+ var generic = Generic(x: input, y: input, z: input)
+ return generic
+ })(1)
+ expectEqual(Generic<Float>.TangentVector(x: 1, y: 1, z: 1), deriv)
+
+ func fifthPower(_ input: Float) -> Float {
+ var generic = Generic(x: input, y: input, z: input)
+ generic.x = generic.x * input
+ generic.y = generic.x * input
+ return generic.x * generic.y
+ }
+ expectEqual(405, derivative(at: 3, in: fifthPower))
+}
+
+ForwardModeTests.test("SubsetIndices") {
+ func deriv(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float {
+ return derivative(at: 1) { x in lossFunction(x * x, 10.0) }
+ }
+ expectEqual(2, deriv { x, y in x + y })
+
+ func derivWRTNonDiff(_ lossFunction: @differentiable (Float, @noDerivative Int) -> Float) -> Float {
+ return derivative(at: 2) { x in lossFunction(x * x, 10) }
+ }
+ expectEqual(4, derivWRTNonDiff { x, y in x + Float(y) })
+}
+
+ForwardModeTests.test("ForceUnwrapping") {
+ func forceUnwrap<T: Differentiable & FloatingPoint>(_ t: T) -> Float where T == T.TangentVector {
+ derivative(at: t, Float(3)) { (x, y) in
+ (x as! Float) * y
+ }
+ }
+ expectEqual(5, forceUnwrap(Float(2)))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/forward_mode_sil.swift b/test/AutoDiff/downstream/forward_mode_sil.swift
new file mode 100644
index 0000000..59be70c
--- /dev/null
+++ b/test/AutoDiff/downstream/forward_mode_sil.swift
@@ -0,0 +1,86 @@
+// RUN: %target-swift-frontend -emit-sil -verify -enable-experimental-forward-mode-differentiation -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
+// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -enable-experimental-forward-mode-differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// REQUIRES: asserts
+
+
+//===----------------------------------------------------------------------===//
+// Unary
+//===----------------------------------------------------------------------===//
+
+@differentiable
+@_silgen_name("unary")
+func unary(_ x: Float) -> Float {
+ return x * x * x
+}
+// CHECK-DATA-STRUCTURES: struct _AD__unary_bb0__DF__src_0_wrt_0 {
+// CHECK-DATA-STRUCTURES: var differential_0: (Float, Float) -> Float
+// CHECK-DATA-STRUCTURES: var differential_1: (Float, Float) -> Float
+// CHECK-DATA-STRUCTURES: }
+// CHECK-DATA-STRUCTURES: enum _AD__unary_bb0__Succ__src_0_wrt_0 {
+// CHECK-DATA-STRUCTURES: }
+
+// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float):
+// CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [parameters 0 1] [results 0] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
+// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @noDerivative @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [parameters 0 1] [results 0] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
+// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @noDerivative @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: ([[ORIG_RESULT_2:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__unary_bb0__DF__src_0_wrt_0 ([[MULT_DIFF_1]] : $@callee_guaranteed (Float, Float) -> Float, [[MULT_DIFF_2]] : $@callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: [[UNARY_DIFFERENTIAL:%.*]] = function_ref @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float
+// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[UNARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float
+// CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT_2]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float) -> Float)
+// CHECK-SIL: return [[RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
+
+// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float {
+// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__unary_bb0__DF__src_0_wrt_0):
+// CHECK-SIL: ([[MULT_DIFF_1:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_struct %1 : $_AD__unary_bb0__DF__src_0_wrt_0
+// CHECK-SIL: [[TEMP_TAN_1:%.*]] = apply [[MULT_DIFF_1]]([[X_TAN]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float
+// CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF_2]]([[TEMP_TAN_1]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float
+// CHECK-SIL: return [[TAN_RESULT]] : $Float
+
+//===----------------------------------------------------------------------===//
+// Binary
+//===----------------------------------------------------------------------===//
+
+@differentiable
+@_silgen_name("binary")
+func binary(x: Float, y: Float) -> Float {
+ return x * y
+}
+
+// CHECK-DATA-STRUCTURES: struct _AD__binary_bb0__DF__src_0_wrt_0_1 {
+// CHECK-DATA-STRUCTURES: var differential_0: (Float, Float) -> Float
+// CHECK-DATA-STRUCTURES: }
+// CHECK-DATA-STRUCTURES: enum _AD__binary_bb0__Succ__src_0_wrt_0_1 {
+// CHECK-DATA-STRUCTURES: }
+
+// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__jvp_src_0_wrt_0_1 : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) {
+// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float, [[Y_ARG:%.*]] : $Float):
+// CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
+// CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [parameters 0 1] [results 0] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
+// CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @noDerivative @thin Float.Type) -> Float
+// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: ([[ORIG_RESULT:%.*]], [[MULT_DIFF:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__binary_bb0__DF__src_0_wrt_0_1 ([[MULT_DIFF]] : $@callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: [[BINARY_DIFFERENTIAL:%.*]] = function_ref @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float
+// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[BINARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float
+// CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float, Float) -> Float)
+// CHECK-SIL: return [[RESULT:%.*]] : $(Float, @callee_guaranteed (Float, Float) -> Float)
+
+// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float {
+// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[Y_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__binary_bb0__DF__src_0_wrt_0_1):
+// CHECK-SIL: [[MULT_DIFF:%.*]] = destructure_struct [[DIFF_STRUCT]] : $_AD__binary_bb0__DF__src_0_wrt_0_1
+// CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF]]([[X_TAN]], [[Y_TAN]]) : $@callee_guaranteed (Float, Float) -> Float
+// CHECK-SIL: return [[TAN_RESULT]] : $Float
diff --git a/test/AutoDiff/downstream/generics.swift b/test/AutoDiff/downstream/generics.swift
new file mode 100644
index 0000000..fa664d5
--- /dev/null
+++ b/test/AutoDiff/downstream/generics.swift
@@ -0,0 +1,388 @@
+// RUN: %target-swift-emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
+
+@_silgen_name("identity")
+func identity<T : Differentiable>(_ x: T) -> T {
+ return x
+}
+_ = gradient(at: Float(1), in: { x in identity(x) })
+
+// Test AdjointEmitter local buffer allocation.
+// Verify that local buffers are immediately set to zero.
+
+// CHECK-SIL-LABEL: sil private @AD__identity__pullback_src_0_wrt_0_s14DifferentiableRzl
+// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.TangentVector
+// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
+// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type
+// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.TangentVector>([[ORIG_COTAN]], [[ORIG_COTAN_METATYPE]])
+// CHECK-SIL: }
+
+// Test TF-201: differentiate direct references to generic function.
+// This involves reabstraction thunk differentiation.
+
+_ = gradient(at: Float(1), in: identity)
+
+protocol DifferentiableAdditiveArithmetic: Differentiable & AdditiveArithmetic {
+ @differentiable
+ static func + (lhs: Self, rhs: Self) -> Self
+}
+extension Float: DifferentiableAdditiveArithmetic {}
+func generic<T: DifferentiableAdditiveArithmetic>(_ x: T) -> T {
+ x + x + x
+}
+_ = gradient(at: Float(10), in: generic)
+
+struct Wrapper<Scalar : Differentiable> : Differentiable {
+ var value: Scalar
+ init(_ value: Scalar) { self.value = value }
+}
+func generic<T>(_ x: Wrapper<T>) -> T {
+ return x.value
+}
+_ = gradient(at: Wrapper<Float>(1), in: generic)
+
+func generic2<T: Differentiable, U: Differentiable>(_ x: T, _ y: Float, _ z: U) -> T {
+ return x
+}
+func foo<T>(_ x: Wrapper<T>) {
+ _ = gradient(at: Float(1), 2, x, in: generic2)
+}
+
+// Test case where associated derivative function's requirements are met.
+extension Wrapper where Scalar : Numeric {
+ @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
+ func mean() -> Wrapper {
+ return self
+ }
+
+ @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
+ func variance() -> Wrapper {
+ return mean() // ok
+ }
+}
+_ = pullback(at: Wrapper<Float>(1), in: { $0.variance() })
+
+// Tests TF-277.
+protocol Layer : Differentiable {
+ associatedtype Output : Differentiable
+}
+struct SupervisedTrainer<Model : Layer> {
+ var model: Model
+ var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
+ func fit(y: Model.Output) {
+ _ = gradient(at: y) { y in return self.lossFunction(y, y) }
+ }
+}
+
+// Tests TF-440.
+struct TF_440_Input<Input: Differentiable, State: Differentiable>
+ : Differentiable {
+ var input: Input
+ var state: State
+}
+struct TF_440<T : Differentiable> {
+ @differentiable
+ func applied(to input: TF_440_Input<Float, Float>) -> Float {
+ return input.state
+ }
+
+ @differentiable
+ func applied(to input: TF_440_Input<T, Float>) -> Float {
+ return input.state
+ }
+
+ @differentiable
+ func applied(to input: TF_440_Input<T, Float>) -> T {
+ return input.input
+ }
+}
+
+// Tests TF-508: differentiation requirements with dependent member types.
+protocol TF_508_Proto {
+ associatedtype Scalar
+}
+extension TF_508_Proto where Scalar : FloatingPoint {
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Conformance requirement with dependent member type.
+ Self.TangentVector : TF_508_Proto
+ )
+ static func +(lhs: Self, rhs: Self) -> Self {
+ return lhs
+ }
+
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Same-type requirement with dependent member type.
+ Self.TangentVector == Float
+ )
+ static func -(lhs: Self, rhs: Self) -> Self {
+ return lhs
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector : TF_508_Proto {
+ @derivative(of: +)
+ static func vjpAdd(lhs: Self, rhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs, { v in (v, v) })
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector == Float {
+ @derivative(of: -)
+ static func vjpSubtract(lhs: Self, rhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs, { v in (v, v) })
+ }
+}
+
+struct TF_508_Struct<Scalar : AdditiveArithmetic>
+ : TF_508_Proto, AdditiveArithmetic {}
+extension TF_508_Struct : Differentiable where Scalar : Differentiable {
+ typealias TangentVector = TF_508_Struct
+}
+
+func TF_508() {
+ let x = TF_508_Struct<Float>()
+ // Test conformance requirement with dependent member type.
+ _ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+ return x + x
+ })
+ // Test same-type requirement with dependent member type.
+ _ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+ return x - x
+ })
+}
+
+// TF-523
+struct TF_523_Struct : Differentiable & AdditiveArithmetic {
+ var a: Float = 1
+ typealias TangentVector = TF_523_Struct
+}
+
+@differentiable
+func TF_523_f(_ x: TF_523_Struct) -> Float {
+ return x.a * 2
+}
+
+// TF-534: Thunk substitution map remapping.
+protocol TF_534_Layer : Differentiable {
+ associatedtype Input : Differentiable
+ associatedtype Output : Differentiable
+
+ @differentiable
+ func callAsFunction(_ input: Input) -> Output
+}
+struct TF_534_Tensor<Scalar> : Differentiable {}
+
+func TF_534<Model: TF_534_Layer>(
+ _ model: inout Model, inputs: Model.Input
+) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
+ return valueWithPullback(at: model) { model -> Model.Output in
+ return model(inputs)
+ }.0
+}
+
+// TF-546: Test that SILGen linear map thunk performs correct reabstraction.
+struct TF_546<T: FloatingPoint>: AdditiveArithmetic {
+ var real: T
+ var imaginary: T
+
+ @differentiable(where T: Differentiable, T == T.TangentVector)
+ init(real: T = 0, imaginary: T = 0) {
+ self.real = real
+ self.imaginary = imaginary
+ }
+}
+extension TF_546: Differentiable where T: Differentiable {
+ typealias TangentVector = TF_546
+}
+extension TF_546 where T: Differentiable, T == T.TangentVector {
+ @derivative(of: init)
+ static func _vjpInit(real: T, imaginary: T) -> (value: TF_546, pullback: (TF_546) -> (T, T)) {
+ return (TF_546(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
+ }
+}
+let _: @differentiable(Float, Float) -> TF_546<Float> = { r, i in
+ TF_546(real: r, imaginary: i)
+}
+
+// TF-652: Test VJPEmitter substitution map generic signature.
+// The substitution map should have the VJP's generic signature, not the
+// original function's.
+struct TF_652<Scalar> {}
+extension TF_652 : Differentiable where Scalar : FloatingPoint {}
+
+@differentiable(wrt: x where Scalar: FloatingPoint)
+func test<Scalar: Numeric>(x: TF_652<Scalar>) -> TF_652<Scalar> {
+ for _ in 0..<10 {
+ let _ = x
+ }
+ return x
+}
+
+// TF-682: Test that SILGen linear map thunk performs correct reabstraction.
+protocol TF_682_Proto {
+ associatedtype Scalar
+}
+extension TF_682_Proto where Scalar : FloatingPoint {
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Same-type requirement with dependent member type.
+ Self.TangentVector == Float
+ )
+ func foo(lhs: Self) -> Self {
+ return lhs
+ }
+}
+extension TF_682_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector == Float {
+ @derivative(of: foo)
+ func vjpFoo(lhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs, { v in (v, v) })
+ }
+}
+
+// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
+/*
+// TF-688: Test generic curry thunk cloning.
+public struct TF_688_Struct<Scalar> {
+ var x: Scalar
+}
+extension TF_688_Struct: Differentiable where Scalar: Differentiable {
+ @differentiable
+ public static func id(x: Self) -> Self {
+ return x
+ }
+}
+@differentiable(wrt: x)
+public func TF_688<Scalar: Differentiable>(
+ _ x: TF_688_Struct<Scalar>,
+ reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
+) -> TF_688_Struct<Scalar> {
+ reduction(x)
+}
+*/
+
+// TF-697: Test generic requirements of generated derivative function.
+protocol TF_697_Module: Differentiable {
+ associatedtype Input
+ associatedtype Output: Differentiable
+
+ @differentiable(wrt: self)
+ func callModule(_ input: Input) -> Output
+}
+protocol TF_697_Layer: TF_697_Module where Input: Differentiable {
+ @differentiable
+ func callLayer(_ input: Input) -> Output
+}
+struct TF_697_Sequential<Layer1: TF_697_Module, Layer2: TF_697_Layer>: TF_697_Module
+ where Layer1.Output == Layer2.Input {
+ var layer1: Layer1
+ var layer2: Layer2
+
+ @differentiable(wrt: self)
+ func callModule(_ input: Layer1.Input) -> Layer2.Output {
+ layer2.callLayer(layer1.callModule(input))
+ }
+}
+extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
+ @differentiable
+ func callLayer(_ input: Layer1.Input) -> Layer2.Output {
+ layer2.callLayer(layer1.callLayer(input))
+ }
+}
+
+// TF-817: Test remapping `apply` callee types in derivative function context.
+struct TF_817<T> {
+ func foo(_ index: Int) -> T {
+ fatalError()
+ }
+}
+extension TF_817: Differentiable where T: Differentiable {
+ @derivative(of: foo)
+ func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) {
+ fatalError()
+ }
+}
+extension TF_817 {
+ @differentiable(wrt: self where T: Differentiable)
+ public func test(index: Int) -> T {
+ return self.foo(0) // crash happened here
+ }
+}
+
+// TF-886: Test `partial_apply` of linear map subset parameters thunk.
+@differentiable
+func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
+ return 0
+}
+@differentiable
+func TF_886_bar<T>(x: Float, y: T) -> Float {
+ return TF_886_foo(x, y, 0)
+}
+
+// Test layout requirements.
+
+// The layout requirement is "contextual": the requirement is not on `T`, the
+// differentiable function parameter/result type.
+struct ContextualLayoutRequirement<T: Differentiable, U: AnyObject> {
+ var stored: T
+}
+extension ContextualLayoutRequirement {
+ func test(_ x: T) {
+ let _: @differentiable (T) -> T = { _ in self.stored }
+ let _: @differentiable (T) -> T = { $0 }
+ }
+}
+// The layout requirement directly involves `T`, the differentiable function
+// parameter/result type.
+// TODO(TF-851): Uncomment the tests below after `@differentiable` function
+// SILGen thunking is fixed.
+/*
+struct LayoutRequirement<T: AnyObject & Differentiable> {
+ var stored: T
+}
+extension LayoutRequirement {
+ func test(_ x: T) {
+ let _: @differentiable (T) -> T = { _ in self.stored }
+ let _: @differentiable (T) -> T = { $0 }
+ }
+}
+*/
+
+// Test superclass requirements.
+
+class Super: Differentiable {}
+
+// The superclass requirement is "contextual": the requirement is not on `T`,
+// the differentiable function parameter/result type.
+struct ContextualSuperclassRequirement<T: Differentiable, U: Super> {
+ var stored: T
+}
+extension ContextualSuperclassRequirement {
+ func test(_ x: T) {
+ let _: @differentiable (T) -> T = { _ in self.stored }
+ let _: @differentiable (T) -> T = { $0 }
+ }
+}
+// The superclass requirement directly involves `T`, the differentiable
+// function parameter/result type.
+// TODO(TF-851): Uncomment the tests below after `@differentiable` function
+// SILGen thunking is fixed.
+/*
+struct SuperclassRequirement<T: Super & Differentiable> {
+ var stored: T
+}
+extension SuperclassRequirement {
+ func test(_ x: T) {
+ let _: @differentiable (T) -> T = { _ in self.stored }
+ let _: @differentiable (T) -> T = { $0 }
+ }
+}
+*/
+
+// TODO: add more tests.
diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift
new file mode 100644
index 0000000..edbb5a2
--- /dev/null
+++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift
@@ -0,0 +1,49 @@
+protocol Protocol: Differentiable {
+ // expected-note @+2 {{protocol requires function 'internalMethod1' with type '(Float) -> Float'}}
+ @differentiable(wrt: (self, x))
+ func internalMethod1(_ x: Float) -> Float
+
+ // expected-note @+3 {{protocol requires function 'internalMethod2' with type '(Float) -> Float'}}
+ @differentiable(wrt: x)
+ @differentiable(wrt: (self, x))
+ func internalMethod2(_ x: Float) -> Float
+
+ @differentiable(wrt: x)
+ @differentiable(wrt: (self, x))
+ func internalMethod3(_ x: Float) -> Float
+}
+
+protocol Protocol2: Differentiable {
+ @differentiable(wrt: (self, x))
+ func internalMethod4(_ x: Float) -> Float
+}
+
+// Note:
+// - No `ConformingStruct: Protocol` conformance exists in this file, so this
+// file should compile just file.
+// - A `ConformingStruct: Protocol` conformance in a different file should be
+// diagnosed to prevent linker errors. Without a diagnostic, compilation of
+// the other file creates external references to symbols for implicit
+// `@differentiable` attributes, even though no such symbols exist.
+// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721
+
+struct ConformingStruct: Differentiable {
+ // Expected: error for missing `@differentiable` attribute.
+ // expected-note @+1 {{non-public instance method 'internalMethod1' must have explicit '@differentiable' attribute to satisfy requirement instance method 'internalMethod1' (in protocol 'Protocol') because it is declared in a different file than the conformance of 'ConformingStruct' to 'Protocol'}} {{3-3=@differentiable }}
+ func internalMethod1(_ x: Float) -> Float {
+ x
+ }
+
+ // Expected: error for missing `@differentiable` superset attribute.
+ // expected-note @+2 {{non-public instance method 'internalMethod2' must have explicit '@differentiable' attribute to satisfy requirement instance method 'internalMethod2' (in protocol 'Protocol') because it is declared in a different file than the conformance of 'ConformingStruct' to 'Protocol'}} {{3-3=@differentiable }}
+ @differentiable(wrt: x)
+ func internalMethod2(_ x: Float) -> Float {
+ x
+ }
+
+ // Expected: no error for missing `@differentiable` subset attribute.
+ @differentiable(wrt: (self, x))
+ func internalMethod3(_ x: Float) -> Float {
+ x
+ }
+}
diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift
new file mode 100644
index 0000000..712d95e
--- /dev/null
+++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift
@@ -0,0 +1,35 @@
+// Test missing protocol requirement `@differentiable` attribute errors for
+// non-public protocol witnesses, when the protocol conformance is declared in a
+// separate file from witnesses.
+//
+// Implicit `@differentiable` attributes cannot be generated for protocol
+// witnesses when the conformance is declared from a separate file from the
+// witness. Otherwise, compilation of the file containing the conformance
+// creates external references to symbols for implicit `@differentiable`
+// attributes, even though no such symbols exist.
+//
+// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721
+
+// Note: `swiftc main.swift other_file.swift` runs three commands:
+// - `swiftc -frontend -primary-file main.swift other_file.swift -o ...`
+// - `swiftc -frontend main.swift -primary-file other_file.swift -o ...`
+// - `/usr/bin/ld ...`
+//
+// `%target-build-swift` performs `swiftc main.swift other_file.swift`, so it is expected to fail (hence `not`).
+// `swiftc -frontend -primary-file main.swift other_file.swift` should fail, so `-verify` is needed.
+// `swiftc -frontend main.swift -primary-file other_file.swift` should succeed, so no need for `-verify`.
+
+// RUN: %target-swift-frontend -c -verify -primary-file %s %S/Inputs/other_file.swift
+// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift
+// RUN: not %target-build-swift %s %S/Inputs/other_file.swift
+
+// Error: conformance is in different file than witnesses.
+// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol'}}
+extension ConformingStruct: Protocol {}
+
+// No error: conformance is in same file as witnesses.
+extension ConformingStruct: Protocol2 {
+ func internalMethod4(_ x: Float) -> Float {
+ x
+ }
+}
diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift
new file mode 100644
index 0000000..ea17866
--- /dev/null
+++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift
@@ -0,0 +1,24 @@
+// RUN: %target-swift-frontend -emit-sil -verify %s
+
+// Test end-to-end differentiation involving implicit `@differentiable`
+// attributes for non-public protocol witnesses.
+//
+// Specifically, test the diagnostic source locations for implicit attributes.
+
+protocol Protocol: Differentiable {
+ // expected-note @+1 {{differentiability required by the corresponding protocol requirement here}}
+ @differentiable(wrt: (self, x))
+ func internalMethod(_ x: Float) -> Float
+}
+
+struct ConformingStruct: Protocol {
+ // Expected:
+ // - No error for missing `@differentiable` attribute on internal protocol witness.
+ // An implicit `@differentiable` attribute should be created.
+ // - A non-differentiability error, because the method body is non-differentiable.
+ // expected-error @+1 {{function is not differentiable}}
+ func internalMethod(_ x: Float) -> Float {
+ // expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
+ return Float(Int(x))
+ }
+}
diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift
new file mode 100644
index 0000000..3653295
--- /dev/null
+++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift
@@ -0,0 +1,23 @@
+// RUN: %target-swift-frontend -print-ast -verify %s | %FileCheck %s
+
+// Test implicit `@differentiable` attributes for non-public protocol witnesses.
+
+protocol InternalProtocol: Differentiable {
+ @differentiable(wrt: self)
+ @differentiable(wrt: (self, x))
+ func internalMethod(_ x: Float) -> Float
+}
+
+public struct PublicConformingStruct: InternalProtocol {
+ // Expected: no error for missing `@differentiable` attribute on internal protocol witness.
+ // Implicit `@differentiable` attributes should be created.
+ func internalMethod(_ x: Float) -> Float {
+ x
+ }
+}
+
+// CHECK-LABEL: public struct PublicConformingStruct : InternalProtocol {
+// CHECK: @differentiable(wrt: (self, x))
+// CHECK: @differentiable(wrt: self)
+// CHECK: internal func internalMethod(_ x: Float) -> Float
+// CHECK: }
diff --git a/test/AutoDiff/downstream/inout_parameters.swift b/test/AutoDiff/downstream/inout_parameters.swift
new file mode 100644
index 0000000..76cbb2f
--- /dev/null
+++ b/test/AutoDiff/downstream/inout_parameters.swift
@@ -0,0 +1,171 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+// `inout` parameter differentiation tests.
+
+import StdlibUnittest
+
+var InoutParametersTests = TestSuite("InoutParameters")
+
+// TODO(TF-1173): Move floating-point mutating operation tests to
+// `floating_point.swift.gyb` when forward-mode differentiation supports `inout`
+// parameter differentiation.
+
+InoutParametersTests.test("Float.+=") {
+ func mutatingAddWrapper(_ x: Float, _ y: Float) -> Float {
+ var result: Float = x
+ result += y
+ return result
+ }
+ expectEqual((1, 1), gradient(at: 4, 5, in: mutatingAddWrapper))
+ expectEqual((10, 10), pullback(at: 4, 5, in: mutatingAddWrapper)(10))
+}
+
+InoutParametersTests.test("Float.-=") {
+ func mutatingSubtractWrapper(_ x: Float, _ y: Float) -> Float {
+ var result: Float = x
+ result += y
+ return result
+ }
+ expectEqual((1, 1), gradient(at: 4, 5, in: mutatingSubtractWrapper))
+ expectEqual((10, 10), pullback(at: 4, 5, in: mutatingSubtractWrapper)(10))
+}
+
+InoutParametersTests.test("Float.*=") {
+ func mutatingMultiplyWrapper(_ x: Float, _ y: Float) -> Float {
+ var result: Float = x
+ result += y
+ return result
+ }
+ expectEqual((1, 1), gradient(at: 4, 5, in: mutatingMultiplyWrapper))
+ expectEqual((10, 10), pullback(at: 4, 5, in: mutatingMultiplyWrapper)(10))
+}
+
+InoutParametersTests.test("Float./=") {
+ func mutatingDivideWrapper(_ x: Float, _ y: Float) -> Float {
+ var result: Float = x
+ result += y
+ return result
+ }
+ expectEqual((1, 1), gradient(at: 4, 5, in: mutatingDivideWrapper))
+ expectEqual((10, 10), pullback(at: 4, 5, in: mutatingDivideWrapper)(10))
+}
+
+// Simplest possible `inout` parameter differentiation.
+InoutParametersTests.test("InoutIdentity") {
+ // Semantically, an empty function with an `inout` parameter is an identity
+ // function.
+ func inoutIdentity(_ x: inout Float) {}
+
+ func identity(_ x: Float) -> Float {
+ var result = x
+ inoutIdentity(&result)
+ return result
+ }
+ expectEqual(1, gradient(at: 10, in: identity))
+ expectEqual(10, pullback(at: 10, in: identity)(10))
+}
+
+extension Float {
+ // Custom version of `Float.*=`, implemented using `Float.*` and mutation.
+ // Verify that its generated derivative has the same behavior as the
+ // registered derivative for `Float.*=`.
+ @differentiable
+ static func multiplyAssign(_ lhs: inout Float, _ rhs: Float) {
+ lhs = lhs * rhs
+ }
+}
+
+InoutParametersTests.test("ControlFlow") {
+ func sum(_ array: [Float]) -> Float {
+ var result: Float = 0
+ for i in withoutDerivative(at: array.indices) {
+ result += array[i]
+ }
+ return result
+ }
+ expectEqual([1, 1, 1], gradient(at: [1, 2, 3], in: sum))
+
+ func product(_ array: [Float]) -> Float {
+ var result: Float = 1
+ for i in withoutDerivative(at: array.indices) {
+ result *= array[i]
+ }
+ return result
+ }
+ expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: product))
+
+ func productCustom(_ array: [Float]) -> Float {
+ var result: Float = 1
+ for i in withoutDerivative(at: array.indices) {
+ Float.multiplyAssign(&result, array[i])
+ }
+ return result
+ }
+ expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: productCustom))
+}
+
+InoutParametersTests.test("SetAccessor") {
+ struct S: Differentiable {
+ var x: Float
+
+ var computed: Float {
+ get { x }
+ set { x = newValue }
+ }
+ }
+
+ // `squared` implemented using a `set` accessor.
+ func squared(_ x: Float) -> Float {
+ var s = S(x: 1)
+ s.x *= x
+ s.computed *= x
+ return s.x
+ }
+ expectEqual(6, gradient(at: 3, in: squared))
+ expectEqual(8, gradient(at: 4, in: squared))
+}
+
+// Test differentiation wrt `inout` parameters that have a class type.
+InoutParametersTests.test("InoutClassParameter") {
+ class Class: Differentiable {
+ @differentiable
+ var x: Float
+
+ init(_ x: Float) {
+ self.x = x
+ }
+ }
+
+ do {
+ func squaredViaMutation(_ c: inout Class) {
+ c = Class(c.x * c.x)
+ }
+ func squared(_ x: Float) -> Float {
+ var c = Class(x)
+ squaredViaMutation(&c)
+ return c.x
+ }
+ expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
+ expectEqual(200, pullback(at: 10, in: squared)(10))
+ }
+
+ do {
+ func squaredViaModifyAccessor(_ c: inout Class) {
+ // The line below calls `Class.x.modify`.
+ c.x *= c.x
+ }
+ func squared(_ x: Float) -> Float {
+ var c = Class(x)
+ squaredViaModifyAccessor(&c)
+ return c.x
+ }
+ // FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values.
+ // expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
+ // expectEqual(200, pullback(at: 10, in: squared)(10))
+ expectEqual((100, 1), valueWithGradient(at: 10, in: squared))
+ expectEqual(10, pullback(at: 10, in: squared)(10))
+ }
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/irgen_crashers.swift b/test/AutoDiff/downstream/irgen_crashers.swift
new file mode 100644
index 0000000..7860d89
--- /dev/null
+++ b/test/AutoDiff/downstream/irgen_crashers.swift
@@ -0,0 +1,12 @@
+// RUN: %target-swift-frontend -emit-ir %s
+
+// TF-917: `partial_apply` IRGen crash.
+public protocol TF_917: Differentiable {
+ @differentiable
+ func r<A>(_ a: A) -> Float
+}
+@differentiable
+public func tf_917<B: TF_917>(_ b: B) -> Float {
+ return b.r(0.0)
+}
+
diff --git a/test/AutoDiff/downstream/leakchecking.swift b/test/AutoDiff/downstream/leakchecking.swift
new file mode 100644
index 0000000..0e63601
--- /dev/null
+++ b/test/AutoDiff/downstream/leakchecking.swift
@@ -0,0 +1,633 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+// Test differentiation-related memory leaks.
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var LeakCheckingTests = TestSuite("LeakChecking")
+
+struct ExampleLeakModel : Differentiable {
+ var bias: Tracked<Float> = 2.0
+ func applied(to input: Tracked<Float>) -> Tracked<Float> {
+ var v = input + bias
+ return v
+ }
+}
+
+struct FloatPair : Differentiable & AdditiveArithmetic {
+ var first, second: Tracked<Float>
+ init(_ first: Tracked<Float>, _ second: Tracked<Float>) {
+ self.first = first
+ self.second = second
+ }
+}
+
+struct Pair<T : Differentiable, U : Differentiable> : Differentiable
+ where T == T.TangentVector, U == U.TangentVector
+{
+ var first: Tracked<T>
+ var second: Tracked<U>
+ init(_ first: Tracked<T>, _ second: Tracked<U>) {
+ self.first = first
+ self.second = second
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("BasicLetLeakChecking") {
+ do {
+ let model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+ _ = gradient(at: model, x) { m, x in m.applied(to: x) }
+ }
+
+ do {
+ let model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+ _ = gradient(at: model, x) { m, x -> Tracked<Float> in
+ let (y0, y1) = (m.applied(to: x), m.applied(to: x))
+ return y0 + y0 - y1
+ }
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("BasicVarLeakChecking") {
+ var model = ExampleLeakModel()
+ var x: Tracked<Float> = 1.0
+ _ = gradient(at: model, x) { m, x -> Float in
+ var y = x + Tracked<Float>(x.value)
+ return m.applied(to: y).value
+ }
+}
+
+protocol DummyLayer : Differentiable {
+ associatedtype Input : Differentiable
+ associatedtype Output : Differentiable
+
+ @differentiable
+ func requirement(_ input: Input) -> Output
+}
+extension DummyLayer {
+ @differentiable
+ func defaultImpl(_ input: Input) -> Output {
+ return requirement(input)
+ }
+ @derivative(of: defaultImpl)
+ func vjpDefaultImpl(_ input: Input)
+ -> (value: Output,
+ pullback: (Self.Output.TangentVector)
+ -> (Self.TangentVector, Self.Input.TangentVector)) {
+ return valueWithPullback(at: self, input) { $0.requirement($1) }
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("TestProtocolDefaultDerivative") {
+ struct Foo : DummyLayer {
+ // NOTE: Make sure not to override `defaultImpl`.
+ // To reproduce the bug, the VJP of `Foo.requirement` should dispatch to
+ // `DummyLayer.vjpDefaultImpl`.
+
+ @differentiable
+ func requirement(_ input: Tracked<Float>) -> Tracked<Float> {
+ return input
+ }
+ }
+
+ let x = Tracked<Float>(1)
+ let model = Foo()
+ _ = valueWithGradient(at: model) { model in
+ // Call the protocol default implementation method.
+ model.defaultImpl(x)
+ }
+}
+
+protocol Module : Differentiable {
+ associatedtype Input
+ associatedtype Output : Differentiable
+ @differentiable(wrt: self)
+ func callAsFunction(_ input: Input) -> Output
+}
+protocol Layer : Module where Input : Differentiable {
+ @differentiable(wrt: (self, input))
+ func callAsFunction(_ input: Input) -> Output
+}
+
+LeakCheckingTests.testWithLeakChecking("ProtocolRequirements") {
+ struct Dense: Layer {
+ var w = Tracked<Float>(1)
+ @differentiable
+ func callAsFunction(_ input: Tracked<Float>) -> Tracked<Float> {
+ input * w
+ }
+ }
+ struct Model: Module {
+ var dense1 = Dense()
+ var dense2 = Dense()
+ @differentiable
+ func callAsFunction(_ input: Tracked<Int>) -> Tracked<Float> {
+ dense2(dense1(Tracked(Float(input.value))))
+ }
+ }
+ let x = Tracked<Int>(1)
+ let model = Model()
+ _ = valueWithGradient(at: model) { model in
+ model(x)
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("LetStructs") {
+ func structConstructionWithOwnedParams(_ x: Tracked<Float>) -> Tracked<Float> {
+ let z = Tracked(x)
+ return z.value
+ }
+ _ = valueWithGradient(at: 4, in: structConstructionWithOwnedParams)
+}
+
+LeakCheckingTests.testWithLeakChecking("NestedVarStructs") {
+ func nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
+ var y = FloatPair(x + x, x - x)
+ var z = Pair(Tracked(y), x)
+ var w = FloatPair(x, x)
+ y.first = w.second
+ y.second = w.first
+ z.first = Tracked(FloatPair(z.first.value.first - y.first,
+ z.first.value.second + y.first))
+ return y.first + y.second - z.first.value.first + z.first.value.second
+ }
+ expectEqual((8, 2), valueWithGradient(at: 4, in: nestedstruct_var))
+}
+
+LeakCheckingTests.testWithLeakChecking("NestedVarTuples") {
+ func nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
+ var y = (x + x, x - x)
+ var z = (y, x)
+ var w = (x, x)
+ y.0 = w.1
+ y.1 = w.0
+ z.0.0 = z.0.0 - y.0
+ z.0.1 = z.0.1 + y.0
+ return y.0 + y.1 - z.0.0 + z.0.1
+ }
+ expectEqual((8, 2), valueWithGradient(at: 4, in: nestedtuple_var))
+}
+
+// Tests class method differentiation and JVP/VJP vtable entry thunks.
+LeakCheckingTests.testWithLeakChecking("ClassMethods") {
+ class Super {
+ @differentiable(wrt: x)
+ func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 2 * x
+ }
+ @derivative(of: f)
+ final func jvpf(_ x: Tracked<Float>) -> (value: Tracked<Float>, differential: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 2 * v })
+ }
+ @derivative(of: f)
+ final func vjpf(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 2 * v })
+ }
+ }
+
+ class SubOverride : Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ }
+
+ class SubOverrideCustomDerivatives : Super {
+ @differentiable(wrt: x)
+ override func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return 3 * x
+ }
+ @derivative(of: f)
+ final func jvpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, differential: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+ @derivative(of: f)
+ final func vjpf2(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
+ return (f(x), { v in 3 * v })
+ }
+ }
+
+ func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Tracked<Float>) {
+ return valueWithGradient(at: 1) { c.f($0) }
+ }
+ expectEqual((2, 2), classValueWithGradient(Super()))
+ expectEqual((3, 3), classValueWithGradient(SubOverride()))
+ expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives()))
+}
+
+protocol TF_508_Proto {
+ associatedtype Scalar
+}
+extension TF_508_Proto where Scalar : FloatingPoint {
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Conformance requirement with dependent member type.
+ Self.TangentVector : TF_508_Proto
+ )
+ static func +(lhs: Self, rhs: Self) -> Self {
+ return lhs
+ }
+
+ @differentiable(
+ where Self : Differentiable, Scalar : Differentiable,
+ // Same-type requirement with dependent member type.
+ Self.TangentVector == TF_508_Struct<Float>
+ )
+ func subtract(_ other: Self) -> Self {
+ return self
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector : TF_508_Proto {
+ @derivative(of: +)
+ static func jvpAdd(lhs: Self, rhs: Self)
+ -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) {
+ return (lhs, { (dlhs, drhs) in dlhs + drhs })
+ }
+ @derivative(of: +)
+ static func vjpAdd(lhs: Self, rhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs, { v in (v, v) })
+ }
+}
+extension TF_508_Proto where Self : Differentiable,
+ Scalar : FloatingPoint & Differentiable,
+ Self.TangentVector == TF_508_Struct<Float> {
+ @derivative(of: subtract)
+ func jvpSubtract(lhs: Self)
+ -> (value: Self, differential: (TangentVector, TangentVector) -> TangentVector) {
+ return (lhs, { dself, dlhs in dself - dlhs })
+ }
+ @derivative(of: subtract)
+ func vjpSubtract(lhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs, { v in (v, v) })
+ }
+}
+
+struct TF_508_Struct<Scalar : AdditiveArithmetic>
+ : TF_508_Proto, AdditiveArithmetic {}
+extension TF_508_Struct : Differentiable where Scalar : Differentiable {
+ typealias TangentVector = TF_508_Struct
+}
+
+// Test leaks regarding `SILGenFunction::getOrCreateAutoDiffLinearMapThunk`.
+LeakCheckingTests.testWithLeakChecking("LinearMapSILGenThunks") {
+ func testLinearMapSILGenThunks() {
+ let x = TF_508_Struct<Float>()
+ // Test conformance requirement with dependent member type.
+ _ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+ return x + x
+ })
+ // Test same-type requirement with dependent member type.
+ _ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
+ return x.subtract(x)
+ })
+ }
+ testLinearMapSILGenThunks()
+}
+
+LeakCheckingTests.testWithLeakChecking("ParameterConventionMismatchLeakChecking") {
+ struct MyTrackedFloat<Dummy> : Differentiable {
+ // The property with type `Dummy` makes `Self` be indirect.
+ @noDerivative var indirectDummy: Dummy
+ var base: Tracked<Float>
+
+ // Test initializer and static VJP function.
+ // Initializers have owned parameters but functions have shared parameters.
+ @differentiable
+ init(_ base: Tracked<Float>, dummy: Dummy) {
+ self.base = base
+ self.indirectDummy = dummy
+ }
+ @derivative(of: init)
+ static func vjpInit(_ base: Tracked<Float>, dummy: Dummy)
+ -> (value: MyTrackedFloat, pullback: (TangentVector) -> Tracked<Float>) {
+ return (MyTrackedFloat(base, dummy: dummy), { v in v.base })
+ }
+
+ @differentiable
+ func ownedParameter(_ x: __owned Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ @derivative(of: ownedParameter)
+ func vjpOwnedParameterMismatch(_ x: __shared Tracked<Float>)
+ -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
+ return (ownedParameter(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ func sharedParameter(_ x: __shared Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ @derivative(of: sharedParameter)
+ func vjpSharedParameterMismatch(_ x: __owned Tracked<Float>)
+ -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
+ return (sharedParameter(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ func ownedParameterGeneric<T : Differentiable>(_ x: __owned T) -> T {
+ return x
+ }
+ @derivative(of: ownedParameterGeneric)
+ func vjpOwnedParameterGenericMismatch<T : Differentiable>(_ x: __shared T)
+ -> (value: T, pullback: (T.TangentVector) -> (TangentVector, T.TangentVector)) {
+ return (ownedParameterGeneric(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ func sharedParameterGeneric<T : Differentiable>(_ x: __shared T) -> T {
+ return x
+ }
+ @derivative(of: sharedParameterGeneric)
+ func vjpSharedParameterGenericMismatch<T : Differentiable>(_ x: __owned T)
+ -> (value: T, pullback: (T.TangentVector) -> (TangentVector, T.TangentVector)) {
+ return (sharedParameterGeneric(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ __consuming func consuming(_ x: Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ @derivative(of: consuming)
+ func vjpConsumingMismatch(_ x: Tracked<Float>)
+ -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
+ return (consuming(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ __consuming func consumingGeneric<T : Differentiable>(_ x: T) -> T {
+ return x
+ }
+ @derivative(of: consumingGeneric)
+ func vjpConsumingGenericMismatch<T : Differentiable>(_ x: T)
+ -> (value: T, pullback: (T.TangentVector) -> (TangentVector, T.TangentVector)) {
+ return (consumingGeneric(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ func nonconsuming(_ x: Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ @derivative(of: nonconsuming)
+ __consuming func vjpNonconsumingMismatch(_ x: Tracked<Float>)
+ -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
+ return (nonconsuming(x), { v in (.zero, v) })
+ }
+
+ @differentiable
+ func nonconsumingGeneric<T : Differentiable>(_ x: T) -> T {
+ return x
+ }
+ @derivative(of: nonconsumingGeneric)
+ __consuming func vjpNonconsumingGenericMismatch<T : Differentiable>(_ x: T)
+ -> (value: T, pullback: (T.TangentVector) -> (TangentVector, T.TangentVector)) {
+ return (nonconsumingGeneric(x), { v in (.zero, v) })
+ }
+ }
+ let v = MyTrackedFloat<Any>.TangentVector(base: 10)
+ expectEqual(10, pullback(at: Tracked<Float>(1)) { x in MyTrackedFloat(x, dummy: 1.0) }(v))
+ let x: Tracked<Float> = 1
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).ownedParameter(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).sharedParameter(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).ownedParameterGeneric(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).sharedParameterGeneric(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).consuming(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).consumingGeneric(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).nonconsuming(x) }
+ _ = gradient(at: x) { x in MyTrackedFloat<Any>(x, dummy: 1).nonconsumingGeneric(x) }
+}
+
+LeakCheckingTests.testWithLeakChecking("ClosureCaptureLeakChecking") {
+ do {
+ var model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+
+ _ = gradient(at: model) { m in m.applied(to: x) }
+ for _ in 0..<10 {
+ _ = gradient(at: model) { m in m.applied(to: x) }
+ }
+ }
+
+ do {
+ var model = ExampleLeakModel()
+ var x: Tracked<Float> = 1.0
+ _ = gradient(at: model) { m -> Tracked<Float> in
+ x = x + x
+ var y = x + Tracked<Float>(x.value)
+ return m.applied(to: y)
+ }
+ }
+
+ do {
+ var model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+ _ = gradient(at: model) { m -> Tracked<Float> in
+ var model = m
+ model.bias = x
+ return model.applied(to: x)
+ }
+ }
+
+ do {
+ struct Foo {
+ let x: Tracked<Float> = .zero
+ var y: Tracked<Float> = .zero
+ mutating func differentiateSomethingThatCapturesSelf() {
+ _ = gradient(at: x) { x -> Tracked<Float> in
+ self.y += .zero
+ return .zero
+ }
+ }
+ }
+ var foo = Foo()
+ foo.differentiateSomethingThatCapturesSelf()
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithTrivialUnconditionalMath") {
+ func ControlFlowWithTrivialUnconditionalMath(_ x: Tracked<Float>) -> Tracked<Float> {
+ if true {}
+ return x
+ }
+ _ = valueWithGradient(at: 1, in: ControlFlowWithTrivialUnconditionalMath)
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithTrivialNestedIfElse") {
+ func ControlFlowNestedWithTrivialIfElse(_ x: Tracked<Float>) -> Tracked<Float> {
+ if true {
+ if false {
+ return x
+ } else {
+ return x
+ }
+ }
+ }
+ _ = valueWithGradient(at: 1, in: ControlFlowNestedWithTrivialIfElse)
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithActiveCFCondition") {
+ var model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+ func ControlFlowWithActiveCFCondition(m: ExampleLeakModel, x: Tracked<Float>) -> Tracked<Float> {
+ if x > 0 {
+ return x
+ } else {
+ return x
+ }
+ }
+ _ = gradient(at: model, x, in: ControlFlowWithActiveCFCondition)
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithIf") {
+ var model = ExampleLeakModel()
+ let x: Tracked<Float> = 1.0
+ _ = gradient(at: model, x) { m, x -> Tracked<Float> in
+ var result: Tracked<Float> = x
+ if x > 0 {
+ result = result + m.applied(to: x)
+ }
+ return result
+ }
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithIfInMethod") {
+ struct Dense : Differentiable {
+ var w1: Tracked<Float>
+ @noDerivative var w2: Tracked<Float>?
+
+ func callAsFunction(_ input: Tracked<Float>) -> Tracked<Float> {
+ if let w2 = w2 {
+ return input * w1 * w2
+ }
+ return input * w1
+ }
+ }
+ let dense = Dense(w1: 4, w2: 5)
+ let denseNil = Dense(w1: 4, w2: nil)
+ expectEqual((Dense.TangentVector(w1: 10), 20),
+ gradient(at: dense, 2, in: { dense, x in dense(x) }))
+ expectEqual((Dense.TangentVector(w1: 2), 4),
+ gradient(at: denseNil, 2, in: { dense, x in dense(x) }))
+}
+
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithLoop") {
+ func for_loop(_ x: Tracked<Float>) -> Tracked<Float> {
+ var result = x
+ for _ in 1..<3 {
+ result = result * x
+ }
+ return result
+ }
+ expectEqual((8, 12), valueWithGradient(at: 2, in: for_loop))
+ expectEqual((27, 27), valueWithGradient(at: 3, in: for_loop))
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithNestedLoop") {
+ func nested_loop(_ x: Tracked<Float>) -> Tracked<Float> {
+ var outer = x
+ for _ in 1..<3 {
+ outer = outer * x
+
+ var inner = outer
+ var i = 1
+ while i < 3 {
+ inner = inner / x
+ i += 1
+ }
+ outer = inner
+ }
+ return outer
+ }
+ expectEqual((Tracked<Float>(0.5), Tracked<Float>(-0.25)), valueWithGradient(at: 2, in: nested_loop))
+ expectEqual((Tracked<Float>(0.25), Tracked<Float>(-0.0625)), valueWithGradient(at: 4, in: nested_loop))
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithNestedTuples") {
+ func cond_nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
+ // Convoluted function returning `x + x`.
+ var y = (x + x, x - x)
+ var z = (y, x)
+ if x > 0 {
+ var w = (x, x)
+ y.0 = w.1
+ y.1 = w.0
+ z.0.0 = z.0.0 - y.0
+ z.0.1 = z.0.1 + y.0
+ } else {
+ z = ((y.0 - x, y.1 + x), x)
+ }
+ return y.0 + y.1 - z.0.0 + z.0.1
+ }
+ expectEqual((8, 2), valueWithGradient(at: 4, in: cond_nestedtuple_var))
+ expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_nestedtuple_var))
+ expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_nestedtuple_var))
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithNestedStructs") {
+ func cond_nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
+ // Convoluted function returning `x + x`.
+ var y = FloatPair(x + x, x - x)
+ var z = Pair(Tracked(y), x)
+ if x > 0 {
+ var w = FloatPair(x, x)
+ y.first = w.second
+ y.second = w.first
+ z.first = Tracked(FloatPair(z.first.value.first - y.first,
+ z.first.value.second + y.first))
+ } else {
+ z = Pair(Tracked(FloatPair(y.first - x, y.second + x)), x)
+ }
+ return y.first + y.second - z.first.value.first + z.first.value.second
+ }
+ expectEqual((8, 2), valueWithGradient(at: 4, in: cond_nestedstruct_var))
+ expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_nestedstruct_var))
+ expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_nestedstruct_var))
+}
+
+LeakCheckingTests.testWithLeakChecking("ControlFlowWithSwitchEnumWithPayload") {
+ enum Enum {
+ case a(Tracked<Float>)
+ case b(Tracked<Float>, Tracked<Float>)
+ }
+ func enum_notactive2(_ e: Enum, _ x: Tracked<Float>) -> Tracked<Float> {
+ var y = x
+ if x > 0 {
+ var z = y + y
+ switch e {
+ case .a: z = z - y
+ case .b: y = y + x
+ }
+ var w = y
+ if case .a = e {
+ w = w + z
+ }
+ return w
+ } else if case .b = e {
+ return y + y
+ }
+ return x + y
+ }
+ expectEqual((8, 2), valueWithGradient(at: 4, in: { x in enum_notactive2(.a(10), x) }))
+ expectEqual((20, 2), valueWithGradient(at: 10, in: { x in enum_notactive2(.b(4, 5), x) }))
+ expectEqual((-20, 2), valueWithGradient(at: -10, in: { x in enum_notactive2(.a(10), x) }))
+ expectEqual((-2674, 2), valueWithGradient(at: -1337, in: { x in enum_notactive2(.b(4, 5), x) }))
+}
+
+LeakCheckingTests.testWithLeakChecking("ArrayLiteralInitialization") {
+ func concat(_ x: [Tracked<Float>]) -> Tracked<Float> { return x[0] }
+ func foo(_ x: Tracked<Float>) -> Float {
+ let y = x + x
+ return concat([x, y]).value
+ }
+ expectEqual(Tracked<Float>(1), gradient(at: .zero, in: foo))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/loadable-by-address.swift b/test/AutoDiff/downstream/loadable-by-address.swift
new file mode 100644
index 0000000..2249970
--- /dev/null
+++ b/test/AutoDiff/downstream/loadable-by-address.swift
@@ -0,0 +1,75 @@
+// RUN: %target-swift-frontend -c -enable-large-loadable-types -Xllvm -sil-verify-after-pass=loadable-address %s
+// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL
+// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL
+// RUN: %target-run-simple-swift
+// REQUIRES: executable_test
+
+// TF-11: Verify that LoadableByAddress works with differentiation-related instructions:
+// - `differentiable_function`
+// - `differentiable_function_extract`
+
+// TODO: Add tests for `@differentiable(linear)` functions.
+
+import StdlibUnittest
+
+var LBATests = TestSuite("LoadableByAddress")
+
+// `Large` is a large loadable type.
+// `Large.TangentVector` is not a large loadable type.
+struct Large : Differentiable {
+ var a: Float
+ var b: Float
+ var c: Float
+ var d: Float
+ @noDerivative let e: Float
+}
+
+@_silgen_name("large2large")
+@differentiable
+func large2large(_ foo: Large) -> Large {
+ foo
+}
+
+// `large2large` old verification error:
+// SIL verification failed: JVP type does not match expected JVP type
+// $@callee_guaranteed (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
+// $@callee_guaranteed (@in_constant Large) -> (@out Large, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> @out Large.TangentVector)
+
+@_silgen_name("large2small")
+@differentiable
+func large2small(_ foo: Large) -> Float {
+ foo.a
+}
+
+// `large2small` old verification error:
+// SIL verification failed: JVP type does not match expected JVP type
+// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float)
+// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> Float)
+
+// CHECK-SIL: sil hidden @large2large : $@convention(thin) (Large) -> Large {
+// CHECK-LBA-SIL: sil hidden @large2large : $@convention(thin) (@in_constant Large) -> @out Large {
+
+// CHECK-SIL-LABEL: sil hidden @large2small : $@convention(thin) (Large) -> Float {
+// CHECK-LBA-SIL: sil hidden @large2small : $@convention(thin) (@in_constant Large) -> Float {
+
+// CHECK-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
+// CHECK-LBA-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
+
+// CHECK-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
+// CHECK-LBA-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
+
+// CHECK-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
+// CHECK-LBA-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
+
+// CHECK-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
+// CHECK-LBA-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
+
+LBATests.test("Correctness") {
+ let one = Large.TangentVector(a: 1, b: 1, c: 1, d: 1)
+ expectEqual(one,
+ pullback(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2large)(one))
+ expectEqual(Large.TangentVector(a: 1, b: 0, c: 0, d: 0),
+ gradient(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2small))
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/loadable_by_address_cross_module.swift b/test/AutoDiff/downstream/loadable_by_address_cross_module.swift
new file mode 100644
index 0000000..f51d67a
--- /dev/null
+++ b/test/AutoDiff/downstream/loadable_by_address_cross_module.swift
@@ -0,0 +1,44 @@
+// First, check that LBA actually modifies the function, so that this test is useful.
+
+// RUN: %target-swift-frontend -emit-sil %S/Inputs/loadable_by_address_cross_module.swift | %FileCheck %s -check-prefix=CHECK-MODULE-PRE-LBA
+// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %S/Inputs/loadable_by_address_cross_module.swift 2>&1 | %FileCheck %s -check-prefix=CHECK-MODULE-POST-LBA
+
+// CHECK-MODULE-PRE-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, LargeLoadableType<T>) -> Float
+// CHECK-MODULE-POST-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, @in_constant LargeLoadableType<T>) -> Float
+
+// Compile the module.
+
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -working-directory %t -parse-as-library -emit-module -module-name external -emit-module-path %t/external.swiftmodule -emit-library -static %S/Inputs/loadable_by_address_cross_module.swift
+
+// Next, check that differentiability_witness_functions in the client get
+// correctly modified by LBA.
+
+// RUN: %target-swift-frontend -emit-sil -I%t %s
+// RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA
+// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA
+
+// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
+// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
+
+// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @$s8external17LargeLoadableTypeV0A19LBAModifiedFunctionyS2fF : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, τ_0_0) -> Float for <LargeLoadableType<τ_0_0>>)
+// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @$s8external17LargeLoadableTypeV0A19LBAModifiedFunctionyS2fF : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, τ_0_0) for <LargeLoadableType<τ_0_0>>)
+
+// Finally, execute the test.
+
+// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lexternal
+// RUN: %target-run %t/a.out
+
+// REQUIRES: executable_test
+
+import external
+import StdlibUnittest
+
+var Tests = TestSuite("LoadableByAddressCrossModule")
+
+Tests.test("Correctness") {
+ let g = gradient(at: LargeLoadableType<Int>(a: 5), 10) { $0.externalLBAModifiedFunction($1) }
+ expectEqual((LargeLoadableType<Int>(a: 10), 5), g)
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/mangling.swift b/test/AutoDiff/downstream/mangling.swift
new file mode 100644
index 0000000..801f459
--- /dev/null
+++ b/test/AutoDiff/downstream/mangling.swift
@@ -0,0 +1,16 @@
+// RUN: %target-swift-frontend -emit-sil -verify %s
+
+// TF-758: test JVP/VJP mangling + generic specialization mangling.
+//
+// Note: this test depends on `Array.differentiableReduce` defined in
+// `stdlib/public/core/AutoDiff.swift`. The crash is not reproducible if
+// `Array.differentiableReduce` is defined in this test file.
+struct TF_758<Scalar> : Differentiable {}
+struct TF_758_Wrapper {
+ var blocks: [TF_758<Float>]
+
+ @differentiable
+ func foo(_ input: TF_758<Float>) -> TF_758<Float> {
+ return blocks.differentiableReduce(input) { res, _ in res }
+ }
+}
diff --git a/test/AutoDiff/downstream/noderivative_attr.swift b/test/AutoDiff/downstream/noderivative_attr.swift
new file mode 100644
index 0000000..dfc93c7
--- /dev/null
+++ b/test/AutoDiff/downstream/noderivative_attr.swift
@@ -0,0 +1,75 @@
+// RUN: %target-swift-frontend -emit-silgen -verify %s | %FileCheck %s
+// REQUIRES: asserts
+
+@noDerivative var flag: Bool
+
+struct NotDifferentiable {
+ @noDerivative var stored: Float
+
+ @noDerivative
+ var computedProperty: Float {
+ get { 1 }
+ set {}
+ _modify { yield &stored }
+ }
+
+ @noDerivative
+ func instanceMethod(_ x: Float) -> Float { x }
+
+ @noDerivative
+ static func staticMethod(_ x: Float) -> Float { x }
+
+ @noDerivative
+ subscript(_ x: Float) -> Float {
+ get { 1 }
+ set {}
+ _modify { yield &stored }
+ }
+}
+
+// CHECK-LABEL: struct NotDifferentiable {
+// CHECK: @noDerivative @_hasStorage var stored: Float { get set }
+// CHECK: @noDerivative var computedProperty: Float { get set _modify }
+// CHECK: @noDerivative func instanceMethod(_ x: Float) -> Float
+// CHECK: @noDerivative static func staticMethod(_ x: Float) -> Float
+// CHECK: @noDerivative subscript(x: Float) -> Float { get set _modify }
+// CHECK: }
+
+// CHECK-LABEL: // NotDifferentiable.computedProperty.getter
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvg : $@convention(method) (NotDifferentiable) -> Float
+
+// CHECK-LABEL: // NotDifferentiable.computedProperty.setter
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvs : $@convention(method) (Float, @inout NotDifferentiable) -> ()
+
+// CHECK-LABEL: // NotDifferentiable.computedProperty.modify
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvM : $@yield_once @convention(method) (@inout NotDifferentiable) -> @yields @inout Float
+
+// CHECK-LABEL: // NotDifferentiable.instanceMethod(_:)
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV14instanceMethodyS2fF : $@convention(method) (Float, NotDifferentiable) -> Float
+
+// CHECK-LABEL: // static NotDifferentiable.staticMethod(_:)
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV12staticMethodyS2fFZ : $@convention(method) (Float, @thin NotDifferentiable.Type) -> Float
+
+// CHECK-LABEL: // NotDifferentiable.subscript.getter
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcig : $@convention(method) (Float, NotDifferentiable) -> Float
+
+// CHECK-LABEL: // NotDifferentiable.subscript.setter
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcis : $@convention(method) (Float, Float, @inout NotDifferentiable) -> ()
+
+// CHECK-LABEL: // NotDifferentiable.subscript.modify
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fciM : $@yield_once @convention(method) (Float, @inout NotDifferentiable) -> @yields @inout Float
+
+struct Bar: Differentiable {
+ @noDerivative var stored: Float
+}
+
+// Test TF-152: derived conformances "no interface type set" crasher.
+struct TF_152: Differentiable {
+ @differentiable(wrt: bar)
+ func applied(to input: Float, bar: TF_152_Bar) -> Float {
+ return input
+ }
+}
+struct TF_152_Bar: Differentiable {
+ @noDerivative let dense: Float
+}
diff --git a/test/AutoDiff/downstream/noderivative_attr_cross_file.swift b/test/AutoDiff/downstream/noderivative_attr_cross_file.swift
new file mode 100644
index 0000000..33c34c5
--- /dev/null
+++ b/test/AutoDiff/downstream/noderivative_attr_cross_file.swift
@@ -0,0 +1,8 @@
+// RUN: %target-swift-frontend -emit-sil %S/Inputs/noderivative_attr_other_file.swift %s | %FileCheck %s
+
+@differentiable
+func bar(_ x: Float) -> Float {
+ return Float(floatToIntNoDerivative(x))
+}
+
+// CHECK: sil hidden [_semantics "autodiff.nonvarying"] @float_to_int_noderivative : $@convention(thin) (Float) -> Int
diff --git a/test/AutoDiff/downstream/nondifferentiable_function_cross_module.swift b/test/AutoDiff/downstream/nondifferentiable_function_cross_module.swift
new file mode 100644
index 0000000..daa2f42
--- /dev/null
+++ b/test/AutoDiff/downstream/nondifferentiable_function_cross_module.swift
@@ -0,0 +1,11 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/nondifferentiable_function_other_module.swift -emit-module-path %t/nondifferentiable_function_other_module.swiftmodule
+// RUN: %target-swift-frontend -emit-sil -I %t -primary-file %s -verify
+
+import nondifferentiable_function_other_module
+
+func test() {
+ // expected-error @+2 {{function is not differentiable}}
+ // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
+ _ = gradient(at: Float(1), in: externalFunction)
+}
diff --git a/test/AutoDiff/downstream/nonvaried_result.swift b/test/AutoDiff/downstream/nonvaried_result.swift
new file mode 100644
index 0000000..e0608dd
--- /dev/null
+++ b/test/AutoDiff/downstream/nonvaried_result.swift
@@ -0,0 +1,162 @@
+// RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null 2>&1 | %FileCheck %s
+// RUN: %target-run-simple-swift
+// TODO: Test forward-mode differentiation when it supports control flow.
+// UN: %target-run-simple-swift-forward-mode-differentiation
+// REQUIRES: executable_test
+
+// Test differentiation edge case: functions with non-varied results.
+// The differentials of these functions should return zero.
+// The pullbacks of these functions should return zero with respect to the
+// parameters for which the result is non-varying.
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var NonVariedResultTests = TestSuite("TestCaseTests")
+
+NonVariedResultTests.testWithLeakChecking("SingleBasicBlock") {
+ @differentiable(wrt: y)
+ func simple(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { simple(10, $0) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: simple))
+}
+
+// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simple{{.*}}pullback_src_0_wrt_1 : $@convention(thin) (@guaranteed Tracked<Float>, @owned _AD__$s4nullyycfU_6simpleL_y23DifferentiationUnittest7TrackedVySfGAF_AFtF_bb0__PB__src_0_wrt_1) -> @owned Tracked<Float> {
+// CHECK: bb0([[SEED:%.*]] : @guaranteed $Tracked<Float>, [[PB_STRUCT:%.*]] : [[PB_STRUCT_TYPE:.*]]):
+// CHECK: [[BUF:%.*]] = alloc_stack $Tracked<Float>
+// CHECK: [[ZERO_FN:%.*]] = witness_method $Tracked<Float>, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE:%.*]] = metatype $@thick Tracked<Float>.Type
+// CHECK: {{%.*}} = apply [[ZERO_FN]]<Tracked<Float>>([[BUF]], [[METATYPE]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[ZERO_VALUE:%.*]] = load [take] [[BUF]] : $*Tracked<Float>
+// CHECK: dealloc_stack [[BUF]] : $*Tracked<Float>
+// CHECK: return [[ZERO_VALUE]]
+
+NonVariedResultTests.testWithLeakChecking("SingleBasicBlockGeneric") {
+ // Test zero wrt multiple arguments.
+ @differentiable(wrt: (x, y, z))
+ func simpleGeneric<T: Differentiable>(
+ _ x: T, _ y: T, _ z: Tracked<Float>
+ ) -> T where T == T.TangentVector {
+ return .zero
+ }
+ expectEqual((0, 0, 0), gradient(at: 3, 4, 5) { simpleGeneric($0, $1, $2) })
+}
+
+// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simpleGeneric{{.*}}pullback_src_0_wrt_0_1_2{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> (@in_guaranteed τ_0_0, @owned _AD__$s4nullyycfU0_13simpleGenericL_yxx_x23DifferentiationUnittest7TrackedVySfGts14DifferentiableRz13TangentVectorsAGPQzRszlF_bb0__PB__src_0_wrt_0_1_2_s14DifferentiableRz13TangentVectorsAAPQzRszl<τ_0_0>) -> (@out τ_0_0, @out τ_0_0, @owned Tracked<Float>) {
+// CHECK: bb0([[DX:%.*]] : $*τ_0_0, [[DY:%.*]] : $*τ_0_0, [[SEED:%.*]] : $*τ_0_0, [[PB_STRUCT:%.*]] : [[PB_STRUCT_TYPE:.*]]):
+// CHECK: [[ZERO_FN_X:%.*]] = witness_method $τ_0_0, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE_X:%.*]] = metatype $@thick τ_0_0.Type
+// CHECK: {{.*}} = apply [[ZERO_FN_X]]<τ_0_0>([[DX]], [[METATYPE_X]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[ZERO_FN_Y:%.*]] = witness_method $τ_0_0, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE_Y:%.*]] = metatype $@thick τ_0_0.Type
+// CHECK: {{.*}} = apply [[ZERO_FN_Y:%.*]]<τ_0_0>([[DY]], [[METATYPE_Y]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[BUF_Z:%.*]] = alloc_stack $Tracked<Float>
+// CHECK: [[ZERO_FN_Z:%.*]] = witness_method $Tracked<Float>, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE_Z:%.*]] = metatype $@thick Tracked<Float>.Type
+// CHECK: {{%.*}} = apply [[ZERO_FN_Z]]<Tracked<Float>>([[BUF_Z]], [[METATYPE_Z]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[ZERO_VALUE_Z:%.*]] = load [take] [[BUF_Z]] : $*Tracked<Float>
+// CHECK: dealloc_stack [[BUF_Z]] : $*Tracked<Float>
+// CHECK: return [[ZERO_VALUE_Z]]
+
+NonVariedResultTests.testWithLeakChecking("Conditionals") {
+ @differentiable(wrt: y)
+ func `if`(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ if x > 0 {}
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { `if`(10, $0) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: `if`))
+
+ @differentiable(wrt: y)
+ func `guard`(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ guard x > 0 else { return x }
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { x in `guard`(10, x) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: `guard`))
+
+ @differentiable(wrt: y)
+ func `switch`(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ switch x.value {
+ case 0: break
+ default: break
+ }
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { `switch`(10, $0) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: `switch`))
+}
+
+NonVariedResultTests.testWithLeakChecking("Loops") {
+ @differentiable(wrt: y)
+ func `for`(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ for i in 0..<10 {}
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { `for`(10, $0) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: `for`))
+
+ @differentiable(wrt: y)
+ func `while`(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ while 0 < 0 {}
+ return x
+ }
+ expectEqual(0, gradient(at: 3) { `while`(10, $0) })
+ expectEqual((1, 0), gradient(at: 3, 4, in: `while`))
+}
+
+NonVariedResultTests.testWithLeakChecking("Complex") {
+ @differentiable(wrt: y)
+ func complex(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
+ for i in 0..<10 {
+ for j in 0..<10 {
+ if x > 0 {}
+ while 0 < 0 {}
+ switch x.value {
+ case 0: break
+ default: break
+ }
+ }
+ }
+ return x + x + x
+ }
+ expectEqual(0, gradient(at: 3) { complex(10, $0) })
+ expectEqual((3, 0), gradient(at: 3, 4, in: complex))
+}
+
+// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complex{{.*}}pullback_src_0_wrt_1 : $@convention(thin) (@guaranteed Tracked<Float>, @owned _AD__$s4nullyycfU3_7complexL_y23DifferentiationUnittest7TrackedVySfGAF_AFtF_bb15__PB__src_0_wrt_1) -> @owned Tracked<Float> {
+// CHECK: bb0([[SEED:%.*]] : @guaranteed $Tracked<Float>, [[PB_STRUCT:%.*]] : @owned [[PB_STRUCT_TYPE:.*]]):
+// CHECK: destroy_value [[PB_STRUCT]] : [[PB_STRUCT_TYPE]]
+// CHECK: [[BUF:%.*]] = alloc_stack $Tracked<Float>
+// CHECK: [[ZERO_FN:%.*]] = witness_method $Tracked<Float>, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE:%.*]] = metatype $@thick Tracked<Float>.Type
+// CHECK: {{%.*}} = apply [[ZERO_FN]]<Tracked<Float>>([[BUF]], [[METATYPE]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[ZERO_VALUE:%.*]] = load [take] [[BUF]] : $*Tracked<Float>
+// CHECK: dealloc_stack [[BUF]] : $*Tracked<Float>
+// CHECK: return [[ZERO_VALUE]]
+
+NonVariedResultTests.testWithLeakChecking("ComplexGeneric") {
+ @differentiable(wrt: y)
+ func complexGeneric<T: Differentiable>(_ x: T, _ y: T) -> T {
+ for i in 0..<10 {
+ for j in 0..<10 {
+ while 0 < 0 {}
+ }
+ }
+ return x
+ }
+ expectEqual(0, pullback(at: Tracked<Float>(3)) { complexGeneric(10, $0) }(1))
+}
+
+// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complexGeneric{{.*}}pullback_src_0_wrt_1{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU4_14complexGenericL_yxx_xts14DifferentiableRzlF_bb9__PB__src_0_wrt_1_s14DifferentiableRzl<τ_0_0>) -> @out τ_0_0.TangentVector {
+// CHECK: bb0([[DY:%.*]] : $*τ_0_0.TangentVector, [[SEED:%.*]] : $*τ_0_0.TangentVector, [[PB_STRUCT:%.*]] : @owned [[PB_STRUCT_TYPE:.*]]):
+// CHECK: destroy_value [[PB_STRUCT]] : [[PB_STRUCT_TYPE]]
+// CHECK: [[ZERO_FN:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type
+// CHECK: {{%.*}} = apply [[ZERO_FN]]<τ_0_0.TangentVector>([[DY]], [[METATYPE]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
+// CHECK: [[VOID:%.*]] = tuple ()
+// CHECK: return [[VOID]]
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/nonvarying_semantics.swift b/test/AutoDiff/downstream/nonvarying_semantics.swift
new file mode 100644
index 0000000..25b93b3
--- /dev/null
+++ b/test/AutoDiff/downstream/nonvarying_semantics.swift
@@ -0,0 +1,88 @@
+// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
+// REQUIRES: asserts
+
+// Test approaches for affecting activity analysis (non-varying semantics):
+// - `@noDerivative` on declaration
+// - `@_semantics("autodiff.nonvarying")` on declaration
+// - `withoutDerivative(at:)` at use site
+
+extension Float {
+ // No non-varying semantics.
+ var int: Int { Int(self) }
+
+ // Non-varying semantics.
+ @noDerivative
+ var intNoDerivative: Int { int }
+
+ // Non-varying semantics.
+ @_semantics("autodiff.nonvarying")
+ var intNonvarying: Int { int }
+}
+
+// expected-error @+1 {{function is not differentiable}}
+@differentiable
+@_silgen_name("id")
+// expected-note @+1 {{when differentiating this function definition}}
+func id(_ x: Float) -> Float {
+ // expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
+ return Float(x.int)
+}
+
+// CHECK-LABEL: [AD] Activity info for id at (parameters=(0) results=(0))
+// CHECK: bb0:
+// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
+// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
+// CHECK: [NONE] // function_ref Float.int.getter
+// CHECK: [ACTIVE] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
+// CHECK: [NONE] // function_ref Float.init(_:)
+// CHECK: [ACTIVE] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
+
+@differentiable
+@_silgen_name("idWithoutDerivativeAt")
+func idWithoutDerivativeAt(_ x: Float) -> Float {
+ return Float(withoutDerivative(at: x.int))
+}
+
+// CHECK-LABEL: [AD] Activity info for idWithoutDerivativeAt at (parameters=(0) results=(0))
+// CHECK: bb0:
+// CHECK: [VARIED] %0 = argument of bb0 : $Float
+// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
+// CHECK: [USEFUL] %3 = alloc_stack $Int
+// CHECK: [NONE] // function_ref Float.int.getter
+// CHECK: [VARIED] %5 = apply %4(%0) : $@convention(method) (Float) -> Int
+// CHECK: [VARIED] %6 = alloc_stack $Int
+// CHECK: [NONE] // function_ref withoutDerivative<A>(at:)
+// CHECK: [NONE] %9 = apply %8<Int>(%3, %6) : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> @out τ_0_0
+// CHECK: [USEFUL] %11 = load [trivial] %3 : $*Int
+// CHECK: [NONE] // function_ref Float.init(_:)
+// CHECK: [USEFUL] %13 = apply %12(%11, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
+
+@differentiable
+@_silgen_name("idNoDerivative")
+func idNoDerivative(_ x: Float) -> Float {
+ return Float(x.intNoDerivative)
+}
+
+// CHECK-LABEL: [AD] Activity info for idNoDerivative at (parameters=(0) results=(0))
+// CHECK: bb0:
+// CHECK: [VARIED] %0 = argument of bb0 : $Float
+// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
+// CHECK: [NONE] // function_ref Float.intNoDerivative.getter
+// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
+// CHECK: [NONE] // function_ref Float.init(_:)
+// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
+
+@differentiable
+@_silgen_name("idNonvaryingSemantics")
+func idNonvaryingSemantics(_ x: Float) -> Float {
+ return Float(x.intNonvarying)
+}
+
+// CHECK-LABEL: [AD] Activity info for idNonvaryingSemantics at (parameters=(0) results=(0))
+// CHECK: bb0:
+// CHECK: [VARIED] %0 = argument of bb0 : $Float
+// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
+// CHECK: [NONE] // function_ref Float.intNonvarying.getter
+// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
+// CHECK: [NONE] // function_ref Float.init(_:)
+// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
diff --git a/test/AutoDiff/downstream/pass_creates_differentiability_witnesses.swift b/test/AutoDiff/downstream/pass_creates_differentiability_witnesses.swift
new file mode 100644
index 0000000..323e6f3
--- /dev/null
+++ b/test/AutoDiff/downstream/pass_creates_differentiability_witnesses.swift
@@ -0,0 +1,78 @@
+// RUN: %target-swift-frontend -emit-sil -emit-sorted-sil %s | %FileCheck %s
+
+// MARK: - Public functions
+
+@differentiable
+@_silgen_name("f000_invokedDirectlyByDifferentiableAttrPublic")
+public func f000_invokedDirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
+ return f001_invokedIndirectlyByDifferentiableAttrPublic(x)
+}
+// CHECK-LABEL: sil_differentiability_witness [serialized] [parameters 0] [results 0] @f000_invokedDirectlyByDifferentiableAttrPublic
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f001_invokedIndirectlyByDifferentiableAttrPublic")
+public func f001_invokedIndirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
+ return x
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f001_invokedIndirectlyByDifferentiableAttrPublic
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f002_invokedDirectlyByConversionPublic")
+public func f002_invokedDirectlyByConversionPublic(_ x: Float) -> Float {
+ return f003_invokedIndirectlyByConversionPublic(x)
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f002_invokedDirectlyByConversionPublic
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f003_invokedIndirectlyByConversionPublic")
+public func f003_invokedIndirectlyByConversionPublic(_ x: Float) -> Float {
+ return x
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f003_invokedIndirectlyByConversionPublic
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+// MARK: - Internal functions
+
+@differentiable
+@_silgen_name("f004_invokedDirectlyByDifferentiableAttrInternal")
+internal func f004_invokedDirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
+ return f005_invokedIndirectlyByDifferentiableAttrInternal(x)
+}
+// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f004_invokedDirectlyByDifferentiableAttrInternal
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f005_invokedIndirectlyByDifferentiableAttrInternal")
+internal func f005_invokedIndirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
+ return x
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f005_invokedIndirectlyByDifferentiableAttrInternal
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f006_invokedDirectlyByConversionInternal")
+internal func f006_invokedDirectlyByConversionInternal(_ x: Float) -> Float {
+ return f007_invokedIndirectlyByConversionInternal(x)
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f006_invokedDirectlyByConversionInternal
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+@_silgen_name("f007_invokedIndirectlyByConversionInternal")
+internal func f007_invokedIndirectlyByConversionInternal(_ x: Float) -> Float {
+ return x
+}
+// CHECK-LABEL: sil_differentiability_witness private [parameters 0] [results 0] @f007_invokedIndirectlyByConversionInternal
+// CHECK-NEXT: jvp
+// CHECK-NEXT: vjp
+
+func invokesByConversion() -> Float {
+ var result: Float = 0
+ result += gradient(at: 0, in: f002_invokedDirectlyByConversionPublic)
+ result += gradient(at: 0, in: f006_invokedDirectlyByConversionInternal)
+ return result
+}
diff --git a/test/AutoDiff/downstream/protocol_requirement_autodiff.swift b/test/AutoDiff/downstream/protocol_requirement_autodiff.swift
new file mode 100644
index 0000000..4142ae2
--- /dev/null
+++ b/test/AutoDiff/downstream/protocol_requirement_autodiff.swift
@@ -0,0 +1,200 @@
+// RUN: %target-run-simple-swift
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementAutodiff")
+
+// MARK: - Func requirements.
+
+protocol DiffReq : Differentiable {
+ @differentiable(wrt: (self, x))
+ func f(_ x: Tracked<Float>) -> Tracked<Float>
+}
+
+extension DiffReq where TangentVector : AdditiveArithmetic {
+ @inline(never) // Prevent specialization, to test all witness code.
+ func gradF(at x: Tracked<Float>) -> (Self.TangentVector, Tracked<Float>) {
+ return (valueWithPullback(at: self, x) { s, x in s.f(x) }).1(1)
+ }
+}
+
+struct Quadratic : DiffReq, AdditiveArithmetic {
+ typealias TangentVector = Quadratic
+
+ @differentiable
+ let a: Tracked<Float>
+
+ @differentiable
+ let b: Tracked<Float>
+
+ @differentiable
+ let c: Tracked<Float>
+
+ init(_ a: Tracked<Float>, _ b: Tracked<Float>, _ c: Tracked<Float>) {
+ self.a = a
+ self.b = b
+ self.c = c
+ }
+
+ @differentiable(wrt: (self, x))
+ func f(_ x: Tracked<Float>) -> Tracked<Float> {
+ return a * x * x + b * x + c
+ }
+}
+
+ProtocolRequirementAutodiffTests.testWithLeakChecking("func") {
+ expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
+ expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
+ Quadratic(11, 12, 13).gradF(at: 1))
+ expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12),
+ Quadratic(11, 12, 13).gradF(at: 2))
+}
+
+// MARK: Constructor, accessor, and subscript requirements.
+
+protocol FunctionsOfX: Differentiable {
+ @differentiable
+ init(x: Tracked<Float>)
+
+ @differentiable
+ var x: Tracked<Float> { get }
+
+ @differentiable
+ var y: Tracked<Float> { get }
+
+ @differentiable
+ var z: Tracked<Float> { get }
+
+ @differentiable
+ subscript() -> Tracked<Float> { get }
+}
+
+struct TestFunctionsOfX: FunctionsOfX {
+ @differentiable
+ init(x: Tracked<Float>) {
+ self.x = x
+ self.y = x * x
+ }
+
+ /// x = x
+ var x: Tracked<Float>
+
+ /// y = x * x
+ var y: Tracked<Float>
+
+ /// z = x * x + x
+ var z: Tracked<Float> {
+ return y + x
+ }
+
+ @differentiable
+ subscript() -> Tracked<Float> {
+ return z
+ }
+}
+
+@inline(never) // Prevent specialization, to test all witness code.
+func derivatives<F: FunctionsOfX>(at x: Tracked<Float>, in: F.Type)
+ -> (Tracked<Float>, Tracked<Float>, Tracked<Float>, Tracked<Float>)
+{
+ let dxdx = gradient(at: x) { x in F(x: x).x }
+ let dydx = gradient(at: x) { x in F(x: x).y }
+ let dzdx = gradient(at: x) { x in F(x: x).z }
+ let dsubscriptdx = gradient(at: x) { x in F(x: x)[] }
+ return (dxdx, dydx, dzdx, dsubscriptdx)
+}
+
+ProtocolRequirementAutodiffTests.testWithLeakChecking("constructor, accessor, subscript") {
+ expectEqual(
+ (1.0, 4.0, 5.0, 5.0),
+ derivatives(at: 2.0, in: TestFunctionsOfX.self))
+}
+
+// MARK: - Test witness method SIL type computation.
+
+protocol P : Differentiable {
+ @differentiable(wrt: (x, y))
+ func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float>
+}
+struct S : P {
+ @differentiable(wrt: (x, y))
+ func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float> {
+ return x
+ }
+}
+
+// MARK: - Overridden protocol method adding differentiable attribute.
+
+public protocol Distribution {
+ associatedtype Value
+ func logProbability(of value: Value) -> Tracked<Float>
+}
+
+public protocol DifferentiableDistribution: Differentiable, Distribution {
+ @differentiable(wrt: self)
+ func logProbability(of value: Value) -> Tracked<Float>
+}
+
+struct Foo: DifferentiableDistribution {
+ @differentiable(wrt: self)
+ func logProbability(of value: Tracked<Float>) -> Tracked<Float> {
+ .zero
+ }
+}
+
+@differentiable
+func blah<T: DifferentiableDistribution>(_ x: T) -> Tracked<Float>
+where T.Value: AdditiveArithmetic {
+ x.logProbability(of: .zero)
+}
+
+// Adding a more general `@differentiable` attribute.
+public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
+ where Value: Differentiable {
+ @differentiable(wrt: self)
+ @differentiable(wrt: (self, value))
+ func logProbability(of value: Value) -> Tracked<Float>
+}
+
+@differentiable
+func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Tracked<Float>
+ where T.Value: AdditiveArithmetic {
+ x.logProbability(of: value)
+}
+
+// Satisfying the requirement with more wrt indices than are necessary.
+
+protocol DifferentiableFoo {
+ associatedtype T: Differentiable
+ @differentiable(wrt: x)
+ func foo(_ x: T) -> Tracked<Float>
+}
+
+protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
+ @differentiable(wrt: (self, x))
+ func foo(_ x: T) -> Tracked<Float>
+}
+
+struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
+ @differentiable(wrt: (self, x))
+ func foo(_ x: Tracked<Float>) -> Tracked<Float> {
+ x
+ }
+}
+
+// Satisfiying the requirement with a less-constrained derivative than is necessary.
+
+protocol ExtraDerivativeConstraint {}
+
+protocol HasExtraConstrainedDerivative {
+ @differentiable
+ func requirement<T: Differentiable & ExtraDerivativeConstraint>(_ x: T) -> T
+}
+
+struct SatisfiesDerivativeWithLessConstraint: HasExtraConstrainedDerivative {
+ @differentiable
+ func requirement<T: Differentiable>(_ x: T) -> T { x }
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/protocol_requirement_autodiff_diags.swift b/test/AutoDiff/downstream/protocol_requirement_autodiff_diags.swift
new file mode 100644
index 0000000..ea55f93
--- /dev/null
+++ b/test/AutoDiff/downstream/protocol_requirement_autodiff_diags.swift
@@ -0,0 +1,18 @@
+// RUN: %target-swift-frontend -typecheck -verify %s
+
+protocol P {}
+
+public protocol HasRequirement {
+ @differentiable
+ // expected-note @+1 {{protocol requires function 'requirement' with type '<T> (T, T) -> T'; do you want to add a stub?}}
+ func requirement<T: Differentiable>(_ x: T, _ y: T) -> T
+}
+
+// expected-error @+1 {{type 'AttemptsToSatisfyRequirement' does not conform to protocol 'HasRequirement'}}
+public struct AttemptsToSatisfyRequirement: HasRequirement {
+ // This does not satisfy the requirement because the differentiable attribute is more
+ // constrained than the requirement's differentiable attribute.
+ @differentiable(where T: P)
+ // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
+ public func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
+}
diff --git a/test/AutoDiff/downstream/refcounting.swift b/test/AutoDiff/downstream/refcounting.swift
new file mode 100644
index 0000000..77d12c2
--- /dev/null
+++ b/test/AutoDiff/downstream/refcounting.swift
@@ -0,0 +1,133 @@
+// RUN: %target-swift-frontend -emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
+// RUN: %target-swift-frontend -emit-sil -Xllvm -differentiation-skip-folding-differentiable-function-extraction %s | %FileCheck %s
+// REQUIRES: asserts
+
+public class NonTrivialStuff : Equatable {
+ public init() {}
+ public static func == (lhs: NonTrivialStuff, rhs: NonTrivialStuff) -> Bool { return true }
+}
+
+@frozen
+public struct Vector : AdditiveArithmetic, Differentiable, Equatable {
+ public var x: Float
+ public var y: Float
+ public var nonTrivialStuff = NonTrivialStuff()
+ public typealias TangentVector = Vector
+ public typealias VectorSpaceScalar = Float
+ public static var zero: Vector { return Vector(0) }
+ public init(_ scalar: Float) { self.x = scalar; self.y = scalar }
+
+ @_silgen_name("Vector_plus")
+ @differentiable
+ public static func + (lhs: Vector, rhs: Vector) -> Vector { abort() }
+
+ @_silgen_name("Vector_subtract")
+ @differentiable
+ public static func - (lhs: Vector, rhs: Vector) -> Vector { abort() }
+
+ public func adding(_ scalar: Float) -> Vector { abort() }
+ public func subtracting(_ scalar: Float) -> Vector { abort() }
+ public func scaled(by scalar: Float) -> Vector { abort() }
+
+ @derivative(of: +)
+ @derivative(of: -)
+ public static func fakeVJP(lhs: Vector, rhs: Vector) -> (value: Vector, pullback: (Vector) -> (Vector, Vector)) { abort() }
+}
+
+// This exists to minimize generated SIL.
+@inline(never) func abort() -> Never { fatalError() }
+
+func testOwnedVector(_ x: Vector) -> Vector {
+ return x + x
+}
+_ = pullback(at: Vector.zero, in: testOwnedVector)
+
+// CHECK-DATA-STRUCTURES-LABEL: struct {{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0 {
+// CHECK-DATA-STRUCTURES-NEXT: var pullback_0: (Vector) -> (Vector, Vector)
+// CHECK-DATA-STRUCTURES-NEXT: }
+// CHECK-DATA-STRUCTURES-LABEL: enum {{.*}}testOwnedVector{{.*}}__Pred__src_0_wrt_0 {
+// CHECK-DATA-STRUCTURES-NEXT: }
+
+// CHECK-LABEL: sil private @{{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__pullback_src_0_wrt_0_1
+// CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1):
+// CHECK: [[PB:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1
+// CHECK: [[NEEDED_COTAN:%.*]] = apply [[PB]]([[SEED]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
+
+// CHECK-LABEL: sil private @{{.*}}subset_pullback_releases_unused_ones{{.*}}__pullback_src_0_wrt_0
+// CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0):
+// CHECK: [[PB1:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0, #{{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0.pullback_0
+// CHECK: [[PB0:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}, #{{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0.pullback_1
+// CHECK: [[NEEDED_COTAN0:%.*]] = apply [[PB0]]([[SEED]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
+// CHECK: strong_release [[PB0]]
+// CHECK-NOT: release_value [[NEEDED_COTAN0]] : $Vector
+// CHECK: [[NEEDED_COTAN1:%.*]] = apply [[PB1]]([[NEEDED_COTAN0]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
+// CHECK: strong_release [[PB1]]
+// CHECK: release_value [[NEEDED_COTAN0]] : $Vector
+// CHECK: return [[NEEDED_COTAN1]] : $Vector
+
+// CHECK-LABEL: sil private @{{.*}}side_effect_release_zero{{.*}}__pullback_src_0_wrt_0
+// CHECK: bb0([[SEED:%.*]] : $Vector, %1 : ${{.*}}side_effect_release_zero{{.*}}_bb0__PB__src_0_wrt_0):
+// CHECK: [[BUF:%.*]] = alloc_stack $Vector
+// CHECK: [[ZERO_GETTER:%.*]] = function_ref @$s11refcounting6VectorV4zeroACvgZ
+// CHECK: [[ZERO:%.*]] = apply [[ZERO_GETTER]]({{%.*}}) : $@convention(method) (@thin Vector.Type) -> @owned Vector
+// CHECK: store [[ZERO]] to [[BUF]] : $*Vector
+// CHECK: load [[BUF]] : $*Vector
+// CHECK: [[ZERO_GETTER:%.*]] = function_ref @$s11refcounting6VectorV4zeroACvgZ
+// CHECK: [[ZERO:%.*]] = apply [[ZERO_GETTER]]({{%.*}}) : $@convention(method) (@thin Vector.Type) -> @owned Vector
+// CHECK: store [[ZERO]] to [[BUF]] : $*Vector
+// CHECK: retain_value [[SEED:%.*]] : $Vector
+// CHECK: release_value [[SEED:%.*]] : $Vector
+// CHECK: destroy_addr [[BUF]] : $*Vector
+// CHECK: dealloc_stack [[BUF]] : $*Vector
+// CHECK: }
+
+// The vjp should not release pullback values.
+//
+// CHECK-LABEL: sil private @{{.*}}testOwnedVector{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (@guaranteed Vector) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> @owned Vector)
+// CHECK: [[ADD:%.*]] = function_ref @Vector_plus
+// CHECK: [[ADD_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @Vector_plus
+// CHECK: [[ADD_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @Vector_plus
+// CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD]] {{.*}} with_derivative {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}}
+// CHECK: [[ADD_AD_FUNC_EXTRACT:%.*]] = differentiable_function_extract [vjp] [[ADD_AD_FUNC]]
+// CHECK: [[ADD_VJP_RESULT:%.*]] = apply [[ADD_AD_FUNC_EXTRACT]]({{.*}}, {{.*}}, {{.*}}) : $@convention(method) (@guaranteed Vector, @guaranteed Vector, @thin Vector.Type) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector))
+// CHECK: [[ADD_PULLBACK:%.*]] = tuple_extract [[ADD_VJP_RESULT]] : $(Vector, @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)), 1
+// CHECK-NOT: release_value [[ADD_VJP_RESULT]]
+// CHECK-NOT: release_value [[ADD_PULLBACK]]
+
+// The pullback should not release pullback struct argument because it has @guaranteed convention.
+//
+// CHECK-LABEL: @{{.*}}testOwnedVector{{.*}}__pullback_src_0_wrt_0
+// CHECK: bb0({{%.*}} : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0):
+// CHECK: [[PULLBACK0:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0.pullback_0
+// CHECK-NOT: release_value [[PULLBACK0]] : @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)
+// CHECK-NOT: release_value [[PB_STRUCT]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0
+// CHECK: }
+
+func side_effect_release_zero(_ x: Vector) -> Vector {
+ var a = x
+ a = a + x
+ a = a - a
+ return a
+}
+_ = pullback(at: Vector.zero, in: side_effect_release_zero)
+
+func subset_pullback_releases_unused_ones(_ x: Vector) -> Vector {
+ let y = x + .zero
+ return .zero + y
+}
+_ = pullback(at: .zero, in: subset_pullback_releases_unused_ones)
+
+struct FakeMaxPool : Differentiable {
+ @differentiable(wrt: (self, input))
+ func applied(to input: Vector) -> Vector { return input }
+}
+
+struct UsesMethodOfNoDerivativeMember : Differentiable {
+ @noDerivative var maxPool = FakeMaxPool()
+
+ func applied(to input: Vector) -> Vector {
+ return maxPool.applied(to: input)
+ }
+}
+
+_ = pullback(at: UsesMethodOfNoDerivativeMember(), .zero) { $0.applied(to: $1) }
diff --git a/test/AutoDiff/downstream/side_effects.swift b/test/AutoDiff/downstream/side_effects.swift
new file mode 100644
index 0000000..e04a9e3
--- /dev/null
+++ b/test/AutoDiff/downstream/side_effects.swift
@@ -0,0 +1,72 @@
+// RUN: %target-swift-frontend -emit-sil -verify %s
+
+func simpleStoreLoad(x: Float) -> Float {
+ var y = x
+ y = x + 1
+ y = x + y
+ return y
+}
+let _: @differentiable (Float) -> Float = simpleStoreLoad(x:)
+
+var global: Float = 10
+
+// Test differentiation of write to non-useful global variable.
+let _: @differentiable (Float) -> Float = { x in
+ global = x
+ return x * x
+}
+
+// Test differentiation of write to non-useful local variable.
+let _: @differentiable (Float) -> Float = { x in
+ var local = x // expected-warning {{initialization of variable 'local' was never used}}
+ return x + x
+}
+
+// Test differentiation of write to useful global variable.
+// expected-error @+1 {{function is not differentiable}}
+let _: @differentiable (Float) -> Float = { x in
+ // expected-note @+1 {{cannot differentiate writes to global variables}}
+ global = x
+ return global + x
+}
+
+// Test differentiation of mutation to captured variables.
+func testMutableCaptures() {
+ var y: Float = 10
+ // expected-error @+1 {{function is not differentiable}}
+ let _: @differentiable (Float) -> Float = { x in
+ // expected-note @+1 {{cannot differentiate writes to mutable captures}}
+ y = x
+ return y + x
+ }
+}
+
+// Test differentiation of write to useful local variable.
+let _: @differentiable (Float) -> Float = { x in
+ var local = x // expected-warning {{variable 'local' was never mutated}}
+ return local + x
+}
+
+// Test differentiation with partial application of @noescape closure.
+// Addresses SR-9653.
+func noEscapePartialApplyTest() {
+ var y: Float = 0 // expected-warning {{variable 'y' was written to, but never read}}
+ let _ = gradient(at: 0) { (x: Float) -> Float in
+ y = x
+ return x + x
+ }
+}
+
+// TF-529: Crash when apply's result is active but arguments aren't.
+struct TF_529_Vector<T: Numeric & Differentiable>: AdditiveArithmetic & Differentiable {
+ var x, y: T
+}
+
+@differentiable
+func TF_529<T>(x: TF_529_Vector<T>) -> TF_529_Vector<T> {
+ var zero = TF_529_Vector<T>.zero
+ zero = x
+ return zero
+}
+
+// TODO: Add file checks.
diff --git a/test/AutoDiff/downstream/sil_diagnostics_after_differentiation.swift b/test/AutoDiff/downstream/sil_diagnostics_after_differentiation.swift
new file mode 100644
index 0000000..c0f07a5
--- /dev/null
+++ b/test/AutoDiff/downstream/sil_diagnostics_after_differentiation.swift
@@ -0,0 +1,19 @@
+// RUN: %target-swift-frontend -emit-sil -verify %s
+
+// This test file contains SIL diagnostics tests for differentiable functions
+// such as escaping capture errors.
+// NOTE: Only add tests for errors that would occur after the differentiation
+// transform.
+
+func nonescapingArgument(f: @differentiable (Float, Float) -> Float) -> Float {
+ return gradient(at: 1) { x in f(x, x) }
+}
+
+// expected-note @+2 {{parameter 'f' is implicitly non-escaping}}
+func nonescapingArgumentError(
+ f: @differentiable (Float, Float) -> Float
+) -> @differentiable (Float) -> Float{
+ // expected-error @+2 {{escaping closure captures non-escaping parameter 'f'}}
+ // expected-note @+1 {{captured here}}
+ return { x in f(x, x) }
+}
diff --git a/test/AutoDiff/downstream/sil_differentiability_witness_reference_serialization.sil b/test/AutoDiff/downstream/sil_differentiability_witness_reference_serialization.sil
new file mode 100644
index 0000000..e51d485
--- /dev/null
+++ b/test/AutoDiff/downstream/sil_differentiability_witness_reference_serialization.sil
@@ -0,0 +1,18 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend -emit-module -emit-module-path %t/test.swiftmodule -module-name test %s
+// RUN: %target-sil-opt %t/test.swiftmodule
+
+sil_stage raw
+
+import Swift
+import Builtin
+
+sil_differentiability_witness [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
+
+sil @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
+
+sil [serialized] @test_serialized : $@convention(thin) () -> () {
+bb0:
+ %referenced_from_serialized_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
+ return undef : $()
+}
diff --git a/test/AutoDiff/downstream/sil_differentiability_witness_silgen.swift b/test/AutoDiff/downstream/sil_differentiability_witness_silgen.swift
new file mode 100644
index 0000000..92ef113
--- /dev/null
+++ b/test/AutoDiff/downstream/sil_differentiability_witness_silgen.swift
@@ -0,0 +1,238 @@
+// RUN: %target-swift-frontend -emit-silgen %s | %target-sil-opt | %FileCheck %s
+
+// Test SIL differentiability witness SIL generation.
+
+// Test public non-generic function.
+// SIL differentiability witness:
+// - Has public linkage (implicit).
+// - Has no `where` clause.
+
+public func foo(_ x: Float) -> Float { x }
+
+@derivative(of: foo)
+public func foo_jvp(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
+ (x, { $0 })
+}
+
+@derivative(of: foo)
+public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { $0 })
+}
+
+// CHECK-LABEL: // differentiability witness for foo(_:)
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3fooyS2fF : $@convention(thin) (Float) -> Float {
+// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-NEXT: }
+
+// Test internal non-generic function.
+// SIL differentiability witness:
+// - Has hidden linkage.
+// - Has no `where` clause.
+// - Has only VJP.
+
+func bar<T>(_ x: Float, _ y: T) -> Float { x }
+
+@usableFromInline
+@derivative(of: bar)
+func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
+ (x, { $0 })
+}
+
+// CHECK-LABEL: // differentiability witness for bar<A>(_:_:)
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0> @$s36sil_differentiability_witness_silgen3baryS2f_xtlF : $@convention(thin) <T> (Float, @in_guaranteed T) -> Float {
+// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3baryS2f_xtlF__jvp_src_0_wrt_0_l : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK-NEXT: }
+
+// Test internal generic function.
+// SIL differentiability witness:
+// - Has hidden linkage.
+// - Has `where` clause.
+
+func generic<T>(_ x: T, _ y: Float) -> T { x }
+
+@derivative(of: generic)
+func generic_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (
+ value: T, differential: (T.TangentVector, Float) -> T.TangentVector
+) {
+ (x, { dx, dy in dx })
+}
+
+@derivative(of: generic)
+func generic_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (
+ value: T, pullback: (T.TangentVector) -> (T.TangentVector, Float)
+) {
+ (x, { ($0, .zero) })
+}
+
+// CHECK-LABEL: // differentiability witness for generic<A>(_:_:)
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @$s36sil_differentiability_witness_silgen7genericyxx_SftlF : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
+// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__jvp_src_0_wrt_0_1_s14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
+// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__vjp_src_0_wrt_0_1_s14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
+// CHECK-NEXT: }
+
+public struct Foo: Differentiable {
+ public var x: Float
+
+// CHECK-LABEL: // differentiability witness for Foo.x.getter
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
+// CHECK-NEXT: }
+
+ @differentiable
+ public init(_ x: Float) {
+ self.x = x
+ }
+
+// CHECK-LABEL: // differentiability witness for Foo.init(_:)
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVyACSfcfC : $@convention(method) (Float, @thin Foo.Type) -> Foo {
+// CHECK-NEXT: }
+
+ @differentiable
+ public func method() -> Float {
+ x
+ }
+
+// CHECK-LABEL: // differentiability witness for Foo.method()
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV6methodSfyF : $@convention(method) (Foo) -> Float {
+// CHECK-NEXT: }
+
+ @differentiable
+ public var computedProperty: Float {
+ x
+ }
+
+// CHECK-LABEL: // differentiability witness for Foo.computedProperty.getter
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV16computedPropertySfvg : $@convention(method) (Foo) -> Float {
+// CHECK-NEXT: }
+
+ @differentiable
+ public subscript() -> Float {
+ x
+ }
+
+// CHECK-LABEL: // differentiability witness for Foo.subscript.getter
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVSfycig : $@convention(method) (Foo) -> Float {
+// CHECK-NEXT: }
+}
+
+// Test function that is differentiable wrt subset of its parameters:
+// - wrt x: explicit @differentiable attribute, with no custom derivative specified
+// - wrt y: explicit @differentiable attribute, with custom derivative specified
+// - wrt x, y: custom deriviative specified, with no explicit @differentiable attribute
+// Has a tuple argument to verify that indices are correctly lowered to SIL.
+
+@differentiable(wrt: x)
+public func wrt_subset(_ tup: (Int, Int), _ x: Float, _ y: Float) -> Float {
+ return 0
+}
+
+@derivative(of: wrt_subset, wrt: y)
+public func wrt_subset_jvp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float) -> Float) {
+ return (0, { $0 })
+}
+
+@derivative(of: wrt_subset, wrt: y)
+public func wrt_subset_vjp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> Float) {
+ return (0, { $0 })
+}
+
+@derivative(of: wrt_subset)
+public func wrt_subset_jvp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float, Float) -> Float) {
+ return (0, { $0 + $1 })
+}
+
+@derivative(of: wrt_subset)
+public func wrt_subset_vjp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
+ return (0, { ($0, $0) })
+}
+
+// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
+// CHECK-NEXT: }
+
+// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
+// CHECK-NEXT: jvp:
+// CHECK-NEXT: vjp:
+// CHECK-NEXT: }
+
+// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
+// CHECK-NEXT: jvp:
+// CHECK-NEXT: vjp:
+// CHECK-NEXT: }
+
+// Test original function with `@differentiable` and `@derivative` attributes.
+
+protocol P1: Differentiable {}
+extension P1 {
+ @differentiable // derivative generic signature: none
+ func foo() -> Float { 1 }
+}
+extension P1 {
+ @derivative(of: foo) // derivative generic signature: `<P1 where Self: P1>`
+ func vjpFoo() -> (value: Float, pullback: (Float) -> (TangentVector)) {
+ fatalError()
+ }
+}
+
+// CHECK-LABEL: // differentiability witness for P1.foo()
+// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF : $@convention(method) <Self where Self : P1> (@in_guaranteed Self) -> Float {
+// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF__vjp_src_0_wrt_0_36sil_differentiability_witness_silgen2P1Rzl : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>)
+// CHECK-NEXT: }
+
+// Test custom derivatives of functions with generic signatures and `@differentiable` attributes.
+
+@differentiable
+@_silgen_name("genericWithDiffAttr")
+public func genericWithDiffAttr<T: Differentiable>(_ x: T) -> T { fatalError() }
+
+@derivative(of: genericWithDiffAttr)
+public func vjpGenericWithDiffAttr<T: Differentiable>(_ x: T)
+ -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
+{
+ fatalError()
+}
+
+// CHECK-LABEL: // differentiability witness for genericWithDiffAttr
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithDiffAttr : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> @out T {
+// CHECK-NEXT: vjp
+// CHECK-NEXT: }
+
+// CHECK-NOT: // differentiability witness for genericWithDiffAttr
+
+@differentiable(where T: Differentiable)
+@_silgen_name("genericWithConstrainedDifferentiable")
+public func genericWithConstrainedDifferentiable<T>(_ x: T) -> T { fatalError() }
+
+@derivative(of: genericWithConstrainedDifferentiable)
+public func vjpGenericWithConstrainedDifferentiable<T: Differentiable>(_ x: T)
+ -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
+{
+ fatalError()
+}
+
+// CHECK-LABEL: // differentiability witness for genericWithConstrainedDifferentiable
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithConstrainedDifferentiable : $@convention(thin) <T> (@in_guaranteed T) -> @out T {
+// CHECK-NEXT: vjp
+// CHECK-NEXT: }
+
+// CHECK-NOT: // differentiability witness for genericWithConstrainedDifferentiable
+
+public extension Differentiable {
+ @differentiable
+ @_silgen_name("protocolExtensionWithDiffAttr")
+ func protocolExtensionWithDiffAttr() -> Self { self }
+
+ @derivative(of: protocolExtensionWithDiffAttr)
+ func protocolExtensionWithDiffAttr() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
+ fatalError("unimplemented")
+ }
+}
+
+// CHECK-LABEL: // differentiability witness for protocolExtensionWithDiffAttr
+// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @protocolExtensionWithDiffAttr : $@convention(method) <Self where Self : Differentiable> (@in_guaranteed Self) -> @out Self {
+// CHECK-NEXT: vjp
+// CHECK-NEXT: }
+
+// CHECK-NOT: // differentiability witness for protocolExtensionWithDiffAttr
diff --git a/test/AutoDiff/downstream/silgen_thunking/main.swift b/test/AutoDiff/downstream/silgen_thunking/main.swift
new file mode 100644
index 0000000..0a814eb
--- /dev/null
+++ b/test/AutoDiff/downstream/silgen_thunking/main.swift
@@ -0,0 +1,194 @@
+// RUN: %target-swift-frontend -emit-silgen -verify %s %S/../Inputs/silgen_thunking_other_module.swift | %FileCheck %s
+
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift %S/../Inputs/silgen_thunking_other_module.swift %s -o %t/a.out
+// RUN: %target-codesign %t/a.out
+// RUN: %target-run %t/a.out
+
+// REQUIRES: executable_test
+
+import StdlibUnittest
+import DifferentiationUnittest
+
+// Verify that SILGen derivative thunks are never `[transparent]`.
+func noReabstraction<T: Differentiable>(_ x: T) -> T {
+ return x
+}
+@derivative(of: noReabstraction)
+func vjpNoReabstraction<T: Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) {
+ return (x, { $0 })
+}
+// Find the non-`[transparent]` SILGen thunk.
+// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main15noReabstractionyxxs14DifferentiableRzlF__vjp_src_0_wrt_0{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
+
+var DerivativeSILGenThunkTests = TestSuite("DerivativeSILGenThunks")
+
+// TF-619: Test cross-module import of `@differentiable` methods with
+// self-ordering thunks.
+DerivativeSILGenThunkTests.testWithLeakChecking("CrossModuleMethodSelfReorderingThunk") {
+ expectEqual(1, gradient(at: 0) { x in TF_619().foo(x) })
+}
+
+// TF-698, TF-742: Test thunks that perform self-ordering but not reabstraction.
+struct SelfReordering : Differentiable & AdditiveArithmetic {
+ var x: Tracked<Float>
+ init(_ x: Tracked<Float>) {
+ self.x = x
+ }
+
+ // TF-742: Test method with three parameters (including `self`).
+ // Note: pullback returns direct `Self.TangentVector`.
+ func threeParameterMethod(_: Self, _: Self) -> Self {
+ return self
+ }
+ @derivative(of: threeParameterMethod)
+ func jvpThreeParameterMethod(_ x: Self, _ y: Self) -> (value: Self, differential: (Self, Self, Self) -> Self) {
+ let value = threeParameterMethod(x, y)
+ return (value, { dself, dx, dy in Self(dself.x + dx.x * 2 + dy.x * 3) })
+ }
+ @derivative(of: threeParameterMethod)
+ func vjpThreeParameterMethod(_ x: Self, _ y: Self) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) {
+ let value = threeParameterMethod(x, y)
+ return (value, { v in (Self(1), Self(2), Self(3)) })
+ }
+
+// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__jvp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> @owned SelfReordering)
+// CHECK: bb0([[X:%.*]] : @guaranteed $SelfReordering, [[Y:%.*]] : @guaranteed $SelfReordering, [[SELF:%.*]] : @guaranteed $SelfReordering):
+// CHECK: [[JVP:%.*]] = function_ref @$s4main14SelfReorderingV23jvpThreeParameterMethodyAC5value_A2C_A2Ctc12differentialtAC_ACtF
+// CHECK: [[JVP_RESULT:%.*]] = apply [[JVP]]([[X]], [[Y]], [[SELF]])
+// CHECK: ([[JVP_ORIG_RESULT:%.*]], [[DF:%.*]]) = destructure_tuple [[JVP_RESULT]]
+// CHECK: [[DF_SELF_REORDER_THUNK:%.*]] = function_ref @AD__$s4main14SelfReorderingVA3CIeggggo_A4CIeggggo_TR_differential_self_reordering_thunk
+// CHECK: [[THUNKED_DF:%.*]] = partial_apply [callee_guaranteed] [[DF_SELF_REORDER_THUNK]]([[DF]])
+// CHECK: [[RESULT:%.*]] = tuple ([[JVP_ORIG_RESULT]] : $SelfReordering, [[THUNKED_DF]] : {{.*}})
+// CHECK: return [[RESULT]]
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @AD__$s4main14SelfReorderingVA3CIeggggo_A4CIeggggo_TR_differential_self_reordering_thunk : $@convention(thin) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed @callee_guaranteed (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> @owned SelfReordering) -> @owned SelfReordering
+// CHECK: bb0([[DX:%.*]] : @guaranteed $SelfReordering, [[DY:%.*]] : @guaranteed $SelfReordering, [[DSELF:%.*]] : @guaranteed $SelfReordering, [[DF:%.*]] : @guaranteed $@callee_guaranteed (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> @owned SelfReordering)
+// CHECK: [[DF_RESULT:%.*]] = apply [[DF]]([[DSELF]], [[DX]], [[DY]])
+// CHECK: return [[DF_RESULT]]
+
+// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__vjp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering))
+// CHECK: bb0([[X:%.*]] : @guaranteed $SelfReordering, [[Y:%.*]] : @guaranteed $SelfReordering, [[SELF:%.*]] : @guaranteed $SelfReordering):
+// CHECK: [[VJP:%.*]] = function_ref @$s4main14SelfReorderingV23vjpThreeParameterMethodyAC5value_AC_A2CtACc8pullbacktAC_ACtF
+// CHECK: [[VJP_RESULT:%.*]] = apply [[VJP]]([[X]], [[Y]], [[SELF]])
+// CHECK: ([[VJP_ORIG_RESULT:%.*]], [[PB:%.*]]) = destructure_tuple [[VJP_RESULT]]
+// CHECK: [[PB_SELF_REORDER_THUNK:%.*]] = function_ref @AD__$s4main14SelfReorderingVA3CIeggooo_A4CIeggooo_TR_pullback_self_reordering_thunk
+// CHECK: [[THUNKED_PB:%.*]] = partial_apply [callee_guaranteed] [[PB_SELF_REORDER_THUNK]]([[PB]])
+// CHECK: [[RESULT:%.*]] = tuple ([[VJP_ORIG_RESULT]] : $SelfReordering, [[THUNKED_PB]] : {{.*}})
+// CHECK: return [[RESULT]]
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @AD__$s4main14SelfReorderingVA3CIeggooo_A4CIeggooo_TR_pullback_self_reordering_thunk : $@convention(thin) (@guaranteed SelfReordering, @guaranteed @callee_guaranteed (@guaranteed SelfReordering) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering)) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering)
+// CHECK: bb0([[SEED:%.*]] : @guaranteed $SelfReordering, [[PB:%.*]] : @guaranteed $@callee_guaranteed (@guaranteed SelfReordering) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering)):
+// CHECK: [[PB_RESULT:%.*]] = apply [[PB]]([[SEED]])
+// CHECK: ([[SELF_ADJ:%.*]], [[X_ADJ:%.*]], [[Y_ADJ:%.*]]) = destructure_tuple %2 : $(SelfReordering, SelfReordering, SelfReordering)
+// CHECK: [[RESULT:%.*]] = tuple ([[X_ADJ]] : $SelfReordering, [[Y_ADJ]] : $SelfReordering, [[SELF_ADJ]] : $SelfReordering)
+// CHECK: return [[RESULT]]
+}
+
+// TF-742: Test thunks that perform self-ordering but not reabstraction.
+struct SelfReorderingGeneric<Dummy>: Differentiable
+where Dummy: Differentiable & ExpressibleByIntegerLiteral {
+ // The property with type `Dummy` makes `Self` be indirect.
+ var indirectDummy: Dummy = 0
+ var x: Tracked<Float>
+ init(_ x: Tracked<Float>) {
+ self.x = x
+ }
+
+ // TF-742: Test method with three parameters (including `self`).
+ // Note: pullback returns indirect `Self.TangentVector`.
+ func threeParameterMethod<T: Differentiable, U: Differentiable>(_: T, _: U) -> Self
+ where T.TangentVector: ExpressibleByFloatLiteral, U.TangentVector: ExpressibleByFloatLiteral {
+ return self
+ }
+ @derivative(of: threeParameterMethod)
+ func jvpThreeParameterMethod<T: Differentiable, U: Differentiable>(_ x: T, _ y: U)
+ -> (value: Self, differential: (Self.TangentVector, T.TangentVector, U.TangentVector) -> Self.TangentVector)
+ where T.TangentVector: ExpressibleByFloatLiteral, U.TangentVector: ExpressibleByFloatLiteral {
+ let value = threeParameterMethod(x, y)
+ // TODO: Make this test meaningful/robust.
+ return (value, { dself, dx, dy in dself })
+ }
+ @derivative(of: threeParameterMethod)
+ func vjpThreeParameterMethod<T: Differentiable, U: Differentiable>(_ x: T, _ y: U)
+ -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, T.TangentVector, U.TangentVector))
+ where T.TangentVector: ExpressibleByFloatLiteral, U.TangentVector: ExpressibleByFloatLiteral {
+ let value = threeParameterMethod(x, y)
+ return (value, { v in (v, 2.0, 3.0) })
+ }
+
+// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__jvp_src_0_wrt_0_1_2_s14DifferentiableRzs27ExpressibleByIntegerLiteralRzsAARd__sAARd_0_s0bc5FloatE013TangentVectorRpd__sAcDRpd_0_r_0_l : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1, @in_guaranteed τ_0_2) -> @out τ_0_3 for <τ_1_0.TangentVector, τ_1_1.TangentVector, SelfReorderingGeneric<τ_0_0>.TangentVector, SelfReorderingGeneric<τ_0_0>.TangentVector>) {
+// CHECK: bb0([[JVP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>):
+// CHECK: [[JVP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23jvpThreeParameterMethodyACyxG5value_AC13TangentVectorVyx_GAI_AGQyd__AGQyd_0_tc12differentialtqd___qd_0_ts14DifferentiableRd__sAMRd_0_s25ExpressibleByFloatLiteralAJRQsAnKRQr0_lF
+// CHECK: [[DF:%.*]] = apply [[JVP]]<τ_0_0, τ_1_0, τ_1_1>([[JVP_RESULT]], [[X]], [[Y]], [[SELF]])
+// CHECK: [[DF_CONVERTED:%.*]] = convert_function [[DF]]
+// CHECK: [[DF_SELF_REORDER_THUNK:%.*]] = function_ref @AD__$s4main21SelfReorderingGenericV13TangentVectorVyx_GADs14DifferentiablePQyd__AdHQyd_0_AFIegnnnr_Aij2FIegnnnr_sAGRzs27ExpressibleByIntegerLiteralRzsAGRd__sAGRd_0_s0hi5FloatK0ADRpd__sAlDRpd_0_r_0_lTR_differential_self_reordering_thunk
+// CHECK: [[THUNKED_DF:%.*]] = partial_apply [callee_guaranteed] [[DF_SELF_REORDER_THUNK]]<τ_0_0, τ_1_0, τ_1_1>([[DF_CONVERTED]])
+// CHECK: [[THUNKED_DF_CONVERTED:%.*]] = convert_function [[THUNKED_DF]]
+// CHECK: return [[THUNKED_DF_CONVERTED]]
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @AD__$s4main21SelfReorderingGenericV13TangentVectorVyx_GADs14DifferentiablePQyd__AdHQyd_0_AFIegnnnr_Aij2FIegnnnr_sAGRzs27ExpressibleByIntegerLiteralRzsAGRd__sAGRd_0_s0hi5FloatK0ADRpd__sAlDRpd_0_r_0_lTR_differential_self_reordering_thunk : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector, @in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector, @guaranteed @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector, @in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector {
+// CHECK: bb0([[DF_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>.TangentVector, [[DX:%.*]] : $*τ_1_0.TangentVector, [[DY:%.*]] : $*τ_1_1.TangentVector, [[DSELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>.TangentVector, [[DF:%.*]] : @guaranteed $@callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector, @in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector):
+// CHECK: {{%.*}} = apply [[DF]]([[DF_RESULT]], [[DSELF]], [[DX]], [[DY]])
+// CHECK: [[VOID:%.*]] = tuple ()
+// CHECK: return [[VOID]]
+
+// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__vjp_src_0_wrt_0_1_2_s14DifferentiableRzs27ExpressibleByIntegerLiteralRzsAARd__sAARd_0_s0bc5FloatE013TangentVectorRpd__sAcDRpd_0_r_0_l : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2, @out τ_0_3) for <SelfReorderingGeneric<τ_0_0>.TangentVector, τ_1_0.TangentVector, τ_1_1.TangentVector, SelfReorderingGeneric<τ_0_0>.TangentVector>) {
+// CHECK: bb0([[VJP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>):
+// CHECK: [[VJP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23vjpThreeParameterMethodyACyxG5value_AC13TangentVectorVyx_G_AGQyd__AGQyd_0_tAIc8pullbacktqd___qd_0_ts14DifferentiableRd__sAMRd_0_s25ExpressibleByFloatLiteralAJRQsAnKRQr0_lF
+// CHECK: [[PB:%.*]] = apply [[VJP]]<τ_0_0, τ_1_0, τ_1_1>([[VJP_RESULT]], [[X]], [[Y]], [[SELF]])
+// CHECK: [[PB_CONVERTED:%.*]] = convert_function [[PB]]
+// CHECK: [[PB_SELF_REORDER_THUNK:%.*]] = function_ref @AD__$s4main21SelfReorderingGenericV13TangentVectorVyx_GAfDs14DifferentiablePQyd__AdHQyd_0_Iegnrrr_AfijFIegnrrr_sAGRzs27ExpressibleByIntegerLiteralRzsAGRd__sAGRd_0_s0hi5FloatK0ADRpd__sAlDRpd_0_r_0_lTR_pullback_self_reordering_thunk
+// CHECK: [[THUNKED_PB:%.*]] = partial_apply [callee_guaranteed] [[PB_SELF_REORDER_THUNK]]<τ_0_0, τ_1_0, τ_1_1>([[PB_CONVERTED]])
+// CHECK: [[THUNKED_PB_CONVERTED:%.*]] = convert_function [[THUNKED_PB]]
+// CHECK: return [[THUNKED_PB_CONVERTED]]
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @AD__$s4main21SelfReorderingGenericV13TangentVectorVyx_GAfDs14DifferentiablePQyd__AdHQyd_0_Iegnrrr_AfijFIegnrrr_sAGRzs27ExpressibleByIntegerLiteralRzsAGRd__sAGRd_0_s0hi5FloatK0ADRpd__sAlDRpd_0_r_0_lTR_pullback_self_reordering_thunk : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector, @guaranteed @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out SelfReorderingGeneric<τ_0_0>.TangentVector, @out τ_1_0.TangentVector, @out τ_1_1.TangentVector)) -> (@out τ_1_0.TangentVector, @out τ_1_1.TangentVector, @out SelfReorderingGeneric<τ_0_0>.TangentVector) {
+// CHECK: bb0([[X_ADJ:%.*]] : $*τ_1_0.TangentVector, [[Y_ADJ:%.*]] : $*τ_1_1.TangentVector, [[SELF_ADJ:%.*]] : $*SelfReorderingGeneric<τ_0_0>.TangentVector, [[SEED:%.*]] : $*SelfReorderingGeneric<τ_0_0>.TangentVector, [[PB:%.*]] : @guaranteed $@callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out SelfReorderingGeneric<τ_0_0>.TangentVector, @out τ_1_0.TangentVector, @out τ_1_1.TangentVector)):
+// CHECK: {{%.*}} = apply [[PB]]([[SELF_ADJ]], [[X_ADJ]], [[Y_ADJ]], [[SEED]])
+// CHECK: [[VOID:%.*]] = tuple ()
+// CHECK: return [[VOID]]
+}
+
+// Test thunk linkage.
+
+public func hasInternalDerivative(_ x: Float) -> Float { x }
+
+@usableFromInline
+@derivative(of: hasInternalDerivative)
+internal func internalDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { $0 })
+}
+
+// CHECK-LABEL: sil [thunk] [always_inline] [ossa] @AD__$s4main21hasInternalDerivativeyS2fF__vjp_src_0_wrt_0
+
+public func hasPublicDerivative(_ x: Float) -> Float { x }
+
+@derivative(of: hasPublicDerivative)
+public func publicDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
+ (x, { $0 })
+}
+
+// CHECK-LABEL: sil [thunk] [always_inline] [ossa] @AD__$s4main19hasPublicDerivativeyS2fF__vjp_src_0_wrt_0
+
+extension SelfReorderingGeneric.TangentVector : ExpressibleByFloatLiteral {}
+
+DerivativeSILGenThunkTests.testWithLeakChecking("SelfReorderingNonReabstractingThunk") {
+ do {
+ let v = SelfReordering(1)
+ // TODO: Add JVP/differential tests.
+ expectEqual((SelfReordering(1), SelfReordering(2), SelfReordering(3)),
+ pullback(at: v, v, v) { x, y, z in x.threeParameterMethod(y, z) }(v))
+ }
+ do {
+ let dummy: Float = 0
+ let x = SelfReorderingGeneric<Float>(1)
+ let v = SelfReorderingGeneric<Float>.TangentVector(indirectDummy: dummy, x: 1)
+ let tracked = Tracked<Float>(1.0)
+ // TODO: Add JVP/differential tests.
+ expectEqual((v, 2, 3),
+ pullback(at: x, tracked, tracked) { x, y, z in x.threeParameterMethod(y, z) }(v))
+ }
+}
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/simple_real_vector.swift b/test/AutoDiff/downstream/simple_real_vector.swift
new file mode 100644
index 0000000..c649ffb
--- /dev/null
+++ b/test/AutoDiff/downstream/simple_real_vector.swift
@@ -0,0 +1,72 @@
+// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
+
+@frozen
+public struct Vector : AdditiveArithmetic, Differentiable {
+ public var x: Float
+ public var y: Float
+
+ public static var zero: Vector {
+ return Vector(0)
+ }
+
+ public init(_ scalar: Float) {
+ self.x = scalar
+ self.y = scalar
+ }
+
+ @differentiable
+ public static func + (lhs: Vector, rhs: Vector) -> Vector {
+ abort()
+ }
+
+ @differentiable
+ public static func - (lhs: Vector, rhs: Vector) -> Vector {
+ abort()
+ }
+
+ public static func * (lhs: Float, rhs: Vector) -> Vector {
+ abort()
+ }
+
+ @derivative(of: +)
+ @derivative(of: -)
+ public static func fakeVJP(lhs: Vector, rhs: Vector) -> (value: Vector, pullback: (Vector) -> (Vector, Vector)) {
+ abort()
+ }
+}
+
+// This exists to minimize generated SIL.
+@inline(never) func abort() -> Never { fatalError() }
+
+public func test1() -> Vector {
+ func foo(_ x: Vector) -> Float {
+ return (x + x).x
+ }
+ return gradient(at: Vector(10), in: foo)
+}
+
+// CHECK-LABEL: @{{.*}}test1{{.*}}
+// CHECK: [[CLOSURE:%.*]] = function_ref @{{.*}}test1{{.*}}foo{{.*}} : $@convention(thin) (Vector) -> Float
+// CHECK: [[CLOSURE_THICK:%.*]] = thin_to_thick_function [[CLOSURE]] : $@convention(thin) (Vector) -> Float to $@callee_guaranteed (Vector) -> Float
+// CHECK: [[CLOSURE_DIFF:%.*]] = differentiable_function [parameters 0] [results 0] [[CLOSURE_THICK]] : $@callee_guaranteed (Vector) -> Float
+// CHECK: [[CLOSURE_DIFF_NOESC:%.*]] = convert_escape_to_noescape [not_guaranteed] [[CLOSURE_DIFF]] : $@differentiable @callee_guaranteed (Vector) -> Float to $@differentiable @noescape @callee_guaranteed (Vector) -> Float
+
+// TF-189: `TF189` is a non-trivial type but `TF189.AllDifferentiableVariables` is trivial.
+// Should pass verification.
+@_fixed_layout
+public class NonTrivial {}
+@frozen
+public struct TF189: Differentiable {
+ @noDerivative public let x: Double
+ @noDerivative public let nonTrivial: NonTrivial
+
+ func foo(input: Vector) -> Vector {
+ return pullback(at: self, input) { m, x in
+ m.applied(to: x)
+ }(.zero).1
+ }
+
+ func applied(to input: Vector) -> Vector {
+ return input
+ }
+}
diff --git a/test/AutoDiff/downstream/subset_parameters_thunk.swift b/test/AutoDiff/downstream/subset_parameters_thunk.swift
new file mode 100644
index 0000000..e89fb14
--- /dev/null
+++ b/test/AutoDiff/downstream/subset_parameters_thunk.swift
@@ -0,0 +1,147 @@
+// RUN: %target-run-simple-swift
+// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
+// REQUIRES: executable_test
+
+import StdlibUnittest
+
+var SubsetParameterThunkTests = TestSuite("SubsetParameterThunks")
+
+// MARK: Subset parameter thunk application SIL FileChecks
+
+func foo<T: Numeric>(_ x: T, _ y: T) -> T { x * y }
+
+@derivative(of: foo)
+func foo_vjp<T: Numeric & Differentiable>(_ x: T, _ y: T) -> (
+ value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)
+) {
+ (foo(x, y), { _ in (.zero, .zero) })
+}
+
+@differentiable
+func differentiate_foo_wrt_0(_ x: Float) -> Float {
+ foo(x, 1)
+}
+
+// CHECK-LABEL: sil hidden @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+// CHECK: bb0
+// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
+// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
+// CHECK: [[FOO_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T where T : Differentiable, T : Numeric> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T
+// CHECK: [[FOO_JVP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_JVP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
+// CHECK: [[FOO_JVP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_jvp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// CHECK: [[FOO_JVP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_JVP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// CHECK: [[FOO_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T where T : Differentiable, T : Numeric> @{{.*}}foo{{.*}} : $@convention(thin) <T where T : Numeric> (@in_guaranteed T, @in_guaranteed T) -> @out T
+// CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
+// CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
+// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [parameters 0] [results 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with_derivative {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
+// CHECK: }
+
+// MARK: `inout` parameters
+
+// TF-1204: Test pullback subset parameter thunks.
+
+func inoutDirect(_ x: Float, _ y: inout Double, _ z: Float) {}
+
+@derivative(of: inoutDirect)
+func vjpInoutDirect(_ x: Float, _ y: inout Double, _ z: Float) -> (
+ value: Void, pullback: (inout Double) -> (Float, Float)
+) {
+ return ((), { dy in
+ dy = 3
+ return (2, 4)
+ })
+}
+
+SubsetParameterThunkTests.test("InoutParametersDirect") {
+ @differentiable(wrt: x)
+ @differentiable(wrt: y)
+ @differentiable(wrt: z)
+ func inoutDirectCaller(_ x: Float, _ y: Double, _ z: Float) -> Double {
+ var result = y
+ inoutDirect(x, &result, z)
+ return result
+ }
+
+ let x: Float = 3
+ let y: Double = 4
+ let z: Float = 5
+ expectEqual((2, 3, 4), gradient(at: x, y, z, in: inoutDirectCaller))
+ expectEqual((3, 4), gradient(at: y, z, in: { y, z in inoutDirectCaller(x, y, z) }))
+ expectEqual((2, 4), gradient(at: x, z, in: { x, z in inoutDirectCaller(x, y, z) }))
+ expectEqual((2, 3), gradient(at: x, y, in: { x, y in inoutDirectCaller(x, y, z) }))
+}
+
+func inoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
+ _ x: T, _ y: inout U, _ z: V
+) {}
+
+@derivative(of: inoutIndirect)
+func vjpInoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
+ _ x: T, _ y: inout U, _ z: V
+) -> (
+ value: Void, pullback: (inout U.TangentVector) -> (T.TangentVector, V.TangentVector)
+) {
+ return ((), { dy in
+ return (.zero, .zero)
+ })
+}
+
+SubsetParameterThunkTests.test("InoutParametersIndirect") {
+ @differentiable(wrt: x)
+ @differentiable(wrt: y)
+ @differentiable(wrt: z)
+ @differentiable
+ func inoutIndirectCaller<T: Differentiable, U: Differentiable, V: Differentiable>(
+ _ x: T, _ y: U, _ z: V
+ ) -> U {
+ var result = y
+ inoutIndirect(x, &result, z)
+ return result
+ }
+
+ let x: Float = 3
+ let y: Double = 4
+ let z: Float = 5
+ expectEqual((0, 1, 0), gradient(at: x, y, z, in: inoutIndirectCaller))
+ expectEqual((1, 0), gradient(at: y, z, in: { y, z in inoutIndirectCaller(x, y, z) }))
+ expectEqual((0, 0), gradient(at: x, z, in: { x, z in inoutIndirectCaller(x, y, z) }))
+ expectEqual((0, 1), gradient(at: x, y, in: { x, y in inoutIndirectCaller(x, y, z) }))
+}
+
+// Check SIL for representative pullback subset parameters thunks.
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVectors14DifferentiablePQy_AaCQzAaCQy0_Ieglrr_AdEIeglr_sABRzsABR_sABR0_r1_lTR_src_0_wrt_0_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> @out τ_0_0.TangentVector {
+// CHECK: bb0(%0 : $*τ_0_0.TangentVector, %1 : $*τ_0_1.TangentVector, %2 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)):
+// CHECK: %3 = alloc_stack $τ_0_2.TangentVector
+// CHECK: %4 = apply %2(%0, %3, %1) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)
+// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector
+// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector
+// CHECK: %7 = tuple ()
+// CHECK: return %7 : $()
+// CHECK: }
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$s13TangentVectors14DifferentiablePQy_AaCQzAaCQy0_Ieglrr_ADIegl_sABRzsABR_sABR0_r1_lTR_src_0_wrt_1_pullback_index_subset_thunk : $@convention(thin) <τ_0_0, τ_0_1, τ_0_2 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_2 : Differentiable> (@inout τ_0_1.TangentVector, @guaranteed @callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)) -> () {
+// CHECK: bb0(%0 : $*τ_0_1.TangentVector, %1 : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)):
+// CHECK: %2 = alloc_stack $τ_0_0.TangentVector
+// CHECK: %3 = alloc_stack $τ_0_2.TangentVector
+// CHECK: %4 = apply %1(%2, %3, %0) : $@callee_guaranteed (@inout τ_0_1.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_2.TangentVector)
+// CHECK: destroy_addr %2 : $*τ_0_0.TangentVector
+// CHECK: destroy_addr %3 : $*τ_0_2.TangentVector
+// CHECK: dealloc_stack %3 : $*τ_0_2.TangentVector
+// CHECK: dealloc_stack %2 : $*τ_0_0.TangentVector
+// CHECK: %9 = tuple ()
+// CHECK: return %9 : $()
+// CHECK: }
+
+// CHECK-LABEL: sil shared [transparent] [serialized] [thunk] @AD__$sSdSfSdSfIegnrrr_SdS2fIegnrr_TR_src_0_wrt_0_2_pullback_index_subset_thunk : $@convention(thin) (@in_guaranteed Double, @guaranteed @callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)) -> (@out Float, @out Float) {
+// CHECK: bb0(%0 : $*Float, %1 : $*Float, %2 : $*Double, %3 : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)):
+// CHECK: %4 = alloc_stack $Double
+// CHECK: %5 = apply %3(%0, %4, %1, %2) : $@callee_guaranteed (@in_guaranteed Double) -> (@out Float, @out Double, @out Float)
+// CHECK: destroy_addr %4 : $*Double
+// CHECK: dealloc_stack %4 : $*Double
+// CHECK: %8 = tuple ()
+// CHECK: return %8 : $()
+// CHECK: }
+
+runAllTests()
diff --git a/test/AutoDiff/downstream/tbdgen.swift b/test/AutoDiff/downstream/tbdgen.swift
new file mode 100644
index 0000000..a1caa9a
--- /dev/null
+++ b/test/AutoDiff/downstream/tbdgen.swift
@@ -0,0 +1,135 @@
+// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s
+// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s -O
+// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing
+// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing -O
+
+// TODO: These tests are disabled because the pullback struct makes the TBDGen be different before/after SILGen.
+// UN: %empty-directory(%t)
+// UN: %target-swift-frontend -typecheck -parse-as-library -module-name test %s -emit-tbd -emit-tbd-path %t/typecheck.tbd
+// UN: %target-swift-frontend -emit-ir -parse-as-library -module-name test %s -emit-tbd -emit-tbd-path %t/emit-ir.tbd
+// UN: diff -u %t/typecheck.tbd %t/emit-ir.tbd
+
+@differentiable public func publicDiffable(_ x: Float, _ y: Float) -> Float { return x }
+@differentiable(wrt: (x)) public func publicDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }
+
+// Tests SILGen derivative "forwarding thunk" (no derivative reabstraction/self-reordering).
+@differentiable
+public func publicNoDerivativeReabstraction<T: Differentiable>(_ x: T) -> T { return x }
+@derivative(of: publicNoDerivativeReabstraction)
+public func publicNoDerivativeReabstractionVJP<T: Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) {
+ return (x, { $0 })
+}
+
+@differentiable internal func internalDiffable(_ x: Float, _ y: Float) -> Float { return x }
+@differentiable(wrt: (x)) internal func internalDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }
+
+@differentiable private func privateDiffable(_ x: Float, _ y: Float) -> Float { return x }
+@differentiable(wrt: (x)) private func privateDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }
+
+public extension Float {
+ // This should generate public symbols for both JVP and VJP.
+ @differentiable
+ var x: Float {
+ return self
+ }
+
+ // This should generate public symbols for JVP but not VJP, because VJP is user-defined.
+ @differentiable
+ var y: Float {
+ return .zero
+ }
+
+ @derivative(of: y)
+ func vjpY() -> (value: Float, pullback: (Float) -> Float) {
+ return (.zero, { $0 })
+ }
+
+ // This should generate public symbols for both JVP and VJP.
+ @differentiable
+ init(x: Float) {
+ self = x
+ }
+
+ // This should generate public symbols for both JVP and VJP.
+ // Tests self-reordering-method thunking.
+ @differentiable
+ func method(x: Float, y: Float) -> Float {
+ return x
+ }
+ @derivative(of: method)
+ func jvpMethod(x: Float, y: Float) -> (value: Float, differential: (Float, Float, Float) -> Float) {
+ return (x, { dself, dx, dy in dx })
+ }
+
+ // This should generate public symbols for both JVP and VJP.
+ // Tests self-reordering-method thunking.
+ @differentiable
+ subscript(x: Float) -> Float {
+ return x
+ }
+ @derivative(of: subscript)
+ func vjpSubscript(x: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
+ return (x, { v in (0, v) })
+ }
+}
+
+struct Nontrivial : Differentiable {
+ var base: [Float]
+
+ // This should generate public symbols for both JVP and VJP.
+ // Tests differential/pullback thunking.
+ @differentiable
+ init(_ base: [Float]) {
+ self.base = base
+ }
+ @derivative(of: init)
+ static func jvpInit(_ base: [Float])
+ -> (value: Nontrivial, differential: (Array<Float>.TangentVector) -> Nontrivial.TangentVector) {
+ return (Nontrivial(base), { v in Nontrivial.TangentVector(base: v) })
+ }
+ @derivative(of: init)
+ static func vjpInit(_ base: [Float])
+ -> (value: Nontrivial, pullback: (Nontrivial.TangentVector) -> Array<Float>.TangentVector) {
+ return (Nontrivial(base), { v in v.base })
+ }
+
+ // This should generate public symbols for both JVP and VJP.
+ // Tests differential/pullback thunking.
+ @differentiable
+ func ownedParameter(_ x: __owned [Float]) -> [Float] {
+ return x
+ }
+ @derivative(of: ownedParameter)
+ func vjpOwnedParameterMismatch(_ x: __shared [Float])
+ -> (value: [Float], pullback: (Array<Float>.TangentVector) -> (Nontrivial.TangentVector, Array<Float>.TangentVector)) {
+ return (ownedParameter(x), { v in (.zero, v) })
+ }
+
+ // This should generate public symbols for both JVP and VJP.
+ // Tests differential/pullback thunking.
+ @differentiable
+ func sharedParameter(_ x: __shared [Float]) -> [Float] {
+ return x
+ }
+ @derivative(of: sharedParameter)
+ func vjpSharedParameterMismatch(_ x: __owned [Float])
+ -> (value: [Float], pullback: (Array<Float>.TangentVector) -> (Nontrivial.TangentVector, Array<Float>.TangentVector)) {
+ return (sharedParameter(x), { v in (.zero, v) })
+ }
+}
+
+public func publicDiffableIndirect(_ x: Float, _ y: Float) -> Float { return x }
+
+internal func internalDiffableIndirect(_ x: Float, _ y: Float) -> Float { return x }
+
+private func privateDiffableIndirect(_ x: Float, _ y: Float) -> Float { return x }
+
+func invokeIndirect() {
+ print(gradient(of: publicDiffableIndirect)(1, 2))
+ print(gradient(of: internalDiffableIndirect)(1, 2))
+ print(gradient(of: privateDiffableIndirect)(1, 2))
+}
+
+@inlinable
+@differentiable
+public func inlinableDifferentiable(_ x: Float) -> Float { x }
diff --git a/test/AutoDiff/downstream/witness_method_autodiff.sil b/test/AutoDiff/downstream/witness_method_autodiff.sil
new file mode 100644
index 0000000..3e2d268
--- /dev/null
+++ b/test/AutoDiff/downstream/witness_method_autodiff.sil
@@ -0,0 +1,55 @@
+// RUN: %target-sil-opt -differentiation -differentiation-skip-folding-differentiable-function-extraction %s | %FileCheck %s
+
+sil_stage raw
+
+import Builtin
+import Swift
+import SwiftShims
+
+protocol DiffReq {
+ @differentiable(wrt: (x))
+ func f(_ x: Float) -> Float
+}
+
+sil @differentiateWitnessMethod : $@convention(thin) <T where T : DiffReq> (@in_guaranteed T) -> () {
+bb0(%0 : $*T):
+ %1 = witness_method $T, #DiffReq.f : <Self where Self : DiffReq> (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
+ %2 = differentiable_function [parameters 0] [results 0] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
+
+ %ret = tuple ()
+ return %ret : $()
+}
+
+// CHECK-LABEL: sil @differentiateWitnessMethod
+// CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f
+// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!jvp.SU
+// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!vjp.SU
+// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_REF]] : {{.*}} with_derivative {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}}
+// CHECK: } // end sil function 'differentiateWitnessMethod'
+
+sil @differentiatePartiallyAppliedWitnessMethod : $@convention(thin) <T where T : DiffReq> (@in_guaranteed T) -> () {
+bb0(%0 : $*T):
+ %1 = witness_method $T, #DiffReq.f : <Self where Self : DiffReq> (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
+ %2 = partial_apply [callee_guaranteed] %1<T>(%0) : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
+ %3 = differentiable_function [parameters 0] [results 0] %2 : $@callee_guaranteed (Float) -> Float
+
+ %ret = tuple ()
+ return %ret : $()
+}
+
+// CHECK-LABEL: sil @differentiatePartiallyAppliedWitnessMethod
+// CHECK: bb0([[ARG:%.*]] : $*T):
+// CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f
+// CHECK: [[ARGCOPY1:%.*]] = alloc_stack $T
+// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY1]] : $*T
+// CHECK: [[ARGCOPY2:%.*]] = alloc_stack $T
+// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY2]] : $*T
+// CHECK: [[ORIG_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_REF]]<T>(%0)
+// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!jvp.SU
+// CHECK: [[JVP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_REF]]<T>([[ARGCOPY1]])
+// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!vjp.SU
+// CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]<T>([[ARGCOPY2]])
+// CHECK: dealloc_stack [[ARGCOPY2]]
+// CHECK: dealloc_stack [[ARGCOPY1]]
+// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}}
+// CHECK: } // end sil function 'differentiatePartiallyAppliedWitnessMethod'
diff --git a/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift b/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift
index 1810750..bff9fab 100644
--- a/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift
+++ b/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift
@@ -1,12 +1,6 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
-// Disabled due to test failure with `-O`: SR-13250.
-// SR-13250 is tracking the fix for compiling this test with optimizations.
-// XFAIL: swift_test_mode_optimize
-// XFAIL: swift_test_mode_optimize_size
-// XFAIL: swift_test_mode_optimize_unchecked
-
import StdlibUnittest
import DifferentiationUnittest
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index ede6aee..5fc1590 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -165,6 +165,9 @@
normalize_boolean_spelling(SWIFT_ENABLE_SOURCEKIT_TESTS)
normalize_boolean_spelling(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
normalize_boolean_spelling(SWIFT_ENABLE_EXPERIMENTAL_CONCURRENCY)
+# SWIFT_ENABLE_TENSORFLOW
+normalize_boolean_spelling(SWIFT_ENABLE_TENSORFLOW)
+# SWIFT_ENABLE_TENSORFLOW END
is_build_type_optimized("${SWIFT_STDLIB_BUILD_TYPE}" SWIFT_OPTIMIZED)
set(profdata_merge_worker
diff --git a/test/Constraints/fast-operator-typechecking.swift b/test/Constraints/fast-operator-typechecking.swift
index c51bb03..d2ca27e 100644
--- a/test/Constraints/fast-operator-typechecking.swift
+++ b/test/Constraints/fast-operator-typechecking.swift
@@ -1,3 +1,5 @@
+// SWIFT_ENABLE_TENSORFLOW
+// UNSUPPORTED: tensorflow
// RUN: %target-typecheck-verify-swift -swift-version 5 -solver-enable-operator-designated-types -solver-disable-shrink -disable-constraint-solver-performance-hacks
// rdar://problem/32998180
diff --git a/test/Driver/linker-rpath.swift b/test/Driver/linker-rpath.swift
index 1167e20..c7ef15d 100644
--- a/test/Driver/linker-rpath.swift
+++ b/test/Driver/linker-rpath.swift
@@ -2,26 +2,32 @@
// Note: This is really about the /host/ environment, but since there are RUN
// lines for multiple targets anyway it doesn't make a huge difference.
-// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.14 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.14.3 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.14.4 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.15 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// SWIFT_ENABLE_TENSORFLOW
+// All "-no-toolchain-stdlib-rpath" additions are SWIFT_ENABLE_TENSORFLOW.
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-ios12 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-ios12.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-ios12.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-ios13 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.14 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.14.3 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.14.4 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.15 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-tvos12 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-tvos12.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-tvos12.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target arm64-apple-tvos13 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-ios12 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-ios12.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-ios12.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-ios13 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target armv7k-apple-watchos5 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target armv7k-apple-watchos5.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target armv7k-apple-watchos5.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -target armv7k-apple-watchos6 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-tvos12 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-tvos12.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-tvos12.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target arm64-apple-tvos13 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target armv7k-apple-watchos5 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target armv7k-apple-watchos5.1 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target armv7k-apple-watchos5.2 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target armv7k-apple-watchos6 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+
+// SWIFT_ENABLE_TENSORFLOW END
// RPATH: bin/ld{{"? }}
// RPATH-SAME: -rpath {{"?/usr/lib/swift(-.+)?"? }}
@@ -37,9 +43,11 @@
// RUN: %swiftc_driver_plain -driver-print-jobs -toolchain-stdlib-rpath -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift -resource-dir garbage/ | %FileCheck -check-prefix TOOLCHAIN-RPATH -DPLATFORM=%target-sdk-name %s
// RUN: %swiftc_driver_plain -driver-print-jobs -toolchain-stdlib-rpath -target x86_64-apple-macosx10.15 %S/../Inputs/empty.swift -resource-dir garbage/ | %FileCheck -check-prefix TOOLCHAIN-RPATH -DPLATFORM=%target-sdk-name %s
-// ### Test with -no-toolchain-stdlib-rpath
-// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift | %FileCheck -check-prefix RPATH %s
-// RUN: %swiftc_driver_plain -driver-print-jobs -no-toolchain-stdlib-rpath -target x86_64-apple-macosx10.15 %S/../Inputs/empty.swift | %FileCheck -check-prefix NO-RPATH %s
+// SWIFT_ENABLE_TENSORFLOW
+// ### Test with implicit -toolchain-stdlib-rpath
+// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.9 %S/../Inputs/empty.swift -resource-dir garbage/ | %FileCheck -check-prefix TOOLCHAIN-RPATH -DPLATFORM=%target-sdk-name %s
+// RUN: %swiftc_driver_plain -driver-print-jobs -target x86_64-apple-macosx10.15 %S/../Inputs/empty.swift -resource-dir garbage/ | %FileCheck -check-prefix TOOLCHAIN-RPATH -DPLATFORM=%target-sdk-name %s
+// SWIFT_ENABLE_TENSORFLOW END
// TOOLCHAIN-RPATH: bin/ld{{"? }}
// TOOLCHAIN-RPATH-SAME: -rpath garbage/[[PLATFORM]]{{ }}
diff --git a/test/Driver/options-interpreter.swift b/test/Driver/options-interpreter.swift
index 33741e3..ba2d801 100644
--- a/test/Driver/options-interpreter.swift
+++ b/test/Driver/options-interpreter.swift
@@ -10,7 +10,7 @@
// ARGS: -- a b c
// RUN: %swift_driver -### -parse-stdlib %s | %FileCheck -check-prefix PARSE_STDLIB %s
-// RUN: %swift_driver -### -parse-stdlib | %FileCheck -check-prefix PARSE_STDLIB %s
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// PARSE_STDLIB: -parse-stdlib
diff --git a/test/Driver/options-repl-darwin.swift b/test/Driver/options-repl-darwin.swift
index cbb6339..3112385 100644
--- a/test/Driver/options-repl-darwin.swift
+++ b/test/Driver/options-repl-darwin.swift
@@ -1,3 +1,6 @@
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
+// UNSUPPORTED: tensorflow
+
// REQUIRES: OS=macosx
// Test LLDB detection, first in a clean environment, then in one that looks
diff --git a/test/Driver/options-repl.swift b/test/Driver/options-repl.swift
index 118ebd7..aca4289 100644
--- a/test/Driver/options-repl.swift
+++ b/test/Driver/options-repl.swift
@@ -1,3 +1,6 @@
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
+// UNSUPPORTED: tensorflow
+
// RUN: not %swift -repl %s 2>&1 | %FileCheck -check-prefix=REPL_NO_FILES %s
// RUN: not %swift_driver -sdk "" -repl %s 2>&1 | %FileCheck -check-prefix=REPL_NO_FILES %s
// RUN: not %swift_driver -sdk "" -lldb-repl %s 2>&1 | %FileCheck -check-prefix=REPL_NO_FILES %s
diff --git a/test/Driver/options.swift b/test/Driver/options.swift
index 130eb95..a9b7d97 100644
--- a/test/Driver/options.swift
+++ b/test/Driver/options.swift
@@ -5,11 +5,11 @@
// STDLIB_MODULE: error: module name "Swift" is reserved for the standard library{{$}}
// RUN: not %swiftc_driver -crazy-option-that-does-not-exist %s 2>&1 | %FileCheck -check-prefix=INVALID_OPTION %s
-// RUN: not %swift_driver -crazy-option-that-does-not-exist 2>&1 | %FileCheck -check-prefix=INVALID_OPTION %s
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// INVALID_OPTION: error: unknown argument: '-crazy-option-that-does-not-exist'
// RUN: %swiftc_driver -assert-config Debug -### %s | %FileCheck -check-prefix=ASSERTCONFIG %s
-// RUN: %swift_driver -assert-config Debug -### | %FileCheck -check-prefix=ASSERTCONFIG %s
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// ASSERTCONFIG: -assert-config Debug
// RUN: %swiftc_driver -assert-config Release -### %s | %FileCheck -check-prefix=ASSERTCONFIG_RELEASE %s
@@ -31,26 +31,7 @@
// RUN: not %swiftc_driver -import-objc-header fake.h -emit-module-interface-path fake.swiftinterface %s 2>&1 | %FileCheck -check-prefix=BRIDGING_HEADER_SWIFTINTERFACE %s
// BRIDGING_HEADER_SWIFTINTERFACE: error: using bridging headers with module interfaces is unsupported
-// RUN: %swift_driver -### | %FileCheck -check-prefix=DEFAULT_REPL %s
-// DEFAULT_REPL: -repl
-// RUN: not %swiftc_driver 2>&1 | %FileCheck -check-prefix=DEFAULT_EXEC_ERR %s
-// DEFAULT_EXEC_ERR: error: no input files
-// RUN: %swiftc_driver %s -### 2>&1 | %FileCheck -check-prefix=DEFAULT_EXEC %s
-// DEFAULT_EXEC: -c
-// DEFAULT_EXEC: {{ld|clang\+\+}}
-
-// RUN: %swift_driver -repl -### 2>&1 | %FileCheck -check-prefix=REPL %s
-// REPL: warning: unnecessary option '-repl'
-
-// RUN: %swift_driver -lldb-repl -### 2>&1 | %FileCheck -check-prefix=LLDB_REPL %s
-// LLDB_REPL-NOT: warning
-// LLDB_REPL: lldb
-// LLDB_REPL-NOT: warning
-
-// RUN: %swift_driver -deprecated-integrated-repl -### 2>&1 | %FileCheck -check-prefix=INTEGRATED_REPL %s
-// INTEGRATED_REPL-NOT: warning
-// INTEGRATED_REPL: -repl
-// INTEGRATED_REPL-NOT: warning
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// RUN: %swift_driver -### %s | %FileCheck -check-prefix=DEFAULT_I %s
// DEFAULT_I: -interpret
diff --git a/test/Driver/sdk-apple.swift b/test/Driver/sdk-apple.swift
index 2286572..c374679 100644
--- a/test/Driver/sdk-apple.swift
+++ b/test/Driver/sdk-apple.swift
@@ -1,3 +1,6 @@
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
+// UNSUPPORTED: tensorflow
+
// XFAIL: freebsd, openbsd, linux, windows
// Test SDK detection for immediate mode.
diff --git a/test/Driver/sdk.swift b/test/Driver/sdk.swift
index 98cc05c..92c5c06 100644
--- a/test/Driver/sdk.swift
+++ b/test/Driver/sdk.swift
@@ -4,9 +4,7 @@
// RUN: %swiftc_driver -driver-print-jobs -target x86_64-unknown-windows-msvc -g -sdk %S/../Inputs/clang-importer-sdk %s 2>&1 | %FileCheck %s --check-prefix WINDOWS
// RUN: %swiftc_driver -driver-print-jobs -target wasm32-unknown-wasi -g -sdk %S/../Inputs/clang-importer-sdk %s 2>&1 | %FileCheck %s --check-prefix WASI
-// RUN: env SDKROOT=%S/../Inputs/clang-importer-sdk %swiftc_driver_plain -target x86_64-apple-macosx10.9 -g -driver-print-jobs %s 2>&1 | %FileCheck %s --check-prefix OSX
-// RUN: env SDKROOT=%S/../Inputs/clang-importer-sdk %swiftc_driver_plain -target x86_64-unknown-linux-gnu -g -driver-print-jobs %s 2>&1 | %FileCheck %s --check-prefix LINUX
-// RUN: env SDKROOT=%S/../Inputs/clang-importer-sdk %swiftc_driver_plain -target x86_64-unknown-freebsd -g -driver-print-jobs %s 2>&1 | %FileCheck %s --check-prefix FREEBSD
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// OSX-NOT: warning: no such SDK:
// OSX: bin{{/|\\\\}}swift
@@ -53,12 +51,7 @@
// WASI: -sdk {{.*}}/Inputs/clang-importer-sdk
// WASI: {{-syslibroot|--sysroot}} {{.*}}/Inputs/clang-importer-sdk
-// RUN: %swift_driver -driver-print-jobs -repl -sdk %S/Inputs/nonexistent-sdk 2>&1 | %FileCheck %s --check-prefix=SDKWARNING
-// RUN: %swift_driver -driver-print-jobs -sdk %S/Inputs/nonexistent-sdk 2>&1 | %FileCheck %s --check-prefix=SDKWARNING
-// RUN: env SDKROOT=%S/Inputs/nonexistent-sdk %swift_driver_plain -driver-print-jobs -repl 2>&1 | %FileCheck %s --check-prefix=SDKWARNING
-
-// SDKWARNING: warning: no such SDK: '{{.*}}/Inputs/nonexistent-sdk'
-// SDKWARNING: -sdk {{.*}}/Inputs/nonexistent-sdk
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// RUN: %swiftc_driver -driver-print-jobs -typecheck -sdk %S/../Inputs/clang-importer-sdk -module-cache-path /path/to/cache %s 2>&1 | %FileCheck %s --check-prefix=CACHE-PATH
diff --git a/test/Driver/subcommands.swift b/test/Driver/subcommands.swift
index 14024c7..617842b 100644
--- a/test/Driver/subcommands.swift
+++ b/test/Driver/subcommands.swift
@@ -1,3 +1,6 @@
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
+// UNSUPPORTED: tensorflow
+
// REQUIRES: shell
// Check that 'swift' and 'swift repl' invoke the REPL.
diff --git a/test/Driver/working-directory.swift b/test/Driver/working-directory.swift
index 127ee02..e078172 100644
--- a/test/Driver/working-directory.swift
+++ b/test/Driver/working-directory.swift
@@ -10,10 +10,7 @@
// -working-directory=
// RUN: cd %t && %swiftc_driver -driver-print-jobs -working-directory=%S/Inputs -c %/S/Inputs/main.swift | %FileCheck %s -check-prefix=INPUT
-// In another driver mode.
-// RUN: cd %t && %swift_driver -driver-print-jobs -working-directory %/S/Inputs -F. | %FileCheck %s -check-prefix=REPL
-// RUN: cd %t && %swift_driver -driver-print-jobs -deprecated-integrated-repl -working-directory %/S/Inputs -F. | %FileCheck %s -check-prefix=REPL
-// REPL: -F {{\\?"?}}SOURCE_DIR/test/Driver/Inputs{{(\\\\(\\\\)?)|/}}.
+// SWIFT_ENABLE_TENSORFLOW: REPL is disabled in tensorflow branch.
// RUN: cd %t && %swiftc_driver -driver-print-jobs -working-directory=%/S/Inputs -c -module-name m main.swift lib.swift | %FileCheck %s -check-prefix=MULTI_INPUT
// MULTI_INPUT: SOURCE_DIR/test/Driver/Inputs{{/|\\\\}}main.swift
diff --git a/test/IDE/complete_autodiff.swift b/test/IDE/complete_autodiff.swift
new file mode 100644
index 0000000..2ed3368
--- /dev/null
+++ b/test/IDE/complete_autodiff.swift
@@ -0,0 +1,25 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=COMPLETE1 | %FileCheck --check-prefix=COMPLETE1 %s
+
+protocol DifferentiableRequirements {
+ @differentiable
+ func f(_ x: Float) -> Float
+}
+
+struct Complete1 : DifferentiableRequirements {
+ @differentiable
+ func f#^COMPLETE1^#
+}
+
+// COMPLETE1-LABEL: Begin completions
+// COMPLETE1: func f(_ x: Float) -> Float
+// COMPLETE1: End completions
+
+struct Complete2 : DifferentiableRequirements {
+ @differentiable
+ func f(_ x: Float)#^COMPLETE2^#
+}
+
+// COMPLETE2-LABEL: Begin completions
+// COMPLETE2: func f(_ x: Float) -> Float
+// COMPLETE2: End completions
diff --git a/test/IDE/complete_decl_attribute.swift b/test/IDE/complete_decl_attribute.swift
index 23606da..72df159 100644
--- a/test/IDE/complete_decl_attribute.swift
+++ b/test/IDE/complete_decl_attribute.swift
@@ -70,6 +70,9 @@
// KEYWORD2-NEXT: Keyword/None: derivative[#Func Attribute#]; name=derivative
// KEYWORD2-NEXT: Keyword/None: transpose[#Func Attribute#]; name=transpose
// KEYWORD2-NEXT: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
+// SWIFT_ENABLE_TENSORFLOW
+// KEYWORD2-NEXT: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
+// SWIFT_ENABLE_TENSORFLOW END
// KEYWORD2-NOT: Keyword
// KEYWORD2: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
// KEYWORD2: End completions
@@ -182,11 +185,13 @@
// ON_METHOD-DAG: Keyword/None: warn_unqualified_access[#Func Attribute#]; name=warn_unqualified_access
// ON_METHOD-DAG: Keyword/None: usableFromInline[#Func Attribute#]; name=usableFromInline
// ON_METHOD-DAG: Keyword/None: discardableResult[#Func Attribute#]; name=discardableResult
-// ON_METHOD-DAG: Keyword/None: IBSegueAction[#Func Attribute#]; name=IBSegueAction
// ON_METHOD-DAG: Keyword/None: differentiable[#Func Attribute#]; name=differentiable
// ON_METHOD-DAG: Keyword/None: derivative[#Func Attribute#]; name=derivative
// ON_METHOD-DAG: Keyword/None: transpose[#Func Attribute#]; name=transpose
// ON_METHOD-DAG: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
+// SWIFT_ENABLE_TENSORFLOW
+// ON_METHOD-DAG: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
+// SWIFT_ENABLE_TENSORFLOW END
// ON_METHOD-NOT: Keyword
// ON_METHOD: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
// ON_METHOD: End completions
@@ -245,6 +250,9 @@
// ON_MEMBER_LAST-DAG: Keyword/None: derivative[#Declaration Attribute#]; name=derivative
// ON_MEMBER_LAST-DAG: Keyword/None: transpose[#Declaration Attribute#]; name=transpose
// ON_MEMBER_LAST-DAG: Keyword/None: noDerivative[#Declaration Attribute#]; name=noDerivative
+// SWIFT_ENABLE_TENSORFLOW
+// ON_MEMBER_LAST-DAG: Keyword/None: compilerEvaluable[#Declaration Attribute#]; name=compilerEvaluable
+// SWIFT_ENABLE_TENSORFLOW END
// ON_MEMBER_LAST-NOT: Keyword
// ON_MEMBER_LAST: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
// ON_MEMBER_LAST-NOT: Decl[PrecedenceGroup]
@@ -292,6 +300,9 @@
// KEYWORD_LAST-NEXT: Keyword/None: derivative[#Declaration Attribute#]; name=derivative
// KEYWORD_LAST-NEXT: Keyword/None: transpose[#Declaration Attribute#]; name=transpose
// KEYWORD_LAST-NEXT: Keyword/None: noDerivative[#Declaration Attribute#]; name=noDerivative
+// SWIFT_ENABLE_TENSORFLOW
+// KEYWORD_LAST-NEXT: Keyword/None: compilerEvaluable[#Declaration Attribute#]; name=compilerEvaluable
+// SWIFT_ENABLE_TENSORFLOW END
// KEYWORD_LAST-NOT: Keyword
// KEYWORD_LAST: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
// KEYWORD_LAST: End completions
diff --git a/test/IDE/complete_repl_pattern_binding_with_closure.swift b/test/IDE/complete_repl_pattern_binding_with_closure.swift
new file mode 100644
index 0000000..24926fb
--- /dev/null
+++ b/test/IDE/complete_repl_pattern_binding_with_closure.swift
@@ -0,0 +1,6 @@
+// RUN: %target-swift-ide-test -repl-code-completion -source-filename %s | %FileCheck %s
+// CHECK: Begin completions
+// CHECK: print
+// CHECK: End completions
+let (a, b) = ({ _ in (1, 2) })(0)
+p
diff --git a/test/IDE/complete_tf_198.swift b/test/IDE/complete_tf_198.swift
new file mode 100644
index 0000000..b43254b
--- /dev/null
+++ b/test/IDE/complete_tf_198.swift
@@ -0,0 +1,7 @@
+// https://bugs.swift.org/browse/TF-198: `@dynamicCallable` REPL completer crash.
+// RUN: %target-swift-ide-test -repl-code-completion -source-filename=%s
+
+// TODO(TF-214): Require `python` lit feature, after it is created.
+
+import Python
+Python.str("name").strip(
diff --git a/test/IDE/complete_tf_239.swift b/test/IDE/complete_tf_239.swift
new file mode 100644
index 0000000..67fa01a
--- /dev/null
+++ b/test/IDE/complete_tf_239.swift
@@ -0,0 +1,10 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=COMPLETE | %FileCheck %s
+
+if true {
+ print("\(1)")
+ let foo = #^COMPLETE^#
+}
+
+// CHECK-LABEL: Begin completions
+// CHECK: End completions
diff --git a/test/IDE/complete_tf_315.swift b/test/IDE/complete_tf_315.swift
new file mode 100644
index 0000000..eadff7c
--- /dev/null
+++ b/test/IDE/complete_tf_315.swift
@@ -0,0 +1,5 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=COMPLETE
+
+import TensorFlow
+let t = Tensor#^COMPLETE^#
diff --git a/test/IRGen/sil_linkage.sil b/test/IRGen/sil_linkage.sil
index e926fd1..1230c73 100644
--- a/test/IRGen/sil_linkage.sil
+++ b/test/IRGen/sil_linkage.sil
@@ -5,7 +5,8 @@
// CHECK: define{{( dllexport)?}}{{( protected)?}} swiftcc void @public_fragile_function_test() {{.*}} {
// CHECK: define{{( dllexport)?}}{{( protected)?}} swiftcc void @public_transparent_fragile_function_test() {{.*}} {
// CHECK: define{{( dllexport)?}}{{( protected)?}} swiftcc void @public_transparent_function_test() {{.*}} {
-// CHECK: define weak_odr hidden swiftcc void @public_non_abi_function_test() {{.*}} {
+// SWIFT_ENABLE_TENSORFLOW: Fix for TF-587.
+// CHECK: define hidden swiftcc void @public_non_abi_function_test() {{.*}} {
// CHECK: define hidden swiftcc void @hidden_fragile_function_test() {{.*}} {
// CHECK: define linkonce_odr hidden swiftcc void @shared_fragile_function_test() {{.*}} {
// CHECK: define{{( internal)?}} swiftcc void @private_fragile_function_test() {{.*}} {
diff --git a/test/Incremental/Verifier/single-file-private/AnyObject.swift b/test/Incremental/Verifier/single-file-private/AnyObject.swift
index ff93b85..cd19de1 100644
--- a/test/Incremental/Verifier/single-file-private/AnyObject.swift
+++ b/test/Incremental/Verifier/single-file-private/AnyObject.swift
@@ -58,6 +58,9 @@
// expected-private-member {{Swift.Encodable.callAsFunction}}
// expected-private-member {{Swift.Decodable.callAsFunction}}
// expected-private-member {{Foundation._OptionalForKVO.callAsFunction}}
+// SWIFT_ENABLE_TENSORFLOW
+// expected-private-member {{Swift.Differentiable.callAsFunction}}
+// SWIFT_ENABLE_TENSORFLOW END
// expected-provides {{AnyObject}}
func lookupOnAnyObject(object: AnyObject) { // expected-provides {{lookupOnAnyObject}}
diff --git a/test/Incremental/Verifier/single-file/AnyObject.swift b/test/Incremental/Verifier/single-file/AnyObject.swift
index 0130674..419b933 100644
--- a/test/Incremental/Verifier/single-file/AnyObject.swift
+++ b/test/Incremental/Verifier/single-file/AnyObject.swift
@@ -58,6 +58,9 @@
// expected-private-member {{Swift.Encodable.callAsFunction}}
// expected-private-member {{Swift.Decodable.callAsFunction}}
// expected-private-member {{Foundation._OptionalForKVO.callAsFunction}}
+// SWIFT_ENABLE_TENSORFLOW
+// expected-private-member {{Swift.Differentiable.callAsFunction}}
+// SWIFT_ENABLE_TENSORFLOW END
// expected-provides {{AnyObject}}
func lookupOnAnyObject(object: AnyObject) { // expected-provides {{lookupOnAnyObject}}
diff --git a/test/Interpreter/SDK/objc_bridge_cast.swift b/test/Interpreter/SDK/objc_bridge_cast.swift
index a4e0e0f..65afa6a 100644
--- a/test/Interpreter/SDK/objc_bridge_cast.swift
+++ b/test/Interpreter/SDK/objc_bridge_cast.swift
@@ -3,6 +3,14 @@
// REQUIRES: objc_interop
+// SWIFT_ENABLE_TENSORFLOW
+// As of 07-12, this test fails with the following error:
+// <stdin>:56:1: error: CHECK-NOT: string occurred!
+// oh noes
+//
+// The test is temporarily marked as unsupported. Upgrading to macOS 10.14 may fix the problem.
+// UNSUPPORTED: objc_interop
+
// Test dynamic casts that bridge value types through the runtime.
import Foundation
diff --git a/test/Misc/stats_dir_instructions.swift b/test/Misc/stats_dir_instructions.swift
index 9c41253..31550b4 100644
--- a/test/Misc/stats_dir_instructions.swift
+++ b/test/Misc/stats_dir_instructions.swift
@@ -1,3 +1,6 @@
+// SWIFT_ENABLE_TENSORFLOW
+// UNSUPPORTED: tensorflow
+
// REQUIRES: OS=macosx
// RUN: %empty-directory(%t)
// RUN: %target-swiftc_driver -o %t/main -module-name main -stats-output-dir %t %s
diff --git a/test/NameBinding/astscope-differentiable-attr.swift b/test/NameBinding/astscope-differentiable-attr.swift
new file mode 100644
index 0000000..9514d2a
--- /dev/null
+++ b/test/NameBinding/astscope-differentiable-attr.swift
@@ -0,0 +1,54 @@
+// SWIFT_ENABLE_TENSORFLOW
+// Check that ASTScope lookup works for `@differentiable` attribute.
+
+// NOTE(TF-815): Without custom scope support, ASTScopeLookup crashes for
+// `@differentiable` attribute with where clauses on subscript and `var`
+// declarations.
+
+// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup
+
+struct Test<Element> {
+ var element: Element
+}
+extension Test: Differentiable where Element: Differentiable {}
+extension Test {
+ @differentiable(where Element: Differentiable)
+ init(_ element: Element) {
+ self.element = element
+ }
+
+ @differentiable(where Element: Differentiable)
+ func method() -> Element {
+ element
+ }
+
+ @differentiable(where T: Differentiable)
+ func method<T>(_ x: T) -> T {
+ x
+ }
+
+ // NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
+ @differentiable(where Element: Differentiable)
+ subscript(implicitGetterOnly_ : Void) -> Element {
+ element
+ }
+
+ subscript(explicitGetterAndSetter _: Void) -> Element {
+ @differentiable(where Element: Differentiable)
+ get { element }
+ set {}
+ }
+
+ // NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
+ @differentiable(where Element: Differentiable)
+ var computedProperty: Element {
+ element
+ }
+
+ var computedPropertyExplicitGetter: Element {
+ @differentiable(where Element: Differentiable)
+ get {
+ element
+ }
+ }
+}
diff --git a/test/NameBinding/astscope-differentiating-attr.swift b/test/NameBinding/astscope-differentiating-attr.swift
new file mode 100644
index 0000000..16a7ea6
--- /dev/null
+++ b/test/NameBinding/astscope-differentiating-attr.swift
@@ -0,0 +1,37 @@
+// SWIFT_ENABLE_TENSORFLOW
+// Check that ASTScope lookup works for `@derivative` attribute.
+
+// NOTE(TF-835): This test is only necessary because `@derivative` attribute
+// type-checking generates implicit `@differentiable` attributes on the
+// referenced declaration. Robust lowering for `@derivative` attributes should
+// make special logic regarding implicit `@differentiable` attributes
+// unnecessary.
+
+// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup
+
+struct Test<Element> {
+ var element: Element
+}
+extension Test: Differentiable where Element: Differentiable {}
+extension Test {
+ static func +(lhs: Self, rhs: Self) -> Self {
+ lhs
+ }
+ static func -(lhs: Self, rhs: Self) -> Self {
+ lhs
+ }
+}
+
+extension Test where Element : Differentiable {
+ @derivative(of: +)
+ internal static func _vjpAdd(lhs: Self, rhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs + rhs, { v in (v, v) })
+ }
+
+ @derivative(of: -)
+ internal static func _vjpSubtract(lhs: Self, rhs: Self)
+ -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (lhs + rhs, { v in (v, v) })
+ }
+}
diff --git a/test/Prototypes/BigInt.swift b/test/Prototypes/BigInt.swift
index f6c96ea..88b857d 100644
--- a/test/Prototypes/BigInt.swift
+++ b/test/Prototypes/BigInt.swift
@@ -16,6 +16,13 @@
// REQUIRES: executable_test
// REQUIRES: CPU=x86_64
+// SWIFT_ENABLE_TENSORFLOW
+// This test is currently unsupported because the addition of `+` operators
+// to the stdlib (via `VectorProtocol`) causes type-checking to fail.
+// Re-enable when type-checking no longer fails.
+// REQUIRES: no_tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// See rdar://problem/65251059
// UNSUPPORTED: windows
// rdar://problem/65015626
diff --git a/test/Reflection/box_descriptors.sil b/test/Reflection/box_descriptors.sil.disabled
similarity index 100%
rename from test/Reflection/box_descriptors.sil
rename to test/Reflection/box_descriptors.sil.disabled
diff --git a/test/Runtime/linux-fatal-backtrace.swift b/test/Runtime/linux-fatal-backtrace.swift
index edac2a7..c69a2f0 100644
--- a/test/Runtime/linux-fatal-backtrace.swift
+++ b/test/Runtime/linux-fatal-backtrace.swift
@@ -15,6 +15,14 @@
// run when optimizations are enabled.
// REQUIRES: swift_test_mode_optimize_none
+// SWIFT_ENABLE_TENSORFLOW
+// `utils/symbolicate-linux-fatal` fails with TensorFlow support because
+// libtensorflow.so is not linked properly. `import lldb` causes an import
+// error:
+// "ImportError: libtensorflow.so: cannot open shared object file"
+// The lldb swig setup scripts should be edited to fix this.
+// UNSUPPORTED: tensorflow
+
func funcB() {
fatalError("linux-fatal-backtrace");
}
diff --git a/test/SILGen/mangling.swift b/test/SILGen/mangling.swift
index 2d401ff..491f2ac 100644
--- a/test/SILGen/mangling.swift
+++ b/test/SILGen/mangling.swift
@@ -175,3 +175,11 @@
// CHECK-LABEL: sil hidden [ossa] @$s8mangling14varargsVsArray3arr1nySaySiGd_SStF : $@convention(thin) (@guaranteed Array<Array<Int>>, @guaranteed String) -> ()
func varargsVsArray(arr: [Int]..., n: String) { }
+
+// SWIFT_ENABLE_TENSORFLOW
+// CHECK-LABEL: sil hidden [ossa] @$s8mangling15funcVsDiffFunc12fnyS2fXE_tF : $@convention(thin) (@noescape @callee_guaranteed (Float) -> Float) -> ()
+func funcVsDiffFunc1(fn: (Float) -> Float) {}
+
+// CHECK-LABEL: sil hidden [ossa] @$s8mangling15funcVsDiffFunc22fnyS2fXF_tF : $@convention(thin) (@differentiable @noescape @callee_guaranteed (Float) -> Float) -> ()
+func funcVsDiffFunc2(fn: @differentiable (Float) -> Float) {}
+// SWIFT_ENABLE_TENSORFLOW END
diff --git a/test/SILOptimizer/prune-vtables.sil b/test/SILOptimizer/prune-vtables.sil
index 633dc9e..fa0d759 100644
--- a/test/SILOptimizer/prune-vtables.sil
+++ b/test/SILOptimizer/prune-vtables.sil
@@ -1,5 +1,6 @@
// RUN: %target-sil-opt -prune-vtables %s | %FileCheck --check-prefix=NOWMO %s
// RUN: %target-sil-opt -wmo -prune-vtables %s | %FileCheck --check-prefix=WMO %s
+// UNSUPPORTED: tensorflow
sil_stage canonical
diff --git a/test/Sema/Inputs/struct_elementary_functions_other_module.swift b/test/Sema/Inputs/struct_elementary_functions_other_module.swift
new file mode 100644
index 0000000..6d9c65e
--- /dev/null
+++ b/test/Sema/Inputs/struct_elementary_functions_other_module.swift
@@ -0,0 +1,14 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// expected-note @+1 24 {{type declared here}}
+struct OtherFileNonconforming : Equatable {
+ let float: Float
+ var double: Double
+}
+
+// expected-note @+1 24 {{type declared here}}
+struct GenericOtherFileNonconforming<T : ElementaryFunctions> : Equatable {
+ let x: T
+ var float: Float
+ var double: Double
+}
diff --git a/test/Sema/Inputs/struct_key_path_iterable_other_module.swift b/test/Sema/Inputs/struct_key_path_iterable_other_module.swift
new file mode 100644
index 0000000..3feea92
--- /dev/null
+++ b/test/Sema/Inputs/struct_key_path_iterable_other_module.swift
@@ -0,0 +1,14 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// expected-note @+1 {{type declared here}}
+struct OtherFileNonconforming {
+ var int: Int
+ var float: Float
+}
+
+// expected-note @+1 {{type declared here}}
+struct GenericOtherFileNonconforming<T : KeyPathIterable> {
+ var x: T
+ var int: Int
+ var float: Float
+}
diff --git a/test/Sema/Inputs/struct_pointwise_multiplicative_other_module.swift b/test/Sema/Inputs/struct_pointwise_multiplicative_other_module.swift
new file mode 100644
index 0000000..08a60c1
--- /dev/null
+++ b/test/Sema/Inputs/struct_pointwise_multiplicative_other_module.swift
@@ -0,0 +1,16 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+struct Base : PointwiseMultiplicative {}
+
+// expected-note @+1 3 {{type declared here}}
+struct OtherFileNonconforming : Equatable, AdditiveArithmetic {
+ var base: Base
+}
+
+// expected-note @+1 3 {{type declared here}}
+struct GenericOtherFileNonconforming<
+ T : PointwiseMultiplicative
+> : Equatable, AdditiveArithmetic {
+ var x: T
+ var base: Base
+}
diff --git a/test/Sema/Inputs/struct_vector_protocol_other_module.swift b/test/Sema/Inputs/struct_vector_protocol_other_module.swift
new file mode 100644
index 0000000..cb27919
--- /dev/null
+++ b/test/Sema/Inputs/struct_vector_protocol_other_module.swift
@@ -0,0 +1,12 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// expected-note @+1 {{type declared here}}
+struct OtherFileNonconforming : AdditiveArithmetic {
+ var float: Float
+}
+
+// expected-note @+1 {{type declared here}}
+struct GenericOtherFileNonconforming<T : VectorProtocol> : AdditiveArithmetic {
+ var x: T
+ var y: T
+}
diff --git a/test/Sema/differentiable_access_level.swift b/test/Sema/differentiable_access_level.swift
new file mode 100644
index 0000000..74c1e64
--- /dev/null
+++ b/test/Sema/differentiable_access_level.swift
@@ -0,0 +1,58 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-frontend -print-ast %s | %FileCheck %s
+
+// TF-1077: Verify access levels of `TangentVector` types and their memberwise
+// initializers, synthesized during `Differentiable` derived conformances.
+
+// `TangentVector` memberwise initializer access level should default to public,
+// for usability.
+
+public struct PublicStruct: Differentiable {}
+internal struct InternalStruct: Differentiable {}
+private struct PrivateStruct: Differentiable {}
+
+// CHECK-LABEL: public struct PublicStruct : Differentiable {
+// CHECK: internal init()
+// CHECK: public struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
+
+// CHECK-LABEL: internal struct InternalStruct : Differentiable {
+// CHECK: internal init()
+// CHECK: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
+
+// CHECK-LABEL: private struct PrivateStruct : Differentiable {
+// CHECK: internal init()
+// CHECK: fileprivate struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
+
+public class PublicClass: Differentiable {}
+internal class InternalClass: Differentiable {}
+private class PrivateClass: Differentiable {}
+
+// CHECK-LABEL: public class PublicClass : Differentiable {
+// CHECK: internal init()
+// CHECK: public struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
+
+// CHECK-LABEL: internal class InternalClass : Differentiable {
+// CHECK: internal init()
+// CHECK: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
+
+// CHECK-LABEL: private class PrivateClass : Differentiable {
+// CHECK: internal init()
+// CHECK: fileprivate struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative, ElementaryFunctions {
+// CHECK: public init()
+// CHECK: }
+// CHECK: }
diff --git a/test/Sema/struct_differentiable_member_types.swift b/test/Sema/struct_differentiable_member_types.swift
new file mode 100644
index 0000000..e213871
--- /dev/null
+++ b/test/Sema/struct_differentiable_member_types.swift
@@ -0,0 +1,19 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-typecheck-verify-swift
+
+// Test usages of synthesized `Differentiable` member struct types.
+
+// TF-466: Test conforming a synthesized member type to a protocol with
+// property requirements.
+protocol Proto {
+ var weight: Float { get }
+}
+struct Foo : Differentiable {
+ var weight: Float
+}
+// Note: the global variable here is necessary for type-checking to pass
+// when extending a synthesized member type. Enabling general extensions of
+// synthesized member types require extra non-trivial work, due to the
+// current type-checker design.
+let randomGlobal = 1
+extension Foo.TangentVector : Proto {}
diff --git a/test/Sema/struct_elementary_functions.swift b/test/Sema/struct_elementary_functions.swift
new file mode 100644
index 0000000..b4c17b7
--- /dev/null
+++ b/test/Sema/struct_elementary_functions.swift
@@ -0,0 +1,91 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_elementary_functions_other_module.swift
+
+struct Empty : ElementaryFunctions {}
+func testEmpty() {
+ _ = Empty()
+}
+
+struct Float2: ElementaryFunctions {
+ let a: Float
+ var b: Float
+}
+func testFloat2() {
+ _ = Float2(a: 1, b: 1)
+}
+
+// Test generic type.
+struct Vector2<T : ElementaryFunctions>: ElementaryFunctions {
+ let x: T
+ var y: T
+}
+func testVector2() {
+ _ = Vector2<Double>(x: 1, y: 1)
+}
+
+// Test nested type.
+struct Nested: ElementaryFunctions {
+ let float2: Float2
+ var float: Float
+}
+func testNested(float2: Float2) {
+ _ = Nested(float2: float2, float: 1)
+}
+
+// Test mixed type.
+struct Mixed: ElementaryFunctions {
+ let nested: Nested
+ var float = Float(1)
+ var double: Double
+}
+func testMixed(nested: Nested) {
+ _ = Mixed(nested: nested, float: 1, double: 1)
+}
+
+// Test type in generic context.
+struct A<T> {
+ struct B<U, V> {
+ struct GenericContextNested : ElementaryFunctions {
+ var nested: Nested
+ let float: Float
+ var double = Double(2)
+ }
+ }
+}
+func testGenericContext<T, U, V>(nested: Nested) -> A<T>.B<U, V>.GenericContextNested {
+ A<T>.B<U, V>.GenericContextNested(nested: nested, float: 1, double: 1)
+}
+
+// Test extension.
+struct Extended {
+ var x: Float
+}
+extension Extended : ElementaryFunctions {}
+
+// Test extension of generic type.
+struct GenericExtended<T> {
+ var x: T
+}
+extension GenericExtended : ElementaryFunctions where T : ElementaryFunctions {}
+
+// Test memberwise initializer synthesis.
+struct NoMemberwiseInitializer<T : ElementaryFunctions> : ElementaryFunctions {
+ var value: T
+ init(randomLabel value: T) { self.value = value }
+}
+struct NoMemberwiseInitializerExtended<T> {
+ var value: T
+ init(_ value: T) {
+ self.value = value
+ }
+}
+extension NoMemberwiseInitializerExtended: ElementaryFunctions
+ where T : ElementaryFunctions {}
+
+// Test derived conformances in disallowed contexts.
+
+// expected-error @+1 24 {{extension outside of file declaring struct 'OtherFileNonconforming' prevents automatic synthesis of}}
+extension OtherFileNonconforming : ElementaryFunctions {}
+
+// expected-error @+1 24 {{extension outside of file declaring generic struct 'GenericOtherFileNonconforming' prevents automatic synthesis of}}
+extension GenericOtherFileNonconforming : ElementaryFunctions {}
diff --git a/test/Sema/struct_key_path_iterable.swift b/test/Sema/struct_key_path_iterable.swift
new file mode 100644
index 0000000..1922e0e
--- /dev/null
+++ b/test/Sema/struct_key_path_iterable.swift
@@ -0,0 +1,145 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_key_path_iterable_other_module.swift
+
+struct Tensor<Scalar> {
+ var scalar: Scalar
+ init(_ scalar: Scalar) {
+ self.scalar = scalar
+ }
+}
+extension Tensor : Equatable where Scalar : Equatable {}
+extension Tensor : AdditiveArithmetic where Scalar : AdditiveArithmetic {}
+extension Tensor : VectorProtocol where Scalar : AdditiveArithmetic {
+ typealias VectorSpaceScalar = Scalar
+ func adding(_: Scalar) -> Self { self }
+ func subtracting(_: Scalar) -> Self { self }
+ func scaled(by scalar: Scalar) -> Self { self }
+}
+
+// Synthesis should work for empty structs.
+// `allKeyPaths` simply returns `[]`.
+struct Empty : KeyPathIterable {}
+
+struct Parameters : KeyPathIterable {
+ var w: Float
+ var b: Float
+}
+func testParameters() {
+ var params = Parameters(w: 1, b: 2)
+ assert(params.allKeyPaths.count == 2)
+ assert(params.allKeyPaths(to: Float.self).count == 2)
+ assert(params.allKeyPaths(to: Int.self).count == 0)
+ for kp in params.allWritableKeyPaths(to: Float.self) {
+ params[keyPath: kp] *= 2
+ }
+}
+
+struct TensorParameters : KeyPathIterable {
+ var w: Tensor<Float>
+ var b: Tensor<Float>
+
+ // Non-stored-property members should not affect synthesis.
+ var computed: Float {
+ return (w + b).scalar
+ }
+ func foo() {}
+ typealias Foo = Int
+}
+
+extension TensorParameters : VectorProtocol {
+ static var zero: TensorParameters {
+ return TensorParameters(w: Tensor(0), b: Tensor(0))
+ }
+ static func + (lhs: TensorParameters, rhs: TensorParameters) -> TensorParameters {
+ return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
+ }
+ static func - (lhs: TensorParameters, rhs: TensorParameters) -> TensorParameters {
+ return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
+ }
+ typealias VectorSpaceScalar = Float
+ func adding(_ x: VectorSpaceScalar) -> TensorParameters {
+ return TensorParameters(w: w.adding(x), b: b.adding(x))
+ }
+ func subtracting(_ x: VectorSpaceScalar) -> TensorParameters {
+ return TensorParameters(w: w.subtracting(x), b: b.subtracting(x))
+ }
+ func scaled(by scalar: VectorSpaceScalar) -> TensorParameters {
+ return TensorParameters(w: w.scaled(by: scalar), b: b.scaled(by: scalar))
+ }
+}
+
+struct HeterogeneousParameters : KeyPathIterable {
+ var float: Float
+ var double: Double
+ var tensor: Tensor<Float>
+ var params: Parameters
+}
+func testHeterogenousParameters(_ params: Parameters) {
+ let hetero = HeterogeneousParameters(float: 0, double: 0,
+ tensor: Tensor(0), params: params)
+ assert(hetero.allKeyPaths.count == 4)
+ assert(hetero.recursivelyAllKeyPaths.count == 6)
+ assert(hetero.allKeyPaths(to: Float.self).count == 1)
+ assert(hetero.recursivelyAllKeyPaths(to: Float.self).count == 3)
+ assert(hetero.allKeyPaths(to: Tensor<Float>.self).count == 1)
+ assert(hetero.allKeyPaths(to: Parameters.self).count == 1)
+ assert(hetero.allKeyPaths(to: Int.self).count == 0)
+}
+
+// Test type in generic context.
+struct A<T> {
+ struct B<U, V> {
+ struct GenericContextParams : KeyPathIterable {
+ var params: Parameters
+ var float: Float
+ }
+ }
+}
+
+// Test generic optimizer.
+
+struct DummyOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>
+ where P : VectorProtocol, P.VectorSpaceScalar == Scalar
+{
+ let learningRate: Scalar
+ var firstMoments: P = P.zero
+
+ mutating func fitParameters(
+ parameters: inout P, withGradients gradients: P
+ ) {
+ for kp in parameters.recursivelyAllWritableKeyPaths(to: Tensor<Scalar>.self) {
+ firstMoments[keyPath: kp].scale(by: learningRate)
+ parameters[keyPath: kp] -= parameters[keyPath: kp].scaled(by: learningRate)
+ }
+ }
+}
+
+// TF-575: Test overloaded key path component name.
+protocol NameLookupConflictProtocol {}
+extension NameLookupConflictProtocol {
+ func member() {}
+}
+struct NameLookupConflict: NameLookupConflictProtocol & KeyPathIterable {
+ // Note: `NameLookupConflict.member` is overloaded with
+ // `MemberNameConflictProtocol.member`.
+ // This makes the following generated code fail:
+ //
+ // var allKeyPaths: [PartialKeyPath<Self>] {
+ // [\Self.member]
+ // }
+ //
+ // error: cannot convert value of type
+ // 'WritableKeyPath<NameLookupConflict, Float>' to expected element type
+ // 'PartialKeyPath<NameLookupConflict>'
+ var member: Float
+}
+
+// Test derived conformances in disallowed contexts.
+
+// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'KeyPathIterable'}}
+// expected-error @+1 {{extension outside of file declaring struct 'OtherFileNonconforming' prevents automatic synthesis of 'AllKeyPaths' for protocol 'KeyPathIterable'}}
+extension OtherFileNonconforming : KeyPathIterable {}
+
+// expected-error @+2 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'KeyPathIterable'}}
+// expected-error @+1 {{extension outside of file declaring generic struct 'GenericOtherFileNonconforming' prevents automatic synthesis of 'AllKeyPaths' for protocol 'KeyPathIterable'}}
+extension GenericOtherFileNonconforming : KeyPathIterable {}
diff --git a/test/Sema/struct_pointwise_multiplicative.swift b/test/Sema/struct_pointwise_multiplicative.swift
new file mode 100644
index 0000000..8d64c6f
--- /dev/null
+++ b/test/Sema/struct_pointwise_multiplicative.swift
@@ -0,0 +1,94 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_pointwise_multiplicative_other_module.swift
+
+func testPointwiseMultiplicative<T : PointwiseMultiplicative>(
+ _ x: inout T
+) {
+ // Test `PointwiseMultiplicative` requirements: `one`, `.*`.
+ let one = T.one
+ x .*= x .* one
+}
+
+struct Empty : PointwiseMultiplicative {}
+func testEmpty() {
+ var empty = Empty()
+ testPointwiseMultiplicative(&empty)
+}
+
+// Test generic type.
+struct Vector2<T : PointwiseMultiplicative>: PointwiseMultiplicative {
+ var x: T
+ var y: T
+}
+func testVector2() {
+ var vec2 = Vector2<Empty>(x: Empty(), y: Empty.one)
+ testPointwiseMultiplicative(&vec2)
+}
+
+// Test nested type.
+struct Nested: PointwiseMultiplicative {
+ var empty: Empty
+ var vec2: Vector2<Empty>
+}
+func testNested(vec2: Vector2<Empty>) {
+ var nested = Nested(empty: Empty(), vec2: vec2)
+ testPointwiseMultiplicative(&nested)
+}
+
+// Test type in generic context.
+struct A<T> {
+ struct B<U, V> {
+ struct GenericContextNested : PointwiseMultiplicative {
+ var empty: Empty
+ var nested: Nested
+ }
+ }
+}
+func testGenericContext<T, U, V>(nested: Nested) -> A<T>.B<U, V>.GenericContextNested {
+ var genericNested =
+ A<T>.B<U, V>.GenericContextNested(empty: Empty(), nested: nested)
+ testPointwiseMultiplicative(&genericNested)
+ return genericNested
+}
+
+// Test extension.
+struct Extended {
+ var empty: Empty
+}
+extension Extended : Equatable, AdditiveArithmetic, PointwiseMultiplicative {}
+
+// Test extension of generic type.
+struct GenericExtended<T> {
+ var x: T
+}
+extension GenericExtended : Equatable, AdditiveArithmetic, PointwiseMultiplicative
+ where T : PointwiseMultiplicative {}
+
+// Test memberwise initializer synthesis.
+struct NoMemberwiseInitializer<T : PointwiseMultiplicative> : PointwiseMultiplicative {
+ var value: T
+ init(randomLabel value: T) { self.value = value }
+}
+struct NoMemberwiseInitializerCustomOne: PointwiseMultiplicative {
+ var x: Empty
+ static var one: Self { return NoMemberwiseInitializerCustomOne(Empty()) }
+ init(_ x: Empty) {
+ self.x = x
+ }
+}
+struct NoMemberwiseInitializerExtended<T> {
+ var value: T
+ init(_ value: T) {
+ self.value = value
+ }
+}
+extension NoMemberwiseInitializerExtended: Equatable, AdditiveArithmetic, PointwiseMultiplicative
+ where T : PointwiseMultiplicative {}
+
+// Test derived conformances in disallowed contexts.
+
+// expected-error @+1 3 {{extension outside of file declaring struct 'OtherFileNonconforming' prevents automatic synthesis of}}
+extension OtherFileNonconforming : PointwiseMultiplicative {}
+
+// expected-error @+1 3 {{extension outside of file declaring generic struct 'GenericOtherFileNonconforming' prevents automatic synthesis of}}
+extension GenericOtherFileNonconforming : PointwiseMultiplicative {}
diff --git a/test/Sema/struct_vector_protocol.swift b/test/Sema/struct_vector_protocol.swift
new file mode 100644
index 0000000..dd7567c
--- /dev/null
+++ b/test/Sema/struct_vector_protocol.swift
@@ -0,0 +1,142 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_vector_protocol_other_module.swift
+
+func testVectorProtocol<T : VectorProtocol>(
+_ x: inout T, scalar: T.VectorSpaceScalar
+) {
+ // Test `AdditiveArithmetic` requirements: `zero`, `+`, `-`.
+ let zero = T.zero
+ x += x + zero
+ x += x - zero
+ // Test `VectorProtocol` requirements: `VectorSpaceScalar`, `adding(_:)`, `add(_:)`
+ // `subtracting(_:)`, `subtract(_:)`, `scaled(by:)`, and `scale(by:)`.
+ x.add(scalar)
+ x.add(scalar)
+ x.scale(by: scalar)
+ _ = x.adding(scalar)
+ _ = x.subtracting(scalar)
+ _ = x.scaled(by: scalar)
+
+ // NOTE: Operators have been disabled for type checker performance reasons.
+ // x += x + zero
+ // x -= x - zero
+ // Test `VectorProtocol` requirements: `VectorSpaceScalar`, `+`, `-`, `*`.
+ // x += scalar
+ // x -= scalar
+ // x *= scalar
+ // _ = x + scalar
+ // _ = scalar + x
+ // _ = x - scalar
+ // _ = scalar * x
+ // _ = x * scalar
+}
+
+struct Float2: VectorProtocol {
+ var x: Float
+ var y: Float
+}
+func testFloat2() {
+ var float2 = Float2(x: 1, y: 1)
+ testVectorProtocol(&float2, scalar: 1)
+}
+
+// Test generic type.
+struct Vector2<T : VectorProtocol>: VectorProtocol {
+ var x: T
+ var y: T
+}
+func testVector2(float2: Float2) {
+ _ = Vector2<Double>(x: 1, y: 1)
+ _ = Vector2<Float2>(x: float2, y: float2)
+}
+func testGeneric<T : VectorProtocol>(vec2: inout Vector2<T>, scalar: T.VectorSpaceScalar) {
+ testVectorProtocol(&vec2, scalar: scalar)
+}
+
+// Test nested types.
+struct Nested: VectorProtocol {
+ var float2: Float2
+ var float: Float
+}
+func testNested(float2: Float2) {
+ var nested = Nested(float2: float2, float: 1)
+ testVectorProtocol(&nested, scalar: 1)
+}
+
+struct NestedGeneric: VectorProtocol {
+ var vec2: Vector2<Float>
+ var float2: Float2
+ var float: Float
+}
+func testNestedGeneric(float2: Float2) {
+ var nestedGeneric = NestedGeneric(vec2: Vector2<Float>(x: 1, y: 1),
+ float2: float2, float: 1)
+ testVectorProtocol(&nestedGeneric, scalar: 1)
+}
+
+// Test type in generic context.
+struct A<T> {
+ struct B<U, V> {
+ struct GenericContextNested : VectorProtocol {
+ var float2: Float2
+ var float: Float
+ }
+ }
+}
+func testGenericContext<T, U, V>(float2: Float2) -> A<T>.B<U, V>.GenericContextNested {
+ var genericNested =
+ A<T>.B<U, V>.GenericContextNested(float2: float2, float: 1)
+ testVectorProtocol(&genericNested, scalar: 1)
+ return genericNested
+}
+
+// Test extension.
+struct Extended {
+ var x: Float
+}
+extension Extended : Equatable, AdditiveArithmetic, VectorProtocol {}
+
+// Test extension of generic type.
+struct GenericExtended<T> {
+ var x: T
+}
+extension GenericExtended : Equatable, AdditiveArithmetic, VectorProtocol where T : VectorProtocol {}
+
+// Test errors.
+
+// expected-error @+1 {{type 'Empty' does not conform to protocol 'VectorProtocol'}}
+struct Empty : VectorProtocol {}
+
+// Test type whose members conform to `VectorProtocol`
+// but have different `VectorSpaceScalar` associated type.
+// expected-error @+1 {{type 'InvalidMixedScalar' does not conform to protocol 'VectorProtocol'}}
+struct InvalidMixedScalar: VectorProtocol {
+ var float: Float
+ var double: Double
+}
+
+// Test memberwise initializer synthesis.
+struct NoMemberwiseInitializer<T : VectorProtocol> : VectorProtocol {
+ var value: T
+ init(randomLabel value: T) { self.value = value }
+}
+struct NoMemberwiseInitializerExtended<T> {
+ var value: T
+ init(_ value: T) {
+ self.value = value
+ }
+}
+extension NoMemberwiseInitializerExtended : Equatable, AdditiveArithmetic
+ where T : AdditiveArithmetic {}
+extension NoMemberwiseInitializerExtended : VectorProtocol
+ where T : VectorProtocol {}
+
+// Test derived conformances in disallowed contexts.
+
+// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'VectorProtocol'}}
+// expected-error @+1 {{extension outside of file declaring struct 'OtherFileNonconforming' prevents automatic synthesis of 'VectorSpaceScalar' for protocol 'VectorProtocol'}}
+extension OtherFileNonconforming : VectorProtocol {}
+
+// expected-error @+2 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'VectorProtocol'}}
+// expected-error @+1 {{extension outside of file declaring generic struct 'GenericOtherFileNonconforming' prevents automatic synthesis of 'VectorSpaceScalar' for protocol 'VectorProtocol'}}
+extension GenericOtherFileNonconforming : VectorProtocol {}
diff --git a/test/Serialization/derivative_attr.swift b/test/Serialization/derivative_attr.swift
new file mode 100644
index 0000000..ba93dca
--- /dev/null
+++ b/test/Serialization/derivative_attr.swift
@@ -0,0 +1,65 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
+// RUN: llvm-bcanalyzer %t/derivative_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
+// RUN: %target-sil-opt -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
+
+// BCANALYZER-NOT: UnknownCode
+
+func add(x: Float, y: Float) -> Float {
+ return x + y
+}
+// CHECK: @derivative(of: add, wrt: x)
+@derivative(of: add, wrt: x)
+func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
+ return (x + y, { $0 })
+}
+// CHECK: @derivative(of: add, wrt: (x, y))
+@derivative(of: add)
+func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
+ return (x + y, { ($0, $0) })
+}
+
+func generic<T : Numeric>(x: T) -> T {
+ return x
+}
+// CHECK: @derivative(of: generic, wrt: x)
+@derivative(of: generic)
+func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
+ where T : Numeric, T : Differentiable
+{
+ return (x, { v in v })
+}
+
+protocol InstanceMethod : Differentiable {
+ func foo(_ x: Self) -> Self
+ func bar<T : Differentiable>(_ x: T) -> Self
+}
+extension InstanceMethod {
+ func foo(_ x: Self) -> Self { self }
+ func bar<T : Differentiable>(_ x: T) -> Self { self }
+}
+extension InstanceMethod {
+ // CHECK: @derivative(of: foo, wrt: (self, x))
+ @derivative(of: foo)
+ func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
+ return (x, { ($0, $0) })
+ }
+
+ // CHECK: @derivative(of: bar, wrt: (self, x))
+ @derivative(of: bar, wrt: (self, x))
+ func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector)
+ where T == T.TangentVector
+ {
+ return (self, { dself, dx in dself })
+ }
+
+ // CHECK: @derivative(of: bar, wrt: (self, x))
+ @derivative(of: bar, wrt: (self, x))
+ func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T))
+ where T == T.TangentVector
+ {
+ return (self, { v in (v, .zero) })
+ }
+}
diff --git a/test/Serialization/differentiable_attr.swift b/test/Serialization/differentiable_attr.swift
new file mode 100644
index 0000000..97be011
--- /dev/null
+++ b/test/Serialization/differentiable_attr.swift
@@ -0,0 +1,96 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
+// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
+// RUN: %target-sil-opt -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
+
+// BCANALYZER-NOT: UnknownCode
+
+// CHECK: @differentiable(wrt: x)
+// CHECK-NEXT: func simple(x: Float) -> Float
+@differentiable
+func simple(x: Float) -> Float {
+ return x
+}
+
+// CHECK: @differentiable(linear, wrt: x)
+// CHECK-NEXT: func simple2(x: Float) -> Float
+@differentiable(linear)
+func simple2(x: Float) -> Float {
+ return x
+}
+
+// CHECK: @differentiable(linear, wrt: x)
+// CHECK-NEXT: func simple4(x: Float) -> Float
+@differentiable(linear, wrt: x)
+func simple4(x: Float) -> Float {
+ return x
+}
+
+func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
+ return (x, { v in v })
+}
+
+func vjpSimple(x: Float) -> (Float, (Float) -> Float) {
+ return (x, { v in v })
+}
+
+// CHECK: @differentiable(wrt: x)
+// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
+@differentiable(wrt: x)
+func testWrtClause(x: Float, y: Float) -> Float {
+ return x + y
+}
+
+struct InstanceMethod : Differentiable {
+ // CHECK: @differentiable(wrt: (self, y))
+ // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
+ @differentiable(wrt: (self, y))
+ func testWrtClause(x: Float, y: Float) -> Float {
+ return x + y
+ }
+}
+
+// CHECK: @differentiable(wrt: x where T : Differentiable)
+// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric
+@differentiable(where T : Differentiable)
+func testOnlyWhereClause<T : Numeric>(x: T) -> T {
+ return x
+}
+
+// CHECK: @differentiable(wrt: x where T : Differentiable)
+// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric
+@differentiable(where T : Differentiable)
+func testWhereClause<T : Numeric>(x: T) -> T {
+ return x
+}
+
+protocol P {}
+extension P {
+ // CHECK: @differentiable(wrt: self where Self : Differentiable)
+ // CHECK-NEXT: func testWhereClauseMethod() -> Self
+ @differentiable(wrt: self where Self : Differentiable)
+ func testWhereClauseMethod() -> Self {
+ return self
+ }
+}
+
+// CHECK: @differentiable(wrt: x where T : Differentiable, T == T.TangentVector)
+// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric
+@differentiable(where T : Differentiable, T == T.TangentVector)
+func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T {
+ return x
+}
+
+extension P {
+ // CHECK: @differentiable(wrt: self where Self : Differentiable, Self == Self.TangentVector)
+ // CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self
+ @differentiable(wrt: self where Self.TangentVector == Self, Self : Differentiable)
+ func testWhereClauseMethodTypeConstraint() -> Self {
+ return self
+ }
+}
+
+// CHECK: func testDifferentiableParam(f: @differentiable (Float) -> Float)
+func testDifferentiableParam(f: @differentiable (Float) -> Float) {}
diff --git a/test/Serialization/transpose_attr.swift b/test/Serialization/transpose_attr.swift
new file mode 100644
index 0000000..865cd3d
--- /dev/null
+++ b/test/Serialization/transpose_attr.swift
@@ -0,0 +1,97 @@
+// SWIFT_ENABLE_TENSORFLOW
+
+// RUN: %empty-directory(%t)
+// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
+// RUN: llvm-bcanalyzer %t/transpose_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
+// RUN: %target-sil-opt -enable-sil-verify-all %t/transpose_attr.swiftmodule -o - | %FileCheck %s
+
+// BCANALYZER-NOT: UnknownCode
+
+// Dummy `Differentiable`-conforming type.
+struct S: Differentiable & AdditiveArithmetic {
+ static var zero: S { S() }
+ static func + (_: S, _: S) -> S { S() }
+ static func - (_: S, _: S) -> S { S() }
+ typealias TangentVector = S
+}
+
+// Test top-level functions.
+
+func top1(_ x: S) -> S {
+ x
+}
+// CHECK: @transpose(of: top1, wrt: 0)
+@transpose(of: top1, wrt: 0)
+func transposeTop1(v: S) -> S {
+ v
+}
+
+func top2<T, U>(_ x: T, _ i: Int, _ y: U) -> U {
+ y
+}
+// CHECK: @transpose(of: top2, wrt: (0, 2))
+@transpose(of: top2, wrt: (0, 2))
+func transposeTop2<T, U>(_ int: Int, v: U) -> (T, U)
+where T: Differentiable, U: Differentiable,
+ T == T.TangentVector, U == U.TangentVector {
+ (.zero, v)
+}
+
+// Test instance methods.
+
+extension S {
+ func instanceMethod(_ other: S) -> S {
+ self + other
+ }
+
+ // CHECK: @transpose(of: instanceMethod, wrt: 0)
+ @transpose(of: instanceMethod, wrt: 0)
+ func transposeInstanceMethod(v: S) -> (S, S) {
+ (v, v)
+ }
+
+ // CHECK: @transpose(of: instanceMethod, wrt: self)
+ @transpose(of: instanceMethod, wrt: self)
+ static func transposeInstanceMethodWrtSelf(_ other: S, v: S) -> S {
+ v
+ }
+}
+
+// Test static methods.
+
+extension S {
+ static func staticMethod(x: S) -> S {
+ x
+ }
+
+ // FIXME(TF-1063): `@transpose` type-checking crash for static methods.
+ // CHECK-FIXME: @transpose(of: staticMethod, wrt: 0)
+ /*
+ @transpose(of: staticMethod, wrt: 0)
+ static func transposeStaticMethod() -> S {
+ self
+ }
+ */
+}
+
+// Test computed properties.
+extension S {
+ var computedProperty: S { self }
+
+ // CHECK: @transpose(of: computedProperty, wrt: self)
+ @transpose(of: computedProperty, wrt: self)
+ static func transposeProperty(v: Self) -> Self {
+ v
+ }
+}
+
+// Test subscripts.
+extension S {
+ subscript<T: Differentiable>(x: T) -> Self { self }
+
+ // CHECK: @transpose(of: subscript, wrt: self)
+ @transpose(of: subscript(_:), wrt: self)
+ static func transposeSubscript<T: Differentiable>(x: T, v: Self) -> Self {
+ v
+ }
+}
diff --git a/test/SourceKit/CodeComplete/complete_sequence_edit.swift b/test/SourceKit/CodeComplete/complete_sequence_edit.swift.disabled
similarity index 100%
rename from test/SourceKit/CodeComplete/complete_sequence_edit.swift
rename to test/SourceKit/CodeComplete/complete_sequence_edit.swift.disabled
diff --git a/test/SourceKit/CompileNotifications/diagnostics.swift b/test/SourceKit/CompileNotifications/diagnostics.swift.disabled
similarity index 100%
rename from test/SourceKit/CompileNotifications/diagnostics.swift
rename to test/SourceKit/CompileNotifications/diagnostics.swift.disabled
diff --git a/test/SourceKit/CursorInfo/in_memory_clang_module_cache.swift b/test/SourceKit/CursorInfo/in_memory_clang_module_cache.swift
new file mode 100644
index 0000000..f1718a2
--- /dev/null
+++ b/test/SourceKit/CursorInfo/in_memory_clang_module_cache.swift
@@ -0,0 +1,15 @@
+let foo: Int = 10
+foo
+
+// Checks that the SourceKit request succeeded.
+// CHECK-SOURCEKIT: source.lang.swift.ref.var.global (1:5-1:8)
+
+// Checks that nothing has been written into the module cache on the real
+// filesystem.
+// CHECK-LS-NOT: ModuleCache
+
+// RUN: %empty-directory(%t)
+// RUN: %sourcekitd-test -in-memory-clang-module-cache -req=cursor -pos=2:1 %s -- %s -module-cache-path %t/ModuleCache | %FileCheck --check-prefix=CHECK-SOURCEKIT %s
+// RUN: ls -l %t | %FileCheck --check-prefix=CHECK-LS %s
+
+// REQUIRES: sourcekit_use_inproc_library
diff --git a/test/SourceKit/DocSupport/doc_clang_module.swift b/test/SourceKit/DocSupport/doc_clang_module.swift
index e65c713..47660ea 100644
--- a/test/SourceKit/DocSupport/doc_clang_module.swift
+++ b/test/SourceKit/DocSupport/doc_clang_module.swift
@@ -1,3 +1,12 @@
+// SWIFT_ENABLE_TENSORFLOW
+//
+// NOTE: This test is disabled on the 'tensorflow' branch because we are
+// actively developing and changing attributes such as `@differentiable` and
+// `@compilerEvaluable`. When various features get merged to 'master', a
+// canonical update of these tests will be included.
+//
+// UNSUPPORTED: tensorflow
+
// REQUIRES: objc_interop
// FIXME: the test output we're comparing to is specific to macOS.
diff --git a/test/SourceKit/InterfaceGen/gen_header.swift b/test/SourceKit/InterfaceGen/gen_header.swift.disabled
similarity index 100%
rename from test/SourceKit/InterfaceGen/gen_header.swift
rename to test/SourceKit/InterfaceGen/gen_header.swift.disabled
diff --git a/test/SourceKit/InterfaceGen/gen_header_swift_args.swift b/test/SourceKit/InterfaceGen/gen_header_swift_args.swift.disabled
similarity index 100%
rename from test/SourceKit/InterfaceGen/gen_header_swift_args.swift
rename to test/SourceKit/InterfaceGen/gen_header_swift_args.swift.disabled
diff --git a/test/SourceKit/Mixed/complete_twice_bridging_header.swift b/test/SourceKit/Mixed/complete_twice_bridging_header.swift.disabled
similarity index 100%
rename from test/SourceKit/Mixed/complete_twice_bridging_header.swift
rename to test/SourceKit/Mixed/complete_twice_bridging_header.swift.disabled
diff --git a/test/SourceKit/Sema/enum-toraw/enum-toraw.swift b/test/SourceKit/Sema/enum-toraw/enum-toraw.swift.disabled
similarity index 100%
rename from test/SourceKit/Sema/enum-toraw/enum-toraw.swift
rename to test/SourceKit/Sema/enum-toraw/enum-toraw.swift.disabled
diff --git a/test/SourceKit/Sema/sil_diags.swift b/test/SourceKit/Sema/sil_diags.swift.disabled
similarity index 100%
rename from test/SourceKit/Sema/sil_diags.swift
rename to test/SourceKit/Sema/sil_diags.swift.disabled
diff --git a/test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds b/test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds
index 61dfcac..f441799 100644
--- a/test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds
+++ b/test/Syntax/Outputs/round_trip_parse_gen.swift.withkinds
@@ -595,4 +595,47 @@
#"<StringSegment>abc</StringSegment>"#</StringLiteralExpr><StringLiteralExpr>
#"<StringSegment>abc </StringSegment><ExpressionSegment>\#(<TupleExprElement><IdentifierExpr>foo</IdentifierExpr></TupleExprElement>)</ExpressionSegment><StringSegment></StringSegment>"#</StringLiteralExpr><StringLiteralExpr>
##"<StringSegment>abc</StringSegment>"##</StringLiteralExpr><StringLiteralExpr>
-##"<StringSegment>abc </StringSegment><ExpressionSegment>\##(<TupleExprElement><IdentifierExpr>foo</IdentifierExpr></TupleExprElement>)</ExpressionSegment><StringSegment></StringSegment>"##</StringLiteralExpr>
+##"<StringSegment>abc </StringSegment><ExpressionSegment>\##(<TupleExprElement><IdentifierExpr>foo</IdentifierExpr></TupleExprElement>)</ExpressionSegment><StringSegment></StringSegment>"##</StringLiteralExpr><FunctionDecl><Attribute>
+
+// SWIFT_ENABLE_TENSORFLOW
+@differentiable</Attribute>
+func bar<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@differentiable(<DifferentiableAttributeArguments><GenericWhereClause>where <GenericRequirement><ConformanceRequirement><SimpleTypeIdentifier>T </SimpleTypeIdentifier>: <SimpleTypeIdentifier>FloatingPoint</SimpleTypeIdentifier></ConformanceRequirement></GenericRequirement></GenericWhereClause></DifferentiableAttributeArguments>)</Attribute>
+func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Numeric</SimpleTypeIdentifier></GenericParameter>></GenericParameterClause><FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>T</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>T</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>T </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@differentiable(<DifferentiableAttributeArguments><DifferentiabilityParamsClause>wrt: <DifferentiabilityParam>x</DifferentiabilityParam></DifferentiabilityParamsClause></DifferentiableAttributeArguments>)</Attribute>
+func bar<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>_: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@differentiable(<DifferentiableAttributeArguments><DifferentiabilityParamsClause>wrt: <DifferentiabilityParams>(<DifferentiabilityParam>self, </DifferentiabilityParam><DifferentiabilityParam>x, </DifferentiabilityParam><DifferentiabilityParam>y</DifferentiabilityParam>)</DifferentiabilityParams></DifferentiabilityParamsClause></DifferentiableAttributeArguments>)</Attribute>
+func bar<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@differentiable(<DifferentiableAttributeArguments><DifferentiabilityParamsClause>wrt: <DifferentiabilityParams>(<DifferentiabilityParam>self, </DifferentiabilityParam><DifferentiabilityParam>x, </DifferentiabilityParam><DifferentiabilityParam>y</DifferentiabilityParam>) </DifferentiabilityParams></DifferentiabilityParamsClause><GenericWhereClause>where <GenericRequirement><ConformanceRequirement><SimpleTypeIdentifier>T </SimpleTypeIdentifier>: <SimpleTypeIdentifier>FloatingPoint</SimpleTypeIdentifier></ConformanceRequirement></GenericRequirement></GenericWhereClause></DifferentiableAttributeArguments>)</Attribute>
+func bar<GenericParameterClause><<GenericParameter>T : <SimpleTypeIdentifier>Numeric</SimpleTypeIdentifier></GenericParameter>></GenericParameterClause><FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>T</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>T</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>T </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{ <ReturnStmt>return <IntegerLiteralExpr>1 </IntegerLiteralExpr></ReturnStmt>}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@derivative(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>-</QualifiedDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
+func negateDerivative<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>)</ParameterClause><ReturnClause>
+ -> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
+ return <TupleExpr>(<TupleExprElement><PrefixOperatorExpr>-<IdentifierExpr>x</IdentifierExpr></PrefixOperatorExpr>, </TupleExprElement><TupleExprElement><ClosureExpr>{ <ClosureSignature><ClosureParam>v </ClosureParam>in </ClosureSignature><PrefixOperatorExpr>-<IdentifierExpr>v </IdentifierExpr></PrefixOperatorExpr>}</ClosureExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
+}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@derivative(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>baz<DeclNameArguments>(<DeclNameArgument>label:</DeclNameArgument><DeclNameArgument>_:</DeclNameArgument>)</DeclNameArguments></QualifiedDeclName>, <DifferentiabilityParamsClause>wrt: <DifferentiabilityParams>(<DifferentiabilityParam>x</DifferentiabilityParam>)</DifferentiabilityParams></DifferentiabilityParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
+func bazDerivative<FunctionSignature><ParameterClause>(<FunctionParameter>_ x: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </FunctionParameter><FunctionParameter>y: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>)</ParameterClause><ReturnClause>
+ -> <TupleType>(<TupleTypeElement>value: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement>pullback: <FunctionType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) -> <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionType></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
+ return <TupleExpr>(<TupleExprElement><IdentifierExpr>x</IdentifierExpr>, </TupleExprElement><TupleExprElement><ClosureExpr>{ <ClosureSignature><ClosureParam>v </ClosureParam>in </ClosureSignature><IdentifierExpr>v </IdentifierExpr>}</ClosureExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
+}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>+</QualifiedDeclName></DerivativeRegistrationAttributeArguments>)</Attribute>
+func addTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
+ return <TupleExpr>(<TupleExprElement><IdentifierExpr>v</IdentifierExpr>, </TupleExprElement><TupleExprElement><IdentifierExpr>v</IdentifierExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
+}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName>-</QualifiedDeclName>, <DifferentiabilityParamsClause>wrt: <DifferentiabilityParams>(<DifferentiabilityParam>0, </DifferentiabilityParam><DifferentiabilityParam>1</DifferentiabilityParam>)</DifferentiabilityParams></DifferentiabilityParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
+func subtractTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <TupleType>(<TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>, </TupleTypeElement><TupleTypeElement><SimpleTypeIdentifier>Float</SimpleTypeIdentifier></TupleTypeElement>) </TupleType></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
+ return <TupleExpr>(<TupleExprElement><IdentifierExpr>v</IdentifierExpr>, </TupleExprElement><TupleExprElement><PrefixOperatorExpr>-<IdentifierExpr>v</IdentifierExpr></PrefixOperatorExpr></TupleExprElement>)</TupleExpr></ReturnStmt>
+}</CodeBlock></FunctionDecl><FunctionDecl><Attribute>
+
+@transpose(<DerivativeRegistrationAttributeArguments>of: <QualifiedDeclName><SimpleTypeIdentifier>Float</SimpleTypeIdentifier>.-</QualifiedDeclName>, <DifferentiabilityParamsClause>wrt: <DifferentiabilityParam>0</DifferentiabilityParam></DifferentiabilityParamsClause></DerivativeRegistrationAttributeArguments>)</Attribute>
+func negateTranspose<FunctionSignature><ParameterClause>(<FunctionParameter>_ v: <SimpleTypeIdentifier>Float</SimpleTypeIdentifier></FunctionParameter>) </ParameterClause><ReturnClause>-> <SimpleTypeIdentifier>Float </SimpleTypeIdentifier></ReturnClause></FunctionSignature><CodeBlock>{<ReturnStmt>
+ return <PrefixOperatorExpr>-<IdentifierExpr>v</IdentifierExpr></PrefixOperatorExpr></ReturnStmt>
+}</CodeBlock></FunctionDecl>
diff --git a/test/Syntax/round_trip_parse_gen.swift b/test/Syntax/round_trip_parse_gen.swift
index 6c68ff5..deabcf3 100644
--- a/test/Syntax/round_trip_parse_gen.swift
+++ b/test/Syntax/round_trip_parse_gen.swift
@@ -596,3 +596,46 @@
#"abc \#(foo)"#
##"abc"##
##"abc \##(foo)"##
+
+// SWIFT_ENABLE_TENSORFLOW
+@differentiable
+func bar(_ x: Float, _: Float) -> Float { return 1 }
+
+@differentiable(where T : FloatingPoint)
+func bar<T : Numeric>(_ x: T, _: T) -> T { return 1 }
+
+@differentiable(wrt: x)
+func bar(_ x: Float, _: Float) -> Float { return 1 }
+
+@differentiable(wrt: (self, x, y))
+func bar(_ x: Float, y: Float) -> Float { return 1 }
+
+@differentiable(wrt: (self, x, y) where T : FloatingPoint)
+func bar<T : Numeric>(_ x: T, y: T) -> T { return 1 }
+
+@derivative(of: -)
+func negateDerivative(_ x: Float)
+ -> (value: Float, pullback: (Float) -> Float) {
+ return (-x, { v in -v })
+}
+
+@derivative(of: baz(label:_:), wrt: (x))
+func bazDerivative(_ x: Float, y: Float)
+ -> (value: Float, pullback: (Float) -> Float) {
+ return (x, { v in v })
+}
+
+@transpose(of: +)
+func addTranspose(_ v: Float) -> (Float, Float) {
+ return (v, v)
+}
+
+@transpose(of: -, wrt: (0, 1))
+func subtractTranspose(_ v: Float) -> (Float, Float) {
+ return (v, -v)
+}
+
+@transpose(of: Float.-, wrt: 0)
+func negateTranspose(_ v: Float) -> Float {
+ return -v
+}
diff --git a/test/TypeDecoder/structural_types.swift b/test/TypeDecoder/structural_types.swift
index 450cfa6..6081796 100644
--- a/test/TypeDecoder/structural_types.swift
+++ b/test/TypeDecoder/structural_types.swift
@@ -131,6 +131,43 @@
blackHole(b)
}
+// SWIFT_ENABLE_TENSORFLOW
+do {
+ let f: @differentiable (Float) -> Float = { $0 }
+ // FIXME(TF-123): `@differentiable` function type + opaque abstraction
+ // pattern bug.
+ // blackHole(f)
+ _ = f
+}
+
+do {
+ let f: (@escaping @differentiable (Float) -> Float) -> () = { _ in }
+ // FIXME(TF-123): `@differentiable` function type + opaque abstraction
+ // pattern bug.
+ // blackHole(f)
+ _ = f
+}
+
+// TODO: Uncomment when `@differentiable(linear)` function types are enabled.
+/*
+do {
+ let f: @differentiable(linear) (Float) -> Float = { $0 }
+ // FIXME(TF-123): `@differentiable` function type + opaque abstraction
+ // pattern bug.
+ // blackHole(f)
+ _ = f
+}
+
+do {
+ let f: (@escaping @differentiable(linear) (Float) -> Float) -> () = { _ in }
+ // FIXME(TF-123): `@differentiable` function type + opaque abstraction
+ // pattern bug.
+ // blackHole(f)
+ _ = f
+}
+*/
+// SWIFT_ENABLE_TENSORFLOW END
+
// DEMANGLE: $syycD
// DEMANGLE: $sySSzcD
// DEMANGLE: $sySSncD
@@ -150,6 +187,13 @@
// DEMANGLE: $sSayyyXCGD
// DEMANGLE: $sSayyyyXL_yyXBtcGD
+// SWIFT_ENABLE_TENSORFLOW
+// DEMANGLE: $sS2fXFD
+// DEMANGLE: $sS2fXGD
+// DEMANGLE: $sS2fXHD
+// DEMANGLE: $sS2fXID
+// SWIFT_ENABLE_TENSORFLOW END
+
// CHECK: () -> ()
// CHECK: (inout String) -> ()
// CHECK: (__owned String) -> ()
@@ -169,6 +213,13 @@
// CHECK: Array<@convention(c) () -> ()>
// CHECK: Array<(@escaping @convention(block) () -> (), @convention(block) () -> ()) -> ()>
+// SWIFT_ENABLE_TENSORFLOW
+// CHECK: @differentiable (Float) -> Float
+// CHECK: @differentiable (Float) -> Float
+// CHECK: @differentiable(linear) (Float) -> Float
+// CHECK: @differentiable(linear) (Float) -> Float
+// SWIFT_ENABLE_TENSORFLOW END
+
// DEMANGLE: $sSimD
// DEMANGLE: $syycmD
// DEMANGLE: $sySSzcmD
@@ -189,6 +240,13 @@
// DEMANGLE: $sSayyyXCGmD
// DEMANGLE: $sSayyyyXL_yyXBtcGmD
+// SWIFT_ENABLE_TENSORFLOW
+// DEMANGLE: $sS2fXFmD
+// DEMANGLE: $sS2fXGmD
+// DEMANGLE: $sS2fXHmD
+// DEMANGLE: $sS2fXImD
+// SWIFT_ENABLE_TENSORFLOW END
+
// CHECK: Int.Type
// CHECK: ((inout String) -> ()).Type
// CHECK: ((__owned String) -> ()).Type
@@ -207,3 +265,10 @@
// CHECK: ((@escaping () -> ()) -> ()).Type
// CHECK: Array<@convention(c) () -> ()>.Type
// CHECK: Array<(@escaping @convention(block) () -> (), @convention(block) () -> ()) -> ()>.Type
+
+// SWIFT_ENABLE_TENSORFLOW
+// CHECK: (@differentiable (Float) -> Float).Type
+// CHECK: (@differentiable (Float) -> Float).Type
+// CHECK: (@differentiable(linear) (Float) -> Float).Type
+// CHECK: (@differentiable(linear) (Float) -> Float).Type
+// SWIFT_ENABLE_TENSORFLOW END
diff --git a/test/api-digester/stability-stdlib-abi-with-asserts.test b/test/api-digester/stability-stdlib-abi-with-asserts.test
index c75b63b..2c46c99 100644
--- a/test/api-digester/stability-stdlib-abi-with-asserts.test
+++ b/test/api-digester/stability-stdlib-abi-with-asserts.test
@@ -1,3 +1,11 @@
+// SWIFT_ENABLE_TENSORFLOW
+// Note: this test is disabled on `tensorflow` branch because `tensorflow`
+// branch adds various public APIs to the stdlib without `@available`
+// attributes. These APIs should ideally be removed over time, or upstreamed to
+// `master` branch.
+// UNSUPPORTED: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// Welcome, Build Wrangler!
//
// A failure in this test indicates that there is a potential ABI breaking
diff --git a/test/api-digester/stability-stdlib-abi-without-asserts.test b/test/api-digester/stability-stdlib-abi-without-asserts.test
index 6a3c887..96584e5 100644
--- a/test/api-digester/stability-stdlib-abi-without-asserts.test
+++ b/test/api-digester/stability-stdlib-abi-without-asserts.test
@@ -1,3 +1,11 @@
+// SWIFT_ENABLE_TENSORFLOW
+// Note: this test is disabled on `tensorflow` branch because `tensorflow`
+// branch adds various public APIs to the stdlib without `@available`
+// attributes. These APIs should ideally be removed over time, or upstreamed to
+// `master` branch.
+// UNSUPPORTED: tensorflow
+// SWIFT_ENABLE_TENSORFLOW END
+
// Welcome, Build Wrangler!
//
// A failure in this test indicates that there is a potential ABI breaking
diff --git a/test/lit.cfg b/test/lit.cfg
index 657adcf..d9ae1db 100644
--- a/test/lit.cfg
+++ b/test/lit.cfg
@@ -1797,6 +1797,24 @@
'%%line-directive %%t/main.swift -- '
'%s %%t/a.out'
% (config.target_build_swift, mcp_opt, config.target_codesign, config.target_run))
+ # SWIFT_ENABLE_TENSORFLOW
+ # TODO: Remove when forward mode AD support is robust.
+ config.target_run_simple_swift_forward_mode_differentiation = (
+ '%%empty-directory(%%t) && '
+ '%s %%s -enable-experimental-forward-mode-differentiation -o %%t/a.out %s -module-name main && '
+ '%s %%t/a.out &&'
+ '%s %%t/a.out'
+ % (config.target_build_swift, mcp_opt, config.target_codesign, config.target_run))
+ config.target_run_simple_swiftgyb_forward_mode_differentiation = (
+ '%%empty-directory(%%t) && '
+ '%%gyb %%s -o %%t/main.swift && '
+ '%%line-directive %%t/main.swift -- '
+ '%s %s %%t/main.swift -enable-experimental-forward-mode-differentiation -o %%t/a.out -module-name main && '
+ '%s %%t/a.out && '
+ '%%line-directive %%t/main.swift -- '
+ '%s %%t/a.out'
+ % (config.target_build_swift, mcp_opt, config.target_codesign, config.target_run))
+ # SWIFT_ENABLE_TENSORFLOW END
# FIXME: why can we not use %rth and have that be expanded out?
config.target_resilience_test = (
@@ -1836,6 +1854,13 @@
subst_target_swift_frontend_mock_sdk_after))))
config.substitutions.append(('%target-swift-frontend', config.target_swift_frontend))
+# SWIFT_ENABLE_TENSORFLOW
+# TODO: Remove when forward mode AD support is robust.
+config.substitutions.append(('%target-run-simple-swiftgyb-forward-mode-differentiation',
+ config.target_run_simple_swiftgyb_forward_mode_differentiation))
+config.substitutions.append(('%target-run-simple-swift-forward-mode-differentiation',
+ config.target_run_simple_swift_forward_mode_differentiation))
+# SWIFT_ENABLE_TENSORFLOW END
config.substitutions.append(('%target-run-simple-swiftgyb\(([^)]+)\)',
config.target_run_simple_swiftgyb_parameterized))
diff --git a/test/lit.site.cfg.in b/test/lit.site.cfg.in
index c4511d2..7d245c5 100644
--- a/test/lit.site.cfg.in
+++ b/test/lit.site.cfg.in
@@ -63,6 +63,11 @@
else:
config.available_features.add('no_asan')
+if "@SWIFT_ENABLE_TENSORFLOW@" == "TRUE":
+ config.available_features.add('tensorflow')
+else:
+ config.available_features.add('no_tensorflow')
+
if "@SWIFT_RUNTIME_ENABLE_LEAK_CHECKER@" == "TRUE":
config.available_features.add('leak-checker')
diff --git a/test/stdlib/ElementaryFunctions.swift.gyb b/test/stdlib/ElementaryFunctions.swift.gyb
new file mode 100644
index 0000000..5818853
--- /dev/null
+++ b/test/stdlib/ElementaryFunctions.swift.gyb
@@ -0,0 +1,109 @@
+//===--- ElementaryFunctions.swift.gyb ------------------------*- swift -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+// SWIFT_ENABLE_TENSORFLOW
+// Runtime tests for ElementaryFunctions derived conformances.
+//===----------------------------------------------------------------------===//
+// -*- swift -*-
+// RUN: %empty-directory(%t)
+// RUN: %gyb %s -o %t/tgmath.swift
+// RUN: %line-directive %t/tgmath.swift -- %target-build-swift %t/tgmath.swift -o %t/a.out
+// RUN: %target-codesign %t/a.out
+// RUN: %line-directive %t/tgmath.swift -- %target-run %t/a.out
+// REQUIRES: executable_test
+
+#if (arch(i386) || arch(x86_64)) && !os(Windows)
+ typealias TestLiteralType = Float80
+#else
+ typealias TestLiteralType = Double
+#endif
+
+import StdlibUnittest
+
+let MathTests = TestSuite("Math")
+
+func expectEqualWithNaNEquality<T>(
+ _ expected: [T], _ actual: [T], file: String = #file, line: UInt = #line
+) where T: BinaryFloatingPoint {
+ for (x, y) in zip(expected, actual) {
+ expectTrue(x == y || x.isNaN && y.isNaN,
+ "\(x) != \(y) for \(T.self).",
+ file: file, line: line)
+ }
+}
+
+%from SwiftMathFunctions import *
+
+struct Wrapper<T: ElementaryFunctions & Equatable>: ElementaryFunctions & Equatable {
+ var x, y: T
+ var values: [T] { [x, y] }
+}
+
+// Prevent any optimizers from const-evaluating any of the math, which can
+// cause results that are different from the results calculated at runtime.
+@inline(never)
+func makeValues<T: BinaryFloatingPoint>() -> [T] {
+ return [-0.375, 0.375]
+}
+
+@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)
+internal extension ElementaryFunctions where Self: BinaryFloatingPoint {
+ static func elementaryFunctionTests() {
+ let values: [Self] = makeValues()
+ let wrapper = Wrapper<Self>(x: values[0], y: values[1])
+
+ expectEqualWithNaNEquality(values.map(Self.acos), Wrapper<Self>.acos(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.asin), Wrapper<Self>.asin(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.atan), Wrapper<Self>.atan(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.cos), Wrapper<Self>.cos(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.sin), Wrapper<Self>.sin(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.tan), Wrapper<Self>.tan(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.acosh), Wrapper<Self>.acosh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.asinh), Wrapper<Self>.asinh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.atanh), Wrapper<Self>.atanh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.cosh), Wrapper<Self>.cosh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.sinh), Wrapper<Self>.sinh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.tanh), Wrapper<Self>.tanh(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.exp), Wrapper<Self>.exp(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.exp2), Wrapper<Self>.exp2(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.exp10), Wrapper<Self>.exp10(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.expm1), Wrapper<Self>.expm1(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.log), Wrapper<Self>.log(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.log2), Wrapper<Self>.log2(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.log10), Wrapper<Self>.log10(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.log1p), Wrapper<Self>.log1p(wrapper).values)
+ expectEqualWithNaNEquality(values.map(Self.sqrt), Wrapper<Self>.sqrt(wrapper).values)
+ expectEqualWithNaNEquality(values.map { x in Self.root(x, 3) }, Wrapper<Self>.root(wrapper, 3).values)
+ expectEqualWithNaNEquality(values.map { x in Self.pow(x, x) }, Wrapper<Self>.pow(wrapper, wrapper).values)
+ expectEqualWithNaNEquality(values.map { x in Self.pow(x, 3) }, Wrapper<Self>.pow(wrapper, 3).values)
+ }
+}
+
+%for T in ['Float', 'Double', 'CGFloat', 'Float80']:
+% if T == 'Float80':
+#if (arch(i386) || arch(x86_64)) && !os(Windows)
+% elif T == 'CGFloat':
+#if canImport(CoreGraphics)
+ import CoreGraphics
+% end
+
+MathTests.test("${T}") {
+ if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) {
+ ${T}.elementaryFunctionTests()
+ }
+}
+
+% if T in ['CGFloat', 'Float80']:
+#endif
+% end
+%end
+
+runAllTests()
diff --git a/test/stdlib/ErrorBridgedStatic.swift b/test/stdlib/ErrorBridgedStatic.swift
new file mode 100644
index 0000000..da98462
--- /dev/null
+++ b/test/stdlib/ErrorBridgedStatic.swift
@@ -0,0 +1,45 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-clang -fmodules -c -o %t/ErrorBridgedStaticImpl.o %S/Inputs/ErrorBridgedStaticImpl.m
+// RUN: %target-build-swift -static-stdlib -o %t/ErrorBridgedStatic %t/ErrorBridgedStaticImpl.o %s -import-objc-header %S/Inputs/ErrorBridgedStaticImpl.h
+// RUN: strip %t/ErrorBridgedStatic
+// RUN: %target-run %t/ErrorBridgedStatic
+
+// REQUIRES: rdar50279940
+// REQUIRES: executable_test
+// REQUIRES: objc_interop
+// REQUIRES: static_stdlib
+
+// SWIFT_ENABLE_TENSORFLOW: This test is unsupported because TensorFlow currently doesn't work with static-stdlib.
+// UNSUPPORTED: tensorflow
+
+import StdlibUnittest
+
+class Bar: Foo {
+ override func foo(_ x: Int32) throws {
+ try super.foo(5)
+ }
+
+ override func foothrows(_ x: Int32) throws {
+ try super.foothrows(5)
+ }
+}
+
+var ErrorBridgingStaticTests = TestSuite("ErrorBridging with static libs")
+
+ErrorBridgingStaticTests.test("round-trip Swift override of ObjC method") {
+ do {
+ try (Bar() as Foo).foo(5)
+ } catch { }
+}
+
+ErrorBridgingStaticTests.test("round-trip Swift override of throwing ObjC method") {
+ do {
+ try (Bar() as Foo).foothrows(5)
+ } catch {
+ print(error)
+ expectEqual(error._domain, "abcd")
+ expectEqual(error._code, 1234)
+ }
+}
+
+runAllTests()
diff --git a/test/stdlib/KeyPathIterable.swift b/test/stdlib/KeyPathIterable.swift
new file mode 100644
index 0000000..add855d
--- /dev/null
+++ b/test/stdlib/KeyPathIterable.swift
@@ -0,0 +1,184 @@
+// SWIFT_ENABLE_TENSORFLOW
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift -swift-version 5 -g %s -o %t/a.out
+// RUN: %target-codesign %t/a.out
+// RUN: %target-run %t/a.out
+// REQUIRES: executable_test
+//
+// `KeyPathIterable` tests.
+
+import StdlibUnittest
+
+var KeyPathIterableTests = TestSuite("KeyPathIterable")
+
+struct Simple : KeyPathIterable, Equatable {
+ var w, b: Float
+}
+
+struct Mixed : KeyPathIterable, Equatable {
+ // Mutable.
+ var string: String
+ var float: Float
+ // Immutable.
+ let int: Int
+}
+
+struct Nested : KeyPathIterable, Equatable {
+ // Immutable.
+ let simple: Simple
+ // Mutable.
+ var mixed: Mixed
+}
+
+struct ComplexNested : KeyPathIterable, Equatable {
+ var float: Float
+ let simple: Simple
+ let array: [Simple]
+ var dictionary: [String : Simple]
+}
+
+KeyPathIterableTests.test("Simple") {
+ var x = Simple(w: 1, b: 2)
+ expectEqual([\Simple.w, \Simple.b], x.allKeyPaths)
+ expectEqual([\Simple.w, \Simple.b], x.allKeyPaths(to: Float.self))
+ expectEqual([\Simple.w, \Simple.b], x.allWritableKeyPaths(to: Float.self))
+ expectEqual([\Simple.w, \Simple.b], x.recursivelyAllKeyPaths)
+ expectEqual([\Simple.w, \Simple.b], x.recursivelyAllKeyPaths(to: Float.self))
+ expectEqual([\Simple.w, \Simple.b], x.recursivelyAllWritableKeyPaths(to: Float.self))
+ expectEqual([], x.allKeyPaths(to: Int.self))
+ expectEqual([], x.recursivelyAllKeyPaths(to: Double.self))
+
+ // Mutate recursively all `Float` properties.
+ for kp in x.allWritableKeyPaths(to: Float.self) {
+ x[keyPath: kp] += x[keyPath: kp]
+ }
+ // Check that recursively all `Float` properties have been mutated.
+ expectEqual(Simple(w: 2, b: 4), x)
+}
+
+KeyPathIterableTests.test("Mixed") {
+ var x = Mixed(string: "hello", float: .pi, int: 0)
+ expectEqual([\Mixed.string, \Mixed.float, \Mixed.int], x.allKeyPaths)
+ expectEqual([\Mixed.string, \Mixed.float, \Mixed.int], x.recursivelyAllKeyPaths)
+
+ expectEqual([\Mixed.string], x.allKeyPaths(to: String.self))
+ expectEqual([\Mixed.string], x.allWritableKeyPaths(to: String.self))
+ expectEqual([\Mixed.string], x.recursivelyAllKeyPaths(to: String.self))
+ expectEqual([\Mixed.string], x.recursivelyAllWritableKeyPaths(to: String.self))
+
+ expectEqual([\Mixed.float], x.allKeyPaths(to: Float.self))
+ expectEqual([\Mixed.float], x.allWritableKeyPaths(to: Float.self))
+ expectEqual([\Mixed.float], x.recursivelyAllKeyPaths(to: Float.self))
+ expectEqual([\Mixed.float], x.recursivelyAllWritableKeyPaths(to: Float.self))
+
+ expectEqual([\Mixed.int], x.allKeyPaths(to: Int.self))
+ expectEqual([], x.allWritableKeyPaths(to: Int.self))
+ expectEqual([\Mixed.int], x.recursivelyAllKeyPaths(to: Int.self))
+ expectEqual([], x.recursivelyAllWritableKeyPaths(to: Int.self))
+
+ // Mutate recursively all `String` properties.
+ for kp in x.allWritableKeyPaths(to: String.self) {
+ x[keyPath: kp] = x[keyPath: kp] + " world"
+ }
+ // Check that recursively all `String` properties have been mutated.
+ expectEqual(Mixed(string: "hello world", float: .pi, int: 0), x)
+}
+
+KeyPathIterableTests.test("SimpleNested") {
+ var x = Nested(simple: Simple(w: 1, b: 2),
+ mixed: Mixed(string: "foo", float: 3, int: 0))
+
+ expectEqual([\Nested.simple, \Nested.mixed], x.allKeyPaths)
+ expectEqual([\Nested.simple, \Nested.simple.w, \Nested.simple.b,
+ \Nested.mixed, \Nested.mixed.string,
+ \Nested.mixed.float, \Nested.mixed.int],
+ x.recursivelyAllKeyPaths)
+
+ expectEqual([], x.allKeyPaths(to: Float.self))
+ expectEqual([], x.allKeyPaths(to: Int.self))
+ expectEqual([], x.allKeyPaths(to: String.self))
+
+ expectEqual([], x.allWritableKeyPaths(to: Float.self))
+ expectEqual([], x.allWritableKeyPaths(to: Int.self))
+ expectEqual([], x.allWritableKeyPaths(to: String.self))
+
+ expectEqual([\Nested.simple.w, \Nested.simple.b, \Nested.mixed.float],
+ x.recursivelyAllKeyPaths(to: Float.self))
+ expectEqual([\Nested.mixed.int], x.recursivelyAllKeyPaths(to: Int.self))
+ expectEqual([\Nested.mixed.string], x.recursivelyAllKeyPaths(to: String.self))
+
+ expectEqual([\Nested.mixed.float], x.recursivelyAllWritableKeyPaths(to: Float.self))
+ expectEqual([], x.recursivelyAllWritableKeyPaths(to: Int.self))
+ expectEqual([\Nested.mixed.string], x.recursivelyAllWritableKeyPaths(to: String.self))
+
+ expectEqual([], x.recursivelyAllKeyPaths(to: Double.self))
+
+ // Mutate recursively all `Float` properties.
+ for kp in x.recursivelyAllWritableKeyPaths(to: Float.self) {
+ x[keyPath: kp] *= 100
+ }
+ // Check that recursively all `Float` properties have been mutated.
+ let expected = Nested(simple: Simple(w: 1, b: 2),
+ mixed: Mixed(string: "foo", float: 300, int: 0))
+ expectEqual(expected, x)
+}
+
+KeyPathIterableTests.test("ComplexNested") {
+ var x = ComplexNested(float: 1, simple: Simple(w: 3, b: 4),
+ array: [Simple(w: 5, b: 6), Simple(w: 7, b: 8)],
+ dictionary: ["foo" : Simple(w: 1, b: 2),
+ "bar" : Simple(w: 3, b: 4)])
+ expectEqual([\ComplexNested.float, \ComplexNested.simple,
+ \ComplexNested.array, \ComplexNested.dictionary],
+ x.allKeyPaths)
+ expectEqual([\ComplexNested.float,
+ \ComplexNested.simple,
+ \ComplexNested.simple.w,
+ \ComplexNested.simple.b,
+ \ComplexNested.array,
+ \ComplexNested.array[0],
+ \ComplexNested.array[0].w,
+ \ComplexNested.array[0].b,
+ \ComplexNested.array[1],
+ \ComplexNested.array[1].w,
+ \ComplexNested.array[1].b,
+ \ComplexNested.dictionary,
+ \ComplexNested.dictionary["bar"]!,
+ \ComplexNested.dictionary["bar"]!.w,
+ \ComplexNested.dictionary["bar"]!.b,
+ \ComplexNested.dictionary["foo"]!,
+ \ComplexNested.dictionary["foo"]!.w,
+ \ComplexNested.dictionary["foo"]!.b],
+ x.recursivelyAllKeyPaths)
+ expectEqual([\ComplexNested.float,
+ \ComplexNested.simple.w,
+ \ComplexNested.simple.b,
+ \ComplexNested.array[0].w,
+ \ComplexNested.array[0].b,
+ \ComplexNested.array[1].w,
+ \ComplexNested.array[1].b,
+ \ComplexNested.dictionary["bar"]!.w,
+ \ComplexNested.dictionary["bar"]!.b,
+ \ComplexNested.dictionary["foo"]!.w,
+ \ComplexNested.dictionary["foo"]!.b],
+ x.recursivelyAllKeyPaths(to: Float.self))
+ expectEqual([\ComplexNested.float,
+ \ComplexNested.dictionary["bar"]!.w,
+ \ComplexNested.dictionary["bar"]!.b,
+ \ComplexNested.dictionary["foo"]!.w,
+ \ComplexNested.dictionary["foo"]!.b],
+ x.recursivelyAllWritableKeyPaths(to: Float.self))
+
+ // Mutate recursively all `Float` properties.
+ for kp in x.recursivelyAllWritableKeyPaths(to: Float.self) {
+ x[keyPath: kp] += 1
+ }
+ // Check that recursively all `Float` properties have been mutated.
+ let expected = ComplexNested(float: 2, simple: Simple(w: 3, b: 4),
+ array: [Simple(w: 5, b: 6), Simple(w: 7, b: 8)],
+ dictionary: ["foo" : Simple(w: 2, b: 3),
+ "bar" : Simple(w: 4, b: 5)])
+ expectEqual(expected, x)
+}
+
+runAllTests()
diff --git a/test/stdlib/MathFunctions.swift.gyb b/test/stdlib/MathFunctions.swift.gyb
new file mode 100644
index 0000000..b784db4
--- /dev/null
+++ b/test/stdlib/MathFunctions.swift.gyb
@@ -0,0 +1,101 @@
+//===--- Math.swift.gyb ---------------------------------------*- swift -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+// -*- swift -*-
+// RUN: %empty-directory(%t)
+// RUN: %gyb %s -o %t/tgmath.swift
+// RUN: %line-directive %t/tgmath.swift -- %target-build-swift %t/tgmath.swift -o %t/a.out
+// RUN: %target-codesign %t/a.out
+// RUN: %line-directive %t/tgmath.swift -- %target-run %t/a.out
+// REQUIRES: executable_test
+
+#if (arch(i386) || arch(x86_64)) && !os(Windows)
+ typealias TestLiteralType = Float80
+#else
+ typealias TestLiteralType = Double
+#endif
+
+import StdlibUnittest
+
+let MathTests = TestSuite("Math")
+
+func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
+ ulps allowed: T = 3,
+ file: String = #file, line: UInt = #line)
+ where T: BinaryFloatingPoint {
+ if actual == T(expected) || actual.isNaN && expected.isNaN {
+ return
+ }
+ // Compute error in ulp, compare to tolerance.
+ let absoluteError = T(abs(TestLiteralType(actual) - expected))
+ let ulpError = absoluteError / T(expected).ulp
+ expectTrue(ulpError <= allowed,
+ "\(actual) != \(expected) as \(T.self)" +
+ "\n \(ulpError)-ulp error exceeds \(allowed)-ulp tolerance.",
+ file: file, line: line)
+}
+
+%from SwiftMathFunctions import *
+
+@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)
+internal extension ElementaryFunctions where Self: BinaryFloatingPoint {
+ static func elementaryFunctionTests() {
+ /* Default tolerance is 3 ulps unless specified otherwise. It's OK to relax
+ * this as needed for new platforms, as these tests are *not* intended to
+ * validate the math library--they are only intended to check that the
+ * Swift bindings are calling the right functions in the math library. */
+ expectEqualWithTolerance(1.1863995522992575361931268186727044683, Self.acos(0.375))
+ expectEqualWithTolerance(0.3843967744956390830381948729670469737, Self.asin(0.375))
+ expectEqualWithTolerance(0.3587706702705722203959200639264604997, Self.atan(0.375))
+ expectEqualWithTolerance(0.9305076219123142911494767922295555080, Self.cos(0.375))
+ expectEqualWithTolerance(0.3662725290860475613729093517162641571, Self.sin(0.375))
+ expectEqualWithTolerance(0.3936265759256327582294137871012180981, Self.tan(0.375))
+ expectEqualWithTolerance(0.4949329230945269058895630995767185785, Self.acosh(1.125))
+ expectEqualWithTolerance(0.9670596312833237113713762009167286709, Self.asinh(1.125))
+ expectEqualWithTolerance(0.7331685343967135223291211023213964500, Self.atanh(0.625))
+ expectEqualWithTolerance(1.0711403467045867672994980155670160493, Self.cosh(0.375))
+ expectEqualWithTolerance(0.3838510679136145687542956764205024589, Self.sinh(0.375))
+ expectEqualWithTolerance(0.3583573983507859463193602315531580424, Self.tanh(0.375))
+ expectEqualWithTolerance(1.4549914146182013360537936919875185083, Self.exp(0.375))
+ expectEqualWithTolerance(1.2968395546510096659337541177924511598, Self.exp2(0.375))
+ expectEqualWithTolerance(2.3713737056616552616517527574788898386, Self.exp10(0.375))
+ expectEqualWithTolerance(0.4549914146182013360537936919875185083, Self.expm1(0.375))
+ expectEqualWithTolerance(-0.980829253011726236856451127452003999, Self.log(0.375))
+ expectEqualWithTolerance(-1.415037499278843818546261056052183491, Self.log2(0.375))
+ expectEqualWithTolerance(0.3184537311185346158102472135905995955, Self.log1p(0.375))
+ expectEqualWithTolerance(-0.425968732272281148346188780918363771, Self.log10(0.375))
+ expectEqualWithTolerance(-0.7211247851537041911608191553900547941, Self.root(-0.375, 3))
+ expectEqualWithTolerance(0.6123724356957945245493210186764728479, Self.sqrt(0.375))
+ expectEqualWithTolerance(0.54171335479545025876069682133938570, Self.pow(0.375, 0.625))
+ expectEqualWithTolerance(-0.052734375, Self.pow(-0.375, 3))
+ }
+}
+
+%for T in ['Float', 'Double', 'CGFloat', 'Float80']:
+% if T == 'Float80':
+#if (arch(i386) || arch(x86_64)) && !os(Windows)
+% elif T == 'CGFloat':
+#if canImport(CoreGraphics)
+ import CoreGraphics
+% end
+
+MathTests.test("${T}") {
+ if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) {
+ ${T}.elementaryFunctionTests()
+ }
+}
+
+% if T in ['CGFloat', 'Float80']:
+#endif
+% end
+%end
+
+runAllTests()
diff --git a/test/stdlib/SIMD_as_AdditiveArithmetic.swift b/test/stdlib/SIMD_as_AdditiveArithmetic.swift
index f2641a3..d1889db 100644
--- a/test/stdlib/SIMD_as_AdditiveArithmetic.swift
+++ b/test/stdlib/SIMD_as_AdditiveArithmetic.swift
@@ -1,3 +1,11 @@
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE(TF-801): Test is disabled because `tensorflow` branch already defines
+// `SIMD{X}: AdditiveArithmetic` conformances in the stdlib to support
+// `SIMD{X}: Differentiable` conformances.
+// Upstreaming `SIMD{X}: AdditiveArithmetic` is not yet possible because it is
+// an ABI breaking change.
+UNSUPPORTED: true
+
// RUN: %target-typecheck-verify-swift
extension SIMD2: AdditiveArithmetic where Scalar: FloatingPoint { }
extension SIMD3: AdditiveArithmetic where Scalar: FloatingPoint { }
diff --git a/test/stdlib/TensorFlowEnabled.swift b/test/stdlib/TensorFlowEnabled.swift
new file mode 100644
index 0000000..f1e796a
--- /dev/null
+++ b/test/stdlib/TensorFlowEnabled.swift
@@ -0,0 +1,11 @@
+// RUN: %target-run-simple-swift
+// REQUIRES: tensorflow
+
+import StdlibUnittest
+
+let TensorFlowEnabled = TestSuite("TensorFlowEnabled")
+TensorFlowEnabled.test("TensorFlowEnabled") {
+ expectPrinted("1", 1)
+}
+
+runAllTests()
diff --git a/test/stdlib/TestDecimal.swift b/test/stdlib/TestDecimal.swift
index ee0c7e4..15d780c 100644
--- a/test/stdlib/TestDecimal.swift
+++ b/test/stdlib/TestDecimal.swift
@@ -16,6 +16,12 @@
// REQUIRES: executable_test
// REQUIRES: objc_interop
+// SWIFT_ENABLE_TENSORFLOW
+// This test is currently unsupported because the addition of `+` operators
+// to the stdlib (via `VectorProtocol`) causes type-checking to fail.
+// Re-enable when type-checking no longer fails.
+// UNSUPPORTED: executable_test
+
import Foundation
import FoundationBridgeObjC
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index b50cbc7..bf4a964 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -54,3 +54,6 @@
endif()
add_swift_tool_subdirectory(swift-reflection-dump)
+
+# SWIFT_ENABLE_TENSORFLOW
+add_swift_tool_subdirectory(libInMemoryFrontend)
diff --git a/tools/SourceKit/cmake/modules/AddSwiftSourceKit.cmake b/tools/SourceKit/cmake/modules/AddSwiftSourceKit.cmake
index caba3e1..1a26393 100644
--- a/tools/SourceKit/cmake/modules/AddSwiftSourceKit.cmake
+++ b/tools/SourceKit/cmake/modules/AddSwiftSourceKit.cmake
@@ -7,6 +7,17 @@
_add_host_variant_c_compile_flags(${target})
_add_host_variant_link_flags(${target})
+ # SWIFT_ENABLE_TENSORFLOW
+ if(SWIFT_ENABLE_TENSORFLOW)
+ if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
+ # FIXME: This is a hack: adding rpaths with many `..` that jump across
+ # frameworks is bad practice. It would be cleaner/more robust to copy
+ # the TensorFlow libraries to sourcekitd.framework.
+ set_target_properties(${target} PROPERTIES
+ INSTALL_RPATH "@loader_path/../../../swift/${SOURCEKIT_DEPLOYMENT_OS};@loader_path/../../../../../../../swift/${SOURCEKIT_DEPLOYMENT_OS}")
+ endif()
+ endif()
+
# Set compilation and link flags.
if(${SWIFT_HOST_VARIANT_SDK} STREQUAL WINDOWS)
swift_windows_include_for_arch(${SWIFT_HOST_VARIANT_ARCH}
@@ -176,6 +187,8 @@
set_target_properties(${name} PROPERTIES FOLDER "SourceKit executables")
add_sourcekit_default_compiler_flags("${name}")
+ set_property(TARGET "${name}" APPEND_STRING PROPERTY
+ COMPILE_FLAGS " ${SOURCEKITEXE_C_COMPILE_FLAGS}")
endmacro()
# Add a new SourceKit framework.
diff --git a/tools/SourceKit/include/SourceKit/Core/LangSupport.h b/tools/SourceKit/include/SourceKit/Core/LangSupport.h
index 6ed2f91..fbd6552 100644
--- a/tools/SourceKit/include/SourceKit/Core/LangSupport.h
+++ b/tools/SourceKit/include/SourceKit/Core/LangSupport.h
@@ -21,6 +21,8 @@
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallString.h"
#include "swift/AST/Type.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "clang/Basic/InMemoryOutputFileSystem.h"
#include "llvm/Support/VirtualFileSystem.h"
#include <functional>
#include <memory>
@@ -657,6 +659,13 @@
virtual ~LangSupport() { }
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Subsequent requests will write temporary output files to this filesystem
+ /// rather than to the real filesystem.
+ virtual void setInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS) = 0;
+ // SWIFT_ENABLE_TENSORFLOW END
+
virtual void globalConfigurationUpdated(std::shared_ptr<GlobalConfig> Config) {};
virtual void indexSource(StringRef Filename,
@@ -827,6 +836,19 @@
Optional<VFSOptions> vfsOptions) = 0;
virtual void getStatistics(StatisticsReceiver) = 0;
+
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Tempoary shim for clients that want to pass the filesystem directly.
+ virtual void
+ codeComplete(llvm::MemoryBuffer *InputBuf, unsigned Offset,
+ CodeCompletionConsumer &Consumer, ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) = 0;
+
+ /// Tempoary shim for clients that want to pass the filesystem directly.
+ virtual void
+ editorOpen(StringRef Name, llvm::MemoryBuffer *Buf, EditorConsumer &Consumer,
+ ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) = 0;
};
} // namespace SourceKit
diff --git a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp
index dfc5ea5..81fa025 100644
--- a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp
+++ b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp
@@ -361,6 +361,12 @@
llvm::sys::Mutex CacheMtx;
std::time_t SessionTimestamp;
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Requests will write temporary output files to this filesystem rather than
+ /// to the real filesystem.
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem>
+ InMemoryOutputFileSystem;
+
WorkQueue ASTBuildQueue{ WorkQueue::Dequeuing::Serial,
"sourcekit.swift.ASTBuilding" };
@@ -390,6 +396,12 @@
delete &Impl;
}
+// SWIFT_ENABLE_TENSORFLOW
+void SwiftASTManager::setInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS) {
+ Impl.InMemoryOutputFileSystem = std::move(FS);
+}
+
std::unique_ptr<llvm::MemoryBuffer>
SwiftASTManager::getMemoryBuffer(StringRef Filename, std::string &Error) {
return Impl.getMemoryBuffer(Filename, llvm::vfs::getRealFileSystem(), Error);
@@ -420,6 +432,9 @@
std::string &Error) {
return ide::initCompilerInvocation(
Invocation, OrigArgs, Diags, UnresolvedPrimaryFile, FileSystem,
+ // SWIFT_ENABLE_TENSORFLOW
+ Impl.InMemoryOutputFileSystem,
+ // SWIFT_ENABLE_TENSORFLOW END
Impl.RuntimeResourcePath, Impl.DiagnosticDocumentationPath,
Impl.Config->shouldOptimizeForIDE(), Impl.SessionTimestamp, Error);
}
diff --git a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.h b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.h
index cbf10c5..61b2c78 100644
--- a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.h
+++ b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.h
@@ -15,6 +15,8 @@
#include "SwiftInvocation.h"
#include "SourceKit/Core/LLVM.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "clang/Basic/InMemoryOutputFileSystem.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/StringRef.h"
@@ -96,6 +98,12 @@
StringRef DiagnosticDocumentationPath);
~SwiftASTManager();
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Subsequent requests will write temporary output files to this filesystem
+ /// rather than to the real filesystem.
+ void setInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS);
+
SwiftInvocationRef getInvocation(
ArrayRef<const char *> Args, StringRef PrimaryFile, std::string &Error);
diff --git a/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.cpp b/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.cpp
index 02b066e7..ac23079 100644
--- a/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.cpp
+++ b/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.cpp
@@ -259,6 +259,69 @@
return OverlayFS;
}
};
+
+class DirectlyPassedFileSystemProvider : public SourceKit::FileSystemProvider {
+public:
+ struct Options : public OptionsDictionary {
+ const unsigned FSID;
+
+ Options(unsigned FSID) : FSID(FSID) {}
+
+ bool valueForOption(UIdent Key, unsigned &Val) override {
+ static UIdent KeyFSID("key.fsid");
+ if (Key != KeyFSID)
+ return false;
+ Val = FSID;
+ return true;
+ }
+ bool valueForOption(UIdent Key, bool &Val) override { return true; }
+ bool valueForOption(UIdent Key, StringRef &Val) override { return false; }
+ bool
+ forEach(UIdent key,
+ llvm::function_ref<bool(OptionsDictionary &)> applier) override {
+ return false;
+ }
+ };
+
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem>
+ getFileSystem(OptionsDictionary &options, std::string &error) override {
+ llvm::sys::ScopedLock L(mtx);
+ static UIdent KeyFSID("key.fsid");
+ unsigned FSID;
+ if (!options.valueForOption(KeyFSID, FSID)) {
+ error = "'key.fsid' not specified";
+ return nullptr;
+ }
+
+ auto result = FileSystems.find(FSID);
+ if (result == FileSystems.end()) {
+ error = "filesystem not found";
+ return nullptr;
+ }
+
+ return result->second;
+ }
+
+ Options addFileSystem(llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs) {
+ llvm::sys::ScopedLock L(mtx);
+ auto result = FileSystems.try_emplace(NextFSID, std::move(fs));
+ assert(result.second && "duplicate key");
+ Options ret(NextFSID);
+ ++NextFSID;
+ return ret;
+ }
+
+ void removeFileSystem(Options &options) {
+ llvm::sys::ScopedLock L(mtx);
+ FileSystems.erase(options.FSID);
+ }
+
+private:
+ llvm::sys::Mutex mtx;
+ unsigned NextFSID = 0;
+ llvm::DenseMap<unsigned, llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem>>
+ FileSystems;
+};
}
static void
@@ -293,11 +356,22 @@
// Provide a default file system provider.
setFileSystemProvider("in-memory-vfs", std::make_unique<InMemoryFileSystemProvider>());
+
+ // SWIFT_ENABLE_TENSORFLOW
+ setFileSystemProvider("directly-passed-vfs",
+ std::make_unique<DirectlyPassedFileSystemProvider>());
}
SwiftLangSupport::~SwiftLangSupport() {
}
+// SWIFT_ENABLE_TENSORFLOW
+void SwiftLangSupport::setInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS) {
+ ASTMgr->setInMemoryOutputFileSystem(std::move(FS));
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
void SwiftLangSupport::globalConfigurationUpdated(
std::shared_ptr<GlobalConfig> Config) {
configureCompletionInstance(CompletionInst, Config);
@@ -945,6 +1019,42 @@
receiver(stats);
}
+// SWIFT_ENABLE_TENSORFLOW
+void SwiftLangSupport::codeComplete(
+ llvm::MemoryBuffer *InputBuf, unsigned Offset,
+ CodeCompletionConsumer &Consumer, ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) {
+ VFSOptions vfsOptions;
+ vfsOptions.name = "directly-passed-vfs";
+ auto *provider = static_cast<DirectlyPassedFileSystemProvider *>(
+ getFileSystemProvider(vfsOptions.name));
+ assert(provider);
+ auto options = provider->addFileSystem(FS);
+ vfsOptions.options =
+ std::make_unique<DirectlyPassedFileSystemProvider::Options>(options);
+ codeComplete(InputBuf, Offset, vfsOptions.options.get(), Consumer, Args,
+ std::move(vfsOptions));
+ provider->removeFileSystem(options);
+}
+// SWIFT_ENABLE_TENSORFLOW END
+
+void SwiftLangSupport::editorOpen(
+ StringRef Name, llvm::MemoryBuffer *Buf, EditorConsumer &Consumer,
+ ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) {
+ VFSOptions vfsOptions;
+ vfsOptions.name = "directly-passed-vfs";
+ auto *provider = static_cast<DirectlyPassedFileSystemProvider *>(
+ getFileSystemProvider(vfsOptions.name));
+ assert(provider);
+ auto options = provider->addFileSystem(FS);
+ vfsOptions.options =
+ std::make_unique<DirectlyPassedFileSystemProvider::Options>(options);
+ editorOpen(Name, Buf, Consumer, Args, std::move(vfsOptions));
+ provider->removeFileSystem(options);
+}
+
+
FileSystemProvider *SwiftLangSupport::getFileSystemProvider(StringRef Name) {
auto It = FileSystemProviders.find(Name);
if (It == FileSystemProviders.end())
diff --git a/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.h b/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.h
index 81f61f4..b89f9b3 100644
--- a/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.h
+++ b/tools/SourceKit/lib/SwiftLang/SwiftLangSupport.h
@@ -314,6 +314,10 @@
explicit SwiftLangSupport(SourceKit::Context &SKCtx);
~SwiftLangSupport();
+ // SWIFT_ENABLE_TENSORFLOW
+ void setInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS) override;
+
std::shared_ptr<NotificationCenter> getNotificationCenter() const {
return NotificationCtr;
}
@@ -632,6 +636,16 @@
void getStatistics(StatisticsReceiver) override;
+ // SWIFT_ENABLE_TENSORFLOW
+ void
+ codeComplete(llvm::MemoryBuffer *InputBuf, unsigned Offset,
+ CodeCompletionConsumer &Consumer, ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) override;
+
+ void editorOpen(StringRef Name, llvm::MemoryBuffer *Buf,
+ EditorConsumer &Consumer, ArrayRef<const char *> Args,
+ llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) override;
+
private:
swift::SourceFile *getSyntacticSourceFile(llvm::MemoryBuffer *InputBuf,
ArrayRef<const char *> Args,
diff --git a/tools/SourceKit/tools/sourcekitd-test/CMakeLists.txt b/tools/SourceKit/tools/sourcekitd-test/CMakeLists.txt
index 91c28f3..8ff4d20 100644
--- a/tools/SourceKit/tools/sourcekitd-test/CMakeLists.txt
+++ b/tools/SourceKit/tools/sourcekitd-test/CMakeLists.txt
@@ -7,6 +7,10 @@
TestOptions.cpp
LLVM_LINK_COMPONENTS option coverage lto
)
+# SWIFT_ENABLE_TENSORFLOW
+target_compile_definitions(sourcekitd-test PRIVATE
+ $<$<BOOL:${SWIFT_USE_SOURCEKIT_INPROC_LIBRARY}>:SWIFT_SOURCEKIT_USE_INPROC_LIBRARY>)
+# SWIFT_ENABLE_TENSORFLOW END
target_link_libraries(sourcekitd-test PRIVATE
SourceKitSupport
clangRewrite
diff --git a/tools/SourceKit/tools/sourcekitd-test/Options.td b/tools/SourceKit/tools/sourcekitd-test/Options.td
index 522cc38..c3ff8bb 100644
--- a/tools/SourceKit/tools/sourcekitd-test/Options.td
+++ b/tools/SourceKit/tools/sourcekitd-test/Options.td
@@ -134,6 +134,10 @@
HelpText<"Repeat the request n times">, MetaVarName<"<n>">;
def repeat_request_EQ : Joined<["-"], "repeat-request=">, Alias<repeat_request>;
+// SWIFT_ENABLE_TENSORFLOW
+def in_memory_clang_module_cache : Flag<["-"], "in-memory-clang-module-cache">,
+ HelpText<"Put the Clang module cache in memory">;
+
def vfs_files : CommaJoined<["-"], "vfs-files=">,
HelpText<"Injects a VFS into the request, overlaying files specified by the given <name>=<target> pairs over the real filesystem. Prefix destination with '@' to pass as sourcetext.">;
diff --git a/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp b/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp
index 7b49e63..488539a 100644
--- a/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp
+++ b/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp
@@ -355,6 +355,22 @@
}
break;
+ // SWIFT_ENABLE_TENSORFLOW
+ case OPT_in_memory_clang_module_cache:
+#ifdef SWIFT_SOURCEKIT_USE_INPROC_LIBRARY
+ InMemoryClangModuleCache = true;
+ break;
+#else
+ // The -in-memory-clang-module-cache option operates by making a function
+ // call to a function defined in the SourceKit InProc library
+ // (SourceKit::setGlobalInMemoryOutputFileSystem). So this option only
+ // works when sourcekitd-test is compiled using that library. It does not
+ // work when sourcekitd-test uses the XPC library.
+ llvm::errs() << "in-memory-clang-module-cache only supported when "
+ "SWIFT_SOURCEKIT_USE_INPROC_LIBRARY is set";
+ return true;
+#endif
+
case OPT_vfs_files:
VFSName = VFSName.getValueOr("in-memory-vfs");
for (const char *vfsFile : InputArg->getValues()) {
diff --git a/tools/SourceKit/tools/sourcekitd-test/TestOptions.h b/tools/SourceKit/tools/sourcekitd-test/TestOptions.h
index a0c6769..6f51ffc 100644
--- a/tools/SourceKit/tools/sourcekitd-test/TestOptions.h
+++ b/tools/SourceKit/tools/sourcekitd-test/TestOptions.h
@@ -116,6 +116,8 @@
bool SuppressDefaultConfigRequest = false;
llvm::Optional<unsigned> CompletionCheckDependencyInterval;
unsigned repeatRequest = 1;
+ // SWIFT_ENABLE_TENSORFLOW
+ bool InMemoryClangModuleCache;
struct VFSFile {
std::string path;
bool passAsSourceText;
diff --git a/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp b/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp
index 925fea4..d624b81 100644
--- a/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp
+++ b/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp
@@ -12,9 +12,15 @@
#include "sourcekitd/sourcekitd.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "sourcekitd/FileSystemProvider.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include "SourceKit/Support/Concurrency.h"
#include "TestOptions.h"
#include "swift/Demangling/ManglingMacros.h"
+// SWIFT_ENABLE_TENSORFLOW
+#include "clang/Basic/InMemoryOutputFileSystem.h"
+// SWIFT_ENABLE_TENSORFLOW END
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/Optional.h"
@@ -481,6 +487,16 @@
}
static int handleTestInvocation(TestOptions Opts, TestOptions &InitOpts) {
+ // SWIFT_ENABLE_TENSORFLOW
+#ifdef SWIFT_SOURCEKIT_USE_INPROC_LIBRARY
+ if (Opts.InMemoryClangModuleCache) {
+ SourceKit::setGlobalInMemoryOutputFileSystem(
+ new clang::InMemoryOutputFileSystem());
+ } else {
+ SourceKit::setGlobalInMemoryOutputFileSystem(nullptr);
+ }
+#endif
+
if (!Opts.JsonRequestPath.empty())
return handleJsonRequestPath(Opts.JsonRequestPath, Opts);
diff --git a/tools/SourceKit/tools/sourcekitd/include/sourcekitd/FileSystemProvider.h b/tools/SourceKit/tools/sourcekitd/include/sourcekitd/FileSystemProvider.h
new file mode 100644
index 0000000..e4aae1a
--- /dev/null
+++ b/tools/SourceKit/tools/sourcekitd/include/sourcekitd/FileSystemProvider.h
@@ -0,0 +1,34 @@
+//===--- FileSystemProvider.h - ---------------------------------*- C++ -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SOURCEKITD_FILESYSTEMPROVIDER_H
+#define LLVM_SOURCEKITD_FILESYSTEMPROVIDER_H
+
+#include "clang/Basic/InMemoryOutputFileSystem.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "sourcekitd.h"
+
+namespace SourceKit {
+
+/// Subsequent requests will write temporary output files to this filesystem
+/// rather than to the real filesystem.
+///
+/// Is not threadsafe.
+///
+/// \param FS may be null, which makes subsequent requests start writing
+/// temporary output files to the real filesystem again.
+void SOURCEKITD_PUBLIC setGlobalInMemoryOutputFileSystem(
+ llvm::IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS);
+
+} // namespace SourceKit
+
+#endif
diff --git a/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp b/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp
index d95425d..970c4e9 100644
--- a/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp
+++ b/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp
@@ -130,6 +130,12 @@
return *GlobalCtx;
}
+namespace SourceKit {
+void SOURCEKITD_PUBLIC setGlobalInMemoryOutputFileSystem(IntrusiveRefCntPtr<clang::InMemoryOutputFileSystem> FS) {
+ getGlobalContext().getSwiftLangSupport().setInMemoryOutputFileSystem(std::move(FS));
+}
+} // namespace SourceKit
+
static sourcekitd_response_t demangleNames(ArrayRef<const char *> MangledNames,
bool Simplified);
diff --git a/tools/libInMemoryFrontend/CMakeLists.txt b/tools/libInMemoryFrontend/CMakeLists.txt
new file mode 100644
index 0000000..7879beb
--- /dev/null
+++ b/tools/libInMemoryFrontend/CMakeLists.txt
@@ -0,0 +1,5 @@
+include_directories(
+ ${CMAKE_CURRENT_SOURCE_DIR}/include
+)
+
+add_subdirectory(lib)
diff --git a/tools/libInMemoryFrontend/include/libInMemoryFrontend/InMemoryFrontend.h b/tools/libInMemoryFrontend/include/libInMemoryFrontend/InMemoryFrontend.h
new file mode 100644
index 0000000..4d84fe5
--- /dev/null
+++ b/tools/libInMemoryFrontend/include/libInMemoryFrontend/InMemoryFrontend.h
@@ -0,0 +1,41 @@
+//===--- InMemoryFrontend.h - Frontend operations, in memory ----*- C++ -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SWIFT_LIBINMEMORYFRONTEND_INMEMORYFRONTEND_H
+#define SWIFT_LIBINMEMORYFRONTEND_INMEMORYFRONTEND_H
+
+#include "swift/Frontend/Frontend.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+namespace swift {
+namespace inmemoryfrontend {
+
+/// Given a fully setup CompilerInstance, configured to emit one module, runs
+/// the compilation and emits the module to a memory buffer, without writing to
+/// the filesystem. Emits error information to the CompilerInstance's
+/// DiagnosticEngine.
+///
+/// \param moduleBuffer will be set to a pointer to the serialized module
+/// buffer. nullptr is allowed, in which case the module
+/// will not be serialized.
+/// \param moduleDocBuffer will be set to a pointer to the serialized module
+/// doc buffer. nullptr is allowed, in which case the
+/// module doc will not be serialized.
+/// \return true on error.
+bool compileSwiftModule(CompilerInstance &CI,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleBuffer,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleDocBuffer);
+
+} // end namespace inmemoryfrontend
+} // end namespace swift
+
+#endif
diff --git a/tools/libInMemoryFrontend/lib/CMakeLists.txt b/tools/libInMemoryFrontend/lib/CMakeLists.txt
new file mode 100644
index 0000000..39ed350
--- /dev/null
+++ b/tools/libInMemoryFrontend/lib/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_swift_host_library(libInMemoryFrontend STATIC
+ InMemoryFrontend.cpp
+)
+
+target_link_libraries(libInMemoryFrontend PRIVATE
+ swiftFrontend
+ swiftSerialization
+ swiftSIL
+)
diff --git a/tools/libInMemoryFrontend/lib/InMemoryFrontend.cpp b/tools/libInMemoryFrontend/lib/InMemoryFrontend.cpp
new file mode 100644
index 0000000..d189b74
--- /dev/null
+++ b/tools/libInMemoryFrontend/lib/InMemoryFrontend.cpp
@@ -0,0 +1,46 @@
+//===--- InMemoryFrontend.cpp - Frontend operations, in memory --*- C++ -*-===//
+//
+// This source file is part of the Swift.org open source project
+//
+// Copyright (c) 2019 Apple Inc. and the Swift project authors
+// Licensed under Apache License v2.0 with Runtime Library Exception
+//
+// See https://swift.org/LICENSE.txt for license information
+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+//
+//===----------------------------------------------------------------------===//
+
+#include "libInMemoryFrontend/InMemoryFrontend.h"
+
+#include "swift/SIL/SILModule.h"
+#include "swift/Subsystems.h"
+
+namespace swift {
+namespace inmemoryfrontend {
+
+bool compileSwiftModule(CompilerInstance &CI,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleBuffer,
+ std::unique_ptr<llvm::MemoryBuffer> *moduleDocBuffer) {
+ CI.performSema();
+ if (CI.getDiags().hadAnyError())
+ return true;
+
+ auto SILMod = performASTLowering(CI.getMainModule(), CI.getSILTypes(),
+ CI.getSILOptions());
+ if (!SILMod)
+ return true;
+
+ SerializationOptions SerOpts;
+ SILMod->setSerializeSILAction([&]() {
+ serializeToMemory(CI.getMainModule(), SerOpts, moduleBuffer,
+ moduleDocBuffer, SILMod.get());
+ });
+
+ if (CI.performSILProcessing(SILMod.get()))
+ return true;
+
+ return false;
+}
+
+} // end namespace inmemoryfrontend
+} // end namespace swift
diff --git a/tools/sil-opt/SILOpt.cpp b/tools/sil-opt/SILOpt.cpp
index f4d2b44..1074604 100644
--- a/tools/sil-opt/SILOpt.cpp
+++ b/tools/sil-opt/SILOpt.cpp
@@ -133,6 +133,9 @@
ResourceDir("resource-dir",
llvm::cl::desc("The directory that holds the compiler resource files"));
+static llvm::cl::list<std::string>
+ExtraClangArgs("Xcc", llvm::cl::desc("Extra flags to pass to Clang."));
+
static llvm::cl::opt<std::string>
SDKPath("sdk", llvm::cl::desc("The path to the SDK for use with the clang "
"importer."),
@@ -219,7 +222,10 @@
static llvm::cl::opt<bool> EnableExperimentalDifferentiableProgramming(
"enable-experimental-differentiable-programming", llvm::cl::Hidden,
- llvm::cl::init(false),
+ // SWIFT_ENABLE_TENSORFLOW
+ // Use default value true on `tensorflow` branch.
+ llvm::cl::init(true),
+ // SWIFT_ENABLE_TENSORFLOW END
llvm::cl::desc("Enable experimental differentiable programming"));
/// Regular expression corresponding to the value given in one of the
@@ -337,6 +343,7 @@
// Set the module cache path. If not passed in we use the default swift module
// cache.
Invocation.getClangImporterOptions().ModuleCachePath = ModuleCachePath;
+ Invocation.getClangImporterOptions().ExtraArgs = ExtraClangArgs;
Invocation.setParseStdlib();
Invocation.getLangOptions().DisableAvailabilityChecking = true;
Invocation.getLangOptions().EnableAccessControl = false;
diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt
index 5c6daa8..6ca60a7 100644
--- a/unittests/CMakeLists.txt
+++ b/unittests/CMakeLists.txt
@@ -29,5 +29,7 @@
if(SWIFT_BUILD_SOURCEKIT)
add_subdirectory(SourceKit)
endif()
+
+ add_subdirectory(libInMemoryFrontend)
endif()
diff --git a/unittests/Syntax/SyntaxCollectionTests.cpp b/unittests/Syntax/SyntaxCollectionTests.cpp
index 3124120..c553577 100644
--- a/unittests/Syntax/SyntaxCollectionTests.cpp
+++ b/unittests/Syntax/SyntaxCollectionTests.cpp
@@ -131,6 +131,7 @@
ASSERT_TRUE(GottenArg2_1.hasSameIdentityAs(GottenArg2_2));
}
+/* SWIFT_ENABLE_TENSORFLOW - Turning off test (http://b/165034295)
TEST(SyntaxCollectionTests, removingFirst) {
ASSERT_DEATH({
SyntaxFactory::makeBlankTupleExprElementList().removingFirst();
@@ -148,6 +149,7 @@
List.print(OS);
ASSERT_EQ(OS.str().str(), "x: foo, x: foo");
}
+SWIFT_ENABLE_TENSORFLOW END */
TEST(SyntaxCollectionTests, inserting) {
auto Arg = getCannedArgument();
diff --git a/unittests/libInMemoryFrontend/CMakeLists.txt b/unittests/libInMemoryFrontend/CMakeLists.txt
new file mode 100644
index 0000000..341b155
--- /dev/null
+++ b/unittests/libInMemoryFrontend/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_swift_unittest(libInMemoryFrontendTests
+ InMemoryFrontendTests.cpp
+)
+
+target_link_libraries(libInMemoryFrontendTests PRIVATE
+ libInMemoryFrontend
+ swiftDriver
+)
+
+target_compile_definitions(libInMemoryFrontendTests PRIVATE
+ SWIFTLIB_DIR=\"${SWIFTLIB_DIR}\"
+)
+
+include_directories(
+ ${SWIFT_SOURCE_DIR}/tools/libInMemoryFrontend/include
+)
diff --git a/unittests/libInMemoryFrontend/InMemoryFrontendTests.cpp b/unittests/libInMemoryFrontend/InMemoryFrontendTests.cpp
new file mode 100644
index 0000000..3b0219c
--- /dev/null
+++ b/unittests/libInMemoryFrontend/InMemoryFrontendTests.cpp
@@ -0,0 +1,132 @@
+#include "libInMemoryFrontend/InMemoryFrontend.h"
+#include "swift/AST/DiagnosticConsumer.h"
+#include "swift/Driver/FrontendUtil.h"
+#include "swift/Frontend/Frontend.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace swift;
+
+class StreamDiagConsumer : public DiagnosticConsumer {
+ llvm::raw_ostream &OS;
+
+public:
+ StreamDiagConsumer(llvm::raw_ostream &OS) : OS(OS) {}
+
+ void handleDiagnostic(SourceManager &SM,
+ const DiagnosticInfo &Info) override {
+ switch (Info.Kind) {
+ case DiagnosticKind::Error:
+ OS << "error: ";
+ break;
+ case DiagnosticKind::Warning:
+ OS << "warning: ";
+ break;
+ case DiagnosticKind::Note:
+ OS << "note: ";
+ break;
+ case DiagnosticKind::Remark:
+ OS << "remark: ";
+ break;
+ }
+ DiagnosticEngine::formatDiagnosticText(OS, Info.FormatString,
+ Info.FormatArgs);
+ }
+};
+
+static StringRef getRuntimeLibPath() {
+ return llvm::sys::path::parent_path(SWIFTLIB_DIR);
+}
+
+class InMemoryFrontendTest : public ::testing::Test {
+protected:
+ InMemoryFrontendTest()
+ : MemFS(new llvm::vfs::InMemoryFileSystem()),
+ FS(new llvm::vfs::OverlayFileSystem(llvm::vfs::getRealFileSystem())),
+ ErrOS(ErrStr), DiagConsumer(ErrOS) {
+ FS->pushOverlay(MemFS);
+
+ CI.addDiagnosticConsumer(&DiagConsumer);
+ CI.getSourceMgr().setFileSystem(FS);
+ }
+
+ bool ParseArgsAndSetupInstance(llvm::ArrayRef<const char *> OrigArgs) {
+ SmallVector<const char *, 16> Args;
+ Args.push_back("-resource-dir");
+ Args.push_back(getRuntimeLibPath().data());
+ Args.append(OrigArgs.begin(), OrigArgs.end());
+
+ // Without this configuration option, the clang tries to emit object files
+ // for the modules that it compiles. To do this, it looks up the current
+ // triple in the llvm TargetRegistry. We have not initialized the
+ // TargetRegistry, so it fails.
+ Invocation.getClangImporterOptions().DetailedPreprocessingRecord = true;
+
+ bool ParseResult = driver::getSingleFrontendInvocationFromDriverArguments(
+ Args, CI.getDiags(), [&](ArrayRef<const char *> FrontendArgs) {
+ return Invocation.parseArgs(FrontendArgs, CI.getDiags());
+ });
+ if (ParseResult)
+ return true;
+
+ return CI.setup(Invocation);
+ }
+
+ llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> MemFS;
+ llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> FS;
+
+ SmallString<32> ErrStr;
+ llvm::raw_svector_ostream ErrOS;
+ StreamDiagConsumer DiagConsumer;
+
+ CompilerInstance CI;
+ CompilerInvocation Invocation;
+};
+
+TEST_F(InMemoryFrontendTest, SemaError) {
+ MemFS->addFile("/file1.swift", /*ModificationTime=*/0,
+ llvm::MemoryBuffer::getMemBuffer("let x: String = \"hello\"",
+ "/file1.swift"));
+ MemFS->addFile(
+ "/file2.swift", /*ModificationTime=*/0,
+ llvm::MemoryBuffer::getMemBuffer("let y: Int = x", "/file2.swift"));
+
+ const char *Args[] = {"/file1.swift", "/file2.swift"};
+ bool SetupResult = ParseArgsAndSetupInstance(Args);
+ ASSERT_FALSE(SetupResult) << ErrStr;
+
+ std::unique_ptr<llvm::MemoryBuffer> ModBuf;
+ std::unique_ptr<llvm::MemoryBuffer> ModDocBuf;
+ bool CompileResult =
+ inmemoryfrontend::compileSwiftModule(CI, &ModBuf, &ModDocBuf);
+ EXPECT_TRUE(CompileResult);
+ EXPECT_EQ("error: cannot convert value of type 'String' to specified type "
+ "'Int'",
+ ErrStr);
+}
+
+TEST_F(InMemoryFrontendTest, Success) {
+ MemFS->addFile("/file1.swift", /*ModificationTime=*/0,
+ llvm::MemoryBuffer::getMemBuffer("let x: String = \"hello\"",
+ "/file1.swift"));
+ MemFS->addFile(
+ "/file2.swift", /*ModificationTime=*/0,
+ llvm::MemoryBuffer::getMemBuffer("let y: String = x", "/file2.swift"));
+
+ const char *Args[] = {"/file1.swift", "/file2.swift"};
+ bool SetupResult = ParseArgsAndSetupInstance(Args);
+ ASSERT_FALSE(SetupResult) << ErrStr;
+
+ std::unique_ptr<llvm::MemoryBuffer> ModBuf;
+ std::unique_ptr<llvm::MemoryBuffer> ModDocBuf;
+ bool CompileResult =
+ inmemoryfrontend::compileSwiftModule(CI, &ModBuf, &ModDocBuf);
+ ASSERT_FALSE(CompileResult) << ErrStr;
+ ASSERT_TRUE(ModBuf);
+ ASSERT_TRUE(ModDocBuf);
+
+ EXPECT_EQ(serialization::Status::Valid,
+ serialization::validateSerializedAST(ModBuf->getBuffer()).status);
+}
diff --git a/utils/SwiftMathFunctions.py b/utils/SwiftMathFunctions.py
new file mode 100644
index 0000000..d60f8c6
--- /dev/null
+++ b/utils/SwiftMathFunctions.py
@@ -0,0 +1,91 @@
+# SWIFT_ENABLE_TENSORFLOW
+# These changes were part of `ElementaryFunctions`, which was removed from
+# apple/swift master branch and moved to apple/swift-numerics.
+# TF-1203 tracks eliminating these ad-hoc tensorflow branch changes.
+
+
+class SwiftMathFunction(object):
+ def __init__(self, name, kind=None, swiftName=None, args="x", comment=None,
+ platforms=None):
+ self.name = name
+ self.swiftName = swiftName if swiftName is not None else name
+ self.kind = kind if kind is not None else "library"
+ self.args = args
+ if comment is not None:
+ self.comment = comment
+ else:
+ self.comment = "/// The " + str(self.swiftName) + " function."
+ self.platforms = platforms
+
+ def params(self, prefix="", suffix=""):
+ return ", ".join(map(lambda a: prefix + a + suffix, self.args))
+
+ def decl(self, type):
+ return self.swiftName + "(" + self.params("_ ", ": " + type) + \
+ ") -> " + type
+
+ def freeDecl(self, constraint):
+ return self.swiftName + "<T>(" + self.params("_ ", ": T") + \
+ ") -> T where " + constraint
+
+ def impl(self, type):
+ if self.kind == "intrinsic":
+ builtin = "Builtin.int_" + self.name + "_FPIEEE" + str(type.bits)
+ return type.stdlib_name + "(" + builtin + "(" + \
+ self.params("", "._value") + "))"
+ return "_stdlib_" + self.name + type.cFuncSuffix + "(" + \
+ self.params() + ")"
+
+
+ElementaryFunctions = [
+ SwiftMathFunction(name="sqrt", kind="intrinsic", comment="""
+ /// The square root of `x`.
+ ///
+ /// For real types, if the argument is negative, either the result is NaN
+ /// or a precondition failure occurs. For complex types, this function has
+ /// a branch cut along the negative real axis.
+"""),
+ SwiftMathFunction(name="cos", kind="intrinsic", comment="""
+ /// The cosine of `x`.
+ ///
+ /// For real types, `x` is interpreted as an angle measured in radians.
+"""),
+ SwiftMathFunction(name="sin", kind="intrinsic", comment="""
+ /// The sine of `x`.
+ ///
+ /// For real types, `x` is interpreted as an angle measured in radians.
+"""),
+ SwiftMathFunction(name="tan",
+ comment="/// The tangent of `x`."),
+ SwiftMathFunction(name="acos"),
+ SwiftMathFunction(name="asin"),
+ SwiftMathFunction(name="atan"),
+ SwiftMathFunction(name="cosh"),
+ SwiftMathFunction(name="sinh"),
+ SwiftMathFunction(name="tanh"),
+ SwiftMathFunction(name="acosh"),
+ SwiftMathFunction(name="asinh"),
+ SwiftMathFunction(name="atanh"),
+ SwiftMathFunction(name="exp", kind="intrinsic"),
+ SwiftMathFunction(name="exp2", kind="intrinsic"),
+ SwiftMathFunction(name="exp10"),
+ SwiftMathFunction(name="expm1"),
+ SwiftMathFunction(name="log", kind="intrinsic"),
+ SwiftMathFunction(name="log2", kind="intrinsic"),
+ SwiftMathFunction(name="log10", kind="intrinsic"),
+ SwiftMathFunction(name="log1p"),
+ # SwiftMathFunction(name="pow", kind="intrinsic", args="xy"), Handled
+ # separately for edge cases.
+ # SwiftMathFunction(name="root", args="xn"), Handled separately for
+ # implementation.
+]
+
+RealFunctions = [
+ # SwiftMathFunction(name="atan2"), Handled separately for explicit
+ # argument labels.
+ SwiftMathFunction(name="erf"),
+ SwiftMathFunction(name="erfc"),
+ SwiftMathFunction(name="hypot", args="xy"),
+ SwiftMathFunction(name="tgamma", swiftName="gamma"),
+ # SwiftMathFunction(name="lgamma"), Handled separately for sign result.
+]
diff --git a/utils/build-presets.ini b/utils/build-presets.ini
index 18dba19..76879d7 100644
--- a/utils/build-presets.ini
+++ b/utils/build-presets.ini
@@ -2564,3 +2564,170 @@
mixin-preset=source_compat_suite_linux_base
debug
no-assertions
+
+# SWIFT_ENABLE_TENSORFLOW
+#===------------------------------------------------------------------------===#
+# TensorFlow Support
+#===------------------------------------------------------------------------===#
+
+# Note: presets that mixin "mixin_codesigning" require a Keychain Access
+# certificate called "lldb_codesign".
+# To create the cerificate, follow these instructions:
+# https://github.com/llvm-mirror/lldb/blob/master/docs/code-signing.txt
+[preset: mixin_codesigning]
+darwin-toolchain-application-cert=lldb_codesign
+
+[preset: tensorflow_osx_base]
+mixin-preset=buildbot_osx_package
+
+build-subdir=buildbot_osx
+compiler-vendor=none
+
+# Skip SwiftSyntax verification.
+# TODO: Remove this when `utils/gyb_syntax_support` has no diff with master.
+swiftsyntax-verify-generated-files=0
+
+# Skip benchmarks.
+skip-build-benchmarks
+
+# Target only macOS.
+# Note: these `skip-{i,tv,watch}os` flags override `{i,tv,watch}os` flags from
+# the `buildbot_osx_package` preset.
+skip-ios
+skip-tvos
+skip-watchos
+stdlib-deployment-targets=macosx-x86_64
+swift-primary-variant-sdk=OSX
+swift-primary-variant-arch=x86_64
+
+# Do not link libz3.
+llvm-cmake-options=-DLLVM_ENABLE_Z3_SOLVER=NO
+
+[preset: tensorflow_osx]
+mixin-preset=tensorflow_osx_base
+test
+test-optimized
+validation-test
+long-test
+stress-test
+
+[preset: tensorflow_osx,no_test]
+mixin-preset=tensorflow_osx_base
+skip-test-cmark
+skip-test-swift
+skip-test-swiftpm
+skip-test-swiftsyntax
+skip-test-llbuild
+skip-test-lldb
+skip-test-playgroundsupport
+
+[preset: tensorflow_osx,tensorflow_swift_apis]
+mixin-preset=tensorflow_osx
+tensorflow-swift-apis
+install-tensorflow-swift-apis
+
+[preset: tensorflow_osx,tensorflow_swift_apis,no_test]
+mixin-preset=tensorflow_osx,no_test
+tensorflow-swift-apis
+install-tensorflow-swift-apis
+
+[preset: tensorflow_osx,installer]
+mixin-preset=
+ tensorflow_osx
+ mixin_codesigning
+darwin-toolchain-installer-package=%(darwin_toolchain_installer_package)s
+
+[preset: tensorflow_osx,no_test,installer]
+mixin-preset=
+ tensorflow_osx,no_test
+ mixin_codesigning
+darwin-toolchain-installer-package=%(darwin_toolchain_installer_package)s
+
+[preset: tensorflow_osx,tensorflow_swift_apis,installer]
+mixin-preset=
+ tensorflow_osx,tensorflow_swift_apis
+ mixin_codesigning
+darwin-toolchain-installer-package=%(darwin_toolchain_installer_package)s
+
+[preset: tensorflow_osx,tensorflow_swift_apis,no_test,installer]
+mixin-preset=
+ tensorflow_osx,tensorflow_swift_apis,no_test
+ mixin_codesigning
+darwin-toolchain-installer-package=%(darwin_toolchain_installer_package)s
+
+[preset: tensorflow_linux]
+mixin-preset=
+ buildbot_linux
+
+# The swift-package-tests fail when run with python3.
+test-installable-package=
+
+# Do not link libz3.
+llvm-cmake-options=-DLLVM_ENABLE_Z3_SOLVER=NO
+
+[preset: tensorflow_linux,no_test]
+mixin-preset=
+ buildbot_linux,no_test
+
+# The swift-package-tests fail when run with python3.
+test-installable-package=
+
+[preset: tensorflow_linux,tensorflow_swift_apis]
+mixin-preset=tensorflow_linux
+tensorflow-swift-apis
+install-tensorflow-swift-apis
+
+[preset: tensorflow_linux,tensorflow_swift_apis,no_test]
+mixin-preset=tensorflow_linux,no_test
+tensorflow-swift-apis
+install-tensorflow-swift-apis
+
+#===------------------------------------------------------------------------===#
+# Swift for TensorFlow Preset
+# Tools: DebInfo and Assertions
+# stdlib: DebInfo and Assertions
+#===------------------------------------------------------------------------===#
+[preset: tensorflow_linux,tools=DA,stdlib=DA]
+mixin-preset=mixin_linux_installation
+
+### From: buildbot_linux
+build-subdir=buildbot_linux
+lldb
+foundation
+libdispatch
+lit-args=-v
+dash-dash
+skip-test-lldb
+install-foundation
+install-libdispatch
+reconfigure
+
+### From: buildbot_incremental,tools=DA,stdlib=DA,build
+release-debuginfo
+debug-llvm
+debug-swift
+swift-stdlib-build-type=RelWithDebInfo
+swift-stdlib-enable-assertions=true
+
+# Enables Tensorflow and runs all the Swift tests, except for long tests. Sets
+# certain flags that are necessary for tests to pass.
+[preset: tensorflow_test]
+mixin-preset=mixin_lightweight_assertions
+release
+test
+test-optimized
+validation-test
+
+# Enables Tensorflow and runs all the Swift tests (including long tests). Sets
+# certain flags that are necessary for tests to pass.
+[preset: tensorflow_test,long_test]
+mixin-preset=tensorflow_test
+long-test
+stress-test
+
+# Enables Tensorflow with ASan support, and runs all the Swift tests, except for
+# long tests. Sets certain flags that are necessary for tests to pass.
+[preset: tensorflow_test,asan]
+mixin-preset=tensorflow_test
+enable-asan
+# SWIFT_ENABLE_TENSORFLOW END
diff --git a/utils/build-script b/utils/build-script
index 18bbe73..e55966e 100755
--- a/utils/build-script
+++ b/utils/build-script
@@ -220,6 +220,7 @@
'build_indexstoredb',
'build_playgroundsupport',
'build_sourcekitlsp',
+ 'build_tensorflow_swift_apis',
'build_toolchainbenchmarks',
'build_swift_inspect',
'tsan_libdispatch_test',
@@ -272,6 +273,7 @@
args.build_foundation or
args.build_indexstoredb or
args.build_sourcekitlsp or
+ args.build_tensorflow_swift_apis or
args.cmake_generator == 'Ninja'
)
if ninja_required and toolchain.ninja is None:
@@ -832,6 +834,8 @@
product_classes.append(products.IndexStoreDB)
if self.args.build_playgroundsupport:
product_classes.append(products.PlaygroundSupport)
+ if self.args.build_tensorflow_swift_apis:
+ product_classes.append(products.TensorFlowSwiftAPIs)
if self.args.build_sourcekitlsp:
product_classes.append(products.SourceKitLSP)
if self.args.build_toolchainbenchmarks:
diff --git a/utils/build-script-impl b/utils/build-script-impl
index 19b92b5..005330d 100755
--- a/utils/build-script-impl
+++ b/utils/build-script-impl
@@ -1157,7 +1157,9 @@
# llbuild and XCTest depend on Foundation, so Foundation must
# be added to the list of build products first.
[[ "${SKIP_BUILD_FOUNDATION}" ]] || PRODUCTS+=(foundation)
-[[ "${SKIP_BUILD_STATIC_FOUNDATION}" ]] || PRODUCTS+=(foundation_static)
+# SWIFT_ENABLE_TENSORFLOW
+# TODO(TF-490): Reenable this.
+# [[ "${SKIP_BUILD_STATIC_FOUNDATION}" ]] || PRODUCTS+=(foundation_static)
[[ "${SKIP_BUILD_LLBUILD}" ]] || PRODUCTS+=(llbuild)
[[ "${SKIP_BUILD_XCTEST}" ]] || PRODUCTS+=(xctest)
@@ -1795,6 +1797,10 @@
-DSWIFT_TOOLS_ENABLE_LTO:STRING="${SWIFT_TOOLS_ENABLE_LTO}"
-DSWIFT_BUILD_RUNTIME_WITH_HOST_COMPILER:BOOL=$(true_false "${BUILD_RUNTIME_WITH_HOST_COMPILER}")
-DLIBDISPATCH_CMAKE_BUILD_TYPE:STRING="${LIBDISPATCH_BUILD_TYPE}"
+ # SWIFT_ENABLE_TENSORFLOW
+ # CMake options specific to `tensorflow` branch here.
+ -DSWIFT_ENABLE_TENSORFLOW:BOOL=TRUE
+ # SWIFT_ENABLE_TENSORFLOW END
"${swift_cmake_options[@]}"
)
@@ -2543,7 +2549,7 @@
with_pushd ${lldb_build_dir} \
call ${NINJA_BIN} -j ${BUILD_JOBS} lldb-test-deps
with_pushd ${results_dir} \
- call "${llvm_build_dir}/bin/llvm-lit" \
+ call "/usr/bin/env" "python3" "${llvm_build_dir}/bin/llvm-lit" \
"${lldb_build_dir}/test" \
${LLVM_LIT_ARGS} ${LLVM_LIT_FILTER_ARG}
echo "--- Rerun LLDB Swift tests (using only DWARFImporter) ---"
@@ -2558,7 +2564,8 @@
"${LLDB_TEST_SWIFT_COMPATIBILITY}"
DOTEST_ARGS="-G swift-history --swift-compiler \"${LLDB_TEST_SWIFT_COMPATIBILITY}\""
with_pushd ${results_dir} \
- call "${llvm_build_dir}/bin/llvm-lit" \
+ # SWIFT_ENABLE_TENSORFLOW: use python3 to launch llvm-lit below.
+ call "/usr/bin/env" "python3" "${llvm_build_dir}/bin/llvm-lit" \
"${lldb_build_dir}/test" \
${LLVM_LIT_ARGS} \
--param dotest-args="${DOTEST_ARGS}" \
@@ -2739,6 +2746,10 @@
echo "--- Running tests for ${product} ---"
for target in "${results_targets[@]}"; do
+ # TODO (TF-1045): Ninja/Build/strip-colors.ninja in llbuild tests is spuriously failing
+ if [[ "${product}" == "llbuild" ]]; then
+ continue
+ fi
if [[ "${target}" != "" ]]; then
echo "--- ${target} ---"
trap "tests_busted ${product} '(${target})'" ERR
diff --git a/utils/build-toolchain-tensorflow b/utils/build-toolchain-tensorflow
new file mode 100755
index 0000000..5809599
--- /dev/null
+++ b/utils/build-toolchain-tensorflow
@@ -0,0 +1,194 @@
+#!/usr/bin/env bash
+#
+# SWIFT_ENABLE_TENSORFLOW
+#
+# utils/build-toolchain-tensorflow - documents process for building a toolchain
+#
+# This source file is part of the Swift.org open source project
+#
+# Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+# Licensed under Apache License v2.0 with Runtime Library Exception
+#
+# See https://swift.org/LICENSE.txt for license information
+# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+
+function usage() {
+ echo "$0 [OPTIONS]"
+ echo ""
+ echo "OPTIONS"
+ echo ""
+ echo "-h --help"
+ echo "Show help information."
+ echo ""
+ echo "-n --dry-run"
+ echo "Do a dry run."
+ echo ""
+ echo "-t --test"
+ echo "Run tests."
+ echo ""
+ echo "-r <version>, --release <version>"
+ echo "Build a release toolchain with the specified version name."
+ echo ""
+ echo "--tensorflow-swift-apis"
+ echo "Build and install the tensorflow-swift-apis library in the toolchain. (Default)."
+ echo ""
+ echo "--no-tensorflow-swift-apis"
+ echo "Do not build and install the tensorflow-swift-apis library in the toolchain."
+ echo ""
+ if [[ "$(uname -s)" == "Darwin" ]] ; then
+ echo "-p --pkg"
+ echo "Build a installer package."
+ echo ""
+ fi
+}
+
+cd "$(dirname $0)/.." || exit
+SRC_DIR=$PWD
+
+# Set defaults.
+DRY_RUN=
+BUNDLE_PREFIX=
+INSTALLER_PACKAGE=
+SWIFT_PACKAGE_BASE=
+SWIFT_PACKAGE_TENSORFLOW_SWIFT_APIS=,tensorflow_swift_apis
+SWIFT_PACKAGE_NOTEST=
+SWIFT_PACKAGE_INSTALLER=
+S4TF_RELEASE_VERSION=
+
+if [[ -z ${SWIFT_PACKAGE} ]]; then
+ case $(uname -s) in
+ Darwin)
+ SWIFT_PACKAGE_BASE=tensorflow_osx
+ SWIFT_PACKAGE_NOTEST=,no_test
+ ;;
+ Linux)
+ SWIFT_PACKAGE_BASE=tensorflow_linux
+ SWIFT_PACKAGE_NOTEST=,no_test
+ ;;
+ *)
+ echo "Unrecognised platform $(uname -s)"
+ exit 1
+ ;;
+ esac
+fi
+
+# Process command line arguments.
+FIRST_ARG_PROCESSED=0
+while [ $# -ne 0 ]; do
+ case "$1" in
+ -n|--dry-run)
+ DRY_RUN="-n"
+ ;;
+ -t|--test)
+ SWIFT_PACKAGE_NOTEST=
+ ;;
+ -h|--help)
+ usage
+ exit 0
+ ;;
+ -r|--release)
+ shift
+ S4TF_RELEASE_VERSION=$1
+ if [ -z "$S4TF_RELEASE_VERSION" ]; then
+ echo "Missing release version name after --release. See --help."
+ exit 1
+ fi
+ ;;
+ -p|--pkg)
+ INSTALLER_PACKAGE=1
+ if [ "$(uname -s)" == "Darwin" ]; then
+ SWIFT_PACKAGE_INSTALLER=,installer
+ else
+ echo "--pkg is not supported on \"$(uname -s)\". See --help."
+ exit 1
+ fi
+ ;;
+ -g|--gpu)
+ # Warn, for backwards compatibility with existing callers.
+ # TODO: Remove when we have fixed up all callers.
+ echo "Warning: --gpu deprecated."
+ ;;
+ -x|--x10)
+ # Warn, for backwards compatibility with existing callers.
+ # TODO: Remove when we have fixed up all callers.
+ echo "Warning: --x10 deprecated."
+ ;;
+ --tensorflow-swift-apis)
+ SWIFT_PACKAGE_TENSORFLOW_SWIFT_APIS=,tensorflow_swift_apis
+ ;;
+ --no-tensorflow-swift-apis)
+ SWIFT_PACKAGE_TENSORFLOW_SWIFT_APIS=
+ ;;
+ *)
+ if [ ${FIRST_ARG_PROCESSED} -ne 0 ]; then
+ echo "Unrecognised argument \"$1\""
+ exit 1
+ fi
+ ;;
+ esac
+ FIRST_ARG_PROCESSED=1
+ shift
+done
+
+SWIFT_PACKAGE="${SWIFT_PACKAGE_BASE}${SWIFT_PACKAGE_TENSORFLOW_SWIFT_APIS}${SWIFT_PACKAGE_NOTEST}${SWIFT_PACKAGE_INSTALLER}"
+
+# Get host name.
+HOST=
+if [ "$(uname -s)" == "Darwin" ]; then
+ HOST=osx
+elif [ "$(uname -s)" == "Linux" ]; then
+ linux_platform="$(lsb_release -i | cut -f2 | tr '[:upper:]' '[:lower:]')"
+ linux_version="$(lsb_release -r | cut -f2)"
+ HOST="${linux_platform}${linux_version}"
+fi
+
+
+# Report the commands being run.
+set -x
+YEAR=$(date +"%Y")
+MONTH=$(date +"%m")
+DAY=$(date +"%d")
+TOOLCHAIN_VERSION_DATE="5.0.${YEAR}${MONTH}${DAY}"
+
+# Set toolchain name based on whether this is a normal or release toolchain.
+if [[ -z ${S4TF_RELEASE_VERSION} ]]; then
+ TOOLCHAIN_NAME="swift-tensorflow-DEVELOPMENT-${YEAR}-${MONTH}-${DAY}-a"
+ DISPLAY_NAME_SHORT="Swift for TensorFlow Development Snapshot"
+ DISPLAY_NAME="${DISPLAY_NAME_SHORT} ${YEAR}-${MONTH}-${DAY}"
+else
+ TOOLCHAIN_NAME="swift-tensorflow-RELEASE-${S4TF_RELEASE_VERSION}"
+ DISPLAY_NAME_SHORT="Swift for TensorFlow ${S4TF_RELEASE_VERSION} Release"
+ DISPLAY_NAME="${DISPLAY_NAME_SHORT}"
+fi
+
+ARCHIVE="${TOOLCHAIN_NAME}-${HOST}.tar.gz"
+SYM_ARCHIVE="${TOOLCHAIN_NAME}-osx-symbols.tar.gz"
+BUNDLE_PREFIX=com.google.swift
+BUNDLE_IDENTIFIER="${BUNDLE_PREFIX}.${YEAR}${MONTH}${DAY}"
+
+SWIFT_INSTALLABLE_PACKAGE="${SRC_DIR}/${ARCHIVE}"
+SWIFT_INSTALL_DIR="${SRC_DIR}/swift-nightly-install"
+SWIFT_INSTALL_SYMROOT="${SRC_DIR}/swift-nightly-symroot"
+SWIFT_TOOLCHAIN_DIR="/Library/Developer/Toolchains/${TOOLCHAIN_NAME}.xctoolchain"
+SYMBOLS_PACKAGE="${SRC_DIR}/${SYM_ARCHIVE}"
+DRY_RUN="${DRY_RUN}"
+
+if [ ${INSTALLER_PACKAGE} ]; then
+ INSTALLER_PACKAGE="darwin_toolchain_installer_package=${TOOLCHAIN_NAME}-osx.pkg"
+fi
+
+./utils/build-script ${DRY_RUN} --preset="${SWIFT_PACKAGE}" \
+ --cmake-c-launcher=`which sccache` \
+ --cmake-cxx-launcher=`which sccache` \
+ install_destdir="${SWIFT_INSTALL_DIR}" \
+ installable_package="${SWIFT_INSTALLABLE_PACKAGE}" \
+ install_toolchain_dir="${SWIFT_TOOLCHAIN_DIR}" \
+ install_symroot="${SWIFT_INSTALL_SYMROOT}" \
+ symbols_package="${SYMBOLS_PACKAGE}" \
+ darwin_toolchain_bundle_identifier="${BUNDLE_IDENTIFIER}" \
+ darwin_toolchain_display_name="${DISPLAY_NAME}" \
+ darwin_toolchain_display_name_short="${DISPLAY_NAME_SHORT}" \
+ darwin_toolchain_xctoolchain_name="${TOOLCHAIN_NAME}" \
+ darwin_toolchain_version="${TOOLCHAIN_VERSION_DATE}" \
+ darwin_toolchain_alias="Swift for TensorFlow" \
+ ${INSTALLER_PACKAGE}
diff --git a/utils/build_swift/build_swift/driver_arguments.py b/utils/build_swift/build_swift/driver_arguments.py
index acf339a..0045982 100644
--- a/utils/build_swift/build_swift/driver_arguments.py
+++ b/utils/build_swift/build_swift/driver_arguments.py
@@ -633,6 +633,12 @@
toggle_true('install_playgroundsupport'),
help='install playground support')
+ option('--tensorflow-swift-apis', store_true('build_tensorflow_swift_apis'),
+ help='build TensorFlow Swift APIs')
+ option('--install-tensorflow-swift-apis',
+ store_true('install_tensorflow_swift_apis'),
+ help='install TensorFlow Swift APIs')
+
option('--build-ninja', toggle_true,
help='build the Ninja tool')
diff --git a/utils/build_swift/tests/build_swift/test_presets.py b/utils/build_swift/tests/build_swift/test_presets.py
index 2668d7a..29237b0 100644
--- a/utils/build_swift/tests/build_swift/test_presets.py
+++ b/utils/build_swift/tests/build_swift/test_presets.py
@@ -36,6 +36,10 @@
'darwin_toolchain_display_name_short': 'DispalyNameShort',
'darwin_toolchain_version': '1.0',
'darwin_toolchain_xctoolchain_name': 'default',
+ # SWIFT_ENABLE_TENSORFLOW
+ 'darwin_toolchain_installer_package': '/tmp/install/swift-installer.pkg',
+ # SWIFT_ENABLE_TENSORFLOW
+ 'darwin_toolchain_application_cert': 'cert_name',
'extra_swift_args': '',
'install_destdir': '/tmp/install',
'install_symroot': '/tmp/install/symroot',
diff --git a/utils/build_swift/tests/expected_options.py b/utils/build_swift/tests/expected_options.py
index e0381cb..e99c152 100644
--- a/utils/build_swift/tests/expected_options.py
+++ b/utils/build_swift/tests/expected_options.py
@@ -90,6 +90,7 @@
'build_swiftpm': False,
'build_swift_driver': False,
'build_swiftsyntax': False,
+ 'build_tensorflow_swift_apis': False,
'build_libparser_only': False,
'build_skstresstester': False,
'build_swiftformat': False,
@@ -106,6 +107,7 @@
'install_sourcekitlsp': False,
'install_skstresstester': False,
'install_swiftevolve': False,
+ 'install_tensorflow_swift_apis': False,
'build_toolchainbenchmarks': False,
'build_tvos': True,
'build_tvos_device': False,
@@ -470,6 +472,9 @@
SetTrueOption('--playgroundsupport', dest='build_playgroundsupport'),
SetTrueOption('--install-playgroundsupport',
dest='install_playgroundsupport'),
+ SetTrueOption('--tensorflow-swift-apis', dest='build_tensorflow_swift_apis'),
+ SetTrueOption('--install-tensorflow-swift-apis',
+ dest='install_tensorflow_swift_apis'),
SetTrueOption('--skip-build'),
SetTrueOption('--swiftpm', dest='build_swiftpm'),
SetTrueOption('--swift-driver', dest='build_swift_driver'),
diff --git a/utils/gyb_syntax_support/AttributeNodes.py b/utils/gyb_syntax_support/AttributeNodes.py
index 7d5bdab..4eab962 100644
--- a/utils/gyb_syntax_support/AttributeNodes.py
+++ b/utils/gyb_syntax_support/AttributeNodes.py
@@ -34,6 +34,8 @@
# | availability-spec-list
# | specialize-attr-spec-list
# | implements-attr-arguments
+ # | differentiable-attr-arguments
+ # | derivative-registration-attr-arguments
# | named-attribute-string-argument
# )? ')'?
Node('Attribute', kind='Syntax',
diff --git a/utils/gyb_syntax_support/Token.py b/utils/gyb_syntax_support/Token.py
index 6a2eb85..4ef350c 100644
--- a/utils/gyb_syntax_support/Token.py
+++ b/utils/gyb_syntax_support/Token.py
@@ -337,7 +337,6 @@
text=')', classification='StringInterpolationAnchor',
serialization_code=101),
Misc('Yield', 'kw_yield', serialization_code=116, text='yield'),
-
]
SYNTAX_TOKEN_MAP = {token.name + 'Token': token for token in SYNTAX_TOKENS}
diff --git a/utils/swift_build_support/swift_build_support/products/__init__.py b/utils/swift_build_support/swift_build_support/products/__init__.py
index 169fe50..c5ec249 100644
--- a/utils/swift_build_support/swift_build_support/products/__init__.py
+++ b/utils/swift_build_support/swift_build_support/products/__init__.py
@@ -31,6 +31,7 @@
from .swiftinspect import SwiftInspect
from .swiftpm import SwiftPM
from .swiftsyntax import SwiftSyntax
+from .tensorflow import TensorFlowSwiftAPIs
from .tsan_libdispatch import TSanLibDispatch
from .xctest import XCTest
@@ -51,6 +52,7 @@
'SwiftInspect',
'SwiftPM',
'SwiftDriver',
+ 'TensorFlowSwiftAPIs',
'XCTest',
'SwiftSyntax',
'SKStressTester',
diff --git a/utils/swift_build_support/swift_build_support/products/tensorflow.py b/utils/swift_build_support/swift_build_support/products/tensorflow.py
new file mode 100644
index 0000000..0af5e93
--- /dev/null
+++ b/utils/swift_build_support/swift_build_support/products/tensorflow.py
@@ -0,0 +1,113 @@
+# swift_build_support/products/tensorflow.py --------------------*- python -*-
+#
+# This source file is part of the Swift.org open source project
+#
+# Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
+# Licensed under Apache License v2.0 with Runtime Library Exception
+#
+# See https://swift.org/LICENSE.txt for license information
+# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
+#
+# ----------------------------------------------------------------------------
+
+import os
+
+from . import cmark
+from . import foundation
+from . import libcxx
+from . import libdispatch
+from . import libicu
+from . import llbuild
+from . import llvm
+from . import product
+from . import swift
+from . import swiftpm
+from . import xctest
+from .. import shell
+from .. import targets
+
+
+class TensorFlowSwiftAPIs(product.Product):
+ @classmethod
+ def product_source_name(cls):
+ return "tensorflow-swift-apis"
+
+ @classmethod
+ def is_build_script_impl_product(cls):
+ return False
+
+ def should_build(self, host_target):
+ return self.args.build_tensorflow_swift_apis
+
+ def build(self, host_target):
+ toolchain_path = targets.toolchain_path(self.args.install_destdir,
+ self.args.install_prefix)
+ swiftc = os.path.join(toolchain_path, 'bin', 'swiftc')
+
+ # FIXME: this is a workaround for CMake <3.16 which does not correctly
+ # generate the build rules if you are not in the build directory. As a
+ # result, we need to create the build tree before we can use it and
+ # change into it.
+ try:
+ shell.makedirs(self.build_dir)
+ except OSError:
+ pass
+
+ # SWIFT_ENABLE_TENSORFLOW
+ target = ''
+ if host_target.startswith('macosx'):
+ target = '-DCMAKE_Swift_COMPILER_TARGET=x86_64-apple-macosx10.13'
+ # SWIFT_ENABLE_TENSORFLOW END
+
+ with shell.pushd(self.build_dir):
+ shell.call([
+ self.toolchain.cmake,
+ '-G', 'Ninja',
+ '-D', 'BUILD_SHARED_LIBS=YES',
+ '-D', 'CMAKE_INSTALL_PREFIX={}'.format(
+ self.install_toolchain_path()),
+ '-D', 'CMAKE_MAKE_PROGRAM={}'.format(self.toolchain.ninja),
+ '-D', 'CMAKE_Swift_COMPILER={}'.format(swiftc),
+ # SWIFT_ENABLE_TENSORFLOW
+ target,
+ '-D', 'BUILD_TESTING={}'.format(
+ 'NO' if host_target.startswith('macosx') else 'YES'
+ ),
+ '-D', 'BUILD_X10=YES',
+ # SWIFT_ENABLE_TENSORFLOW END
+ '-B', self.build_dir,
+ '-S', self.source_dir,
+ ])
+ shell.call([
+ self.toolchain.cmake,
+ '--build', self.build_dir,
+ ])
+
+ def should_test(self, host_target):
+ return False
+
+ def test(self, host_target):
+ pass
+
+ def should_install(self, host_target):
+ return self.args.build_tensorflow_swift_apis
+
+ def install(self, host_target):
+ shell.call([
+ self.toolchain.cmake,
+ '--build', self.build_dir,
+ '--target', 'install',
+ ])
+
+ @classmethod
+ def get_dependencies(cls):
+ return [cmark.CMark,
+ llvm.LLVM,
+ libcxx.LibCXX,
+ libicu.LibICU,
+ swift.Swift,
+ libdispatch.LibDispatch,
+ foundation.Foundation,
+ xctest.XCTest,
+ llbuild.LLBuild,
+ swiftpm.SwiftPM]
diff --git a/utils/toolchain-installer b/utils/toolchain-installer
index 28fc10f..d4674a3 100755
--- a/utils/toolchain-installer
+++ b/utils/toolchain-installer
@@ -19,6 +19,12 @@
DARWIN_TOOLCHAIN_VERSION=$6
DARWIN_SCRIPTS=$7
+# SWIFT_ENABLE_TENSORFLOW
+SIGN_ARGUMENT=
+if [ -n "${DARWIN_INSTALLER_CERT}" ]; then
+ SIGN_ARGUMENT="--sign ${DARWIN_INSTALLER_CERT}"
+fi
+
pkgbuild --root "${TOOLCHAIN_PREFIX}" --install-location "${DARWIN_TOOLCHAIN_INSTALL_LOCATION}" "${DARWIN_INSTALLER_PACKAGE}" \
- --version "${DARWIN_TOOLCHAIN_VERSION}" --identifier "${DARWIN_BUNDLE_IDENTIFIER}" --sign "${DARWIN_INSTALLER_CERT}" \
+ --version "${DARWIN_TOOLCHAIN_VERSION}" --identifier "${DARWIN_BUNDLE_IDENTIFIER}" ${SIGN_ARGUMENT} \
--scripts "${DARWIN_SCRIPTS}"
diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json
index d499ace..1fa1f8a 100644
--- a/utils/update_checkout/update-checkout-config.json
+++ b/utils/update_checkout/update-checkout-config.json
@@ -32,17 +32,24 @@
"remote": { "id": "apple/swift-xcode-playground-support" } },
"ninja": {
"remote": { "id": "ninja-build/ninja" } },
+ "tensorflow": {
+ "remote": { "id": "tensorflow/tensorflow" } },
+ "tensorflow-swift-apis": {
+ "remote": { "id": "tensorflow/swift-apis" } },
"icu": {
"remote": { "id": "unicode-org/icu" },
"platforms": [ "Linux" ]
},
"yams": {
- "remote": { "id": "jpsim/Yams" }
+ "remote": { "id": "asuhan/Yams" }
},
"cmake": {
"remote": { "id": "KitWare/CMake" },
"platforms": [ "Linux" ]
},
+ "tensorflow-swift-apis": {
+ "remote": { "id": "tensorflow/swift-apis" }
+ },
"indexstore-db": {
"remote": { "id": "apple/indexstore-db" } },
"sourcekit-lsp": {
@@ -78,7 +85,8 @@
"cmake": "v3.16.5",
"indexstore-db": "master",
"sourcekit-lsp": "master",
- "swift-format": "master"
+ "swift-format": "master",
+ "tensorflow-swift-apis": "master"
}
},
"next" : {
@@ -108,7 +116,8 @@
"cmake": "v3.16.5",
"indexstore-db": "master",
"sourcekit-lsp": "master",
- "swift-format": "master"
+ "swift-format": "master",
+ "tensorflow-swift-apis": "master"
}
},
"swift-3.0-branch" : {
@@ -304,7 +313,8 @@
"cmake": "v3.16.5",
"indexstore-db": "release/5.3",
"sourcekit-lsp": "release/5.3",
- "swift-format": "master"
+ "swift-format": "master",
+ "tensorflow-swift-apis": "master"
}
},
"master-rebranch": {
@@ -333,6 +343,36 @@
"sourcekit-lsp": "master",
"swift-format": "master"
}
+ },
+
+ "tensorflow": {
+ "aliases": ["tensorflow"],
+ "repos": {
+ "llvm-project": "14a10d635b229fb3d8fc0d80340f98430b46fcff",
+ "swift": "tensorflow",
+ "cmark": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "llbuild": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-tools-support-core": "0.1.8",
+ "swiftpm": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-argument-parser": "0.3.0",
+ "swift-driver": "7a441561a176f6eac2953aea2c395cc4e683f820",
+ "swift-syntax": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-stress-tester": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-corelibs-xctest": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-corelibs-foundation": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-corelibs-libdispatch": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-integration-tests": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "swift-xcode-playground-support": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "ninja": "release",
+ "icu": "release-65-1",
+ "yams": "3.0.1-patched",
+ "indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2020-08-31-a",
+ "PythonKit": "master",
+ "swift-format": "master",
+ "tensorflow": "v2.3.0",
+ "tensorflow-swift-apis": "master"
+ }
}
}
}
diff --git a/validation-test/SIL/parse_stdlib.sil b/validation-test/SIL/parse_stdlib.sil
index 7b37365..d7c898a 100644
--- a/validation-test/SIL/parse_stdlib.sil
+++ b/validation-test/SIL/parse_stdlib.sil
@@ -1,5 +1,4 @@
-// RUN: rm -f %t.*
-// RUN: %target-sil-opt -enable-sil-verify-all=true -sil-disable-ast-dump %platform-module-dir/Swift.swiftmodule/%target-swiftmodule-name -module-name=Swift -o %t.sil || %target-sil-opt -enable-sil-verify-all=true -sil-disable-ast-dump %platform-module-dir/Swift.swiftmodule -module-name=Swift -o %t.sil
-// RUN: %target-sil-opt -enable-sil-verify-all=true %t.sil > /dev/null
+// TODO(SR-9714): Fix this and re-enable.
+// RUN: echo "disabled"
// REQUIRES: long_test
// REQUIRES: nonexecutable_test
diff --git a/validation-test/SIL/verify_all_overlays.py b/validation-test/SIL/verify_all_overlays.py
index 70c0ce6..8c23ab3 100755
--- a/validation-test/SIL/verify_all_overlays.py
+++ b/validation-test/SIL/verify_all_overlays.py
@@ -1,12 +1,17 @@
#!/usr/bin/python
-# RUN: ${python} %s %target-swiftmodule-name %platform-sdk-overlay-dir \
-# RUN: %target-sil-opt -sdk %sdk -enable-sil-verify-all \
-# RUN: -F %sdk/System/Library/PrivateFrameworks \
-# RUN: -F "%xcode-extra-frameworks-dir"
+# TODO(TF-491): Re-enable.
+# RUN: echo "disabled"
+# UN: ${python} %s %target-swiftmodule-name %platform-sdk-overlay-dir \
+# UN: %target-sil-opt -sdk %sdk -enable-sil-verify-all \
+# UN: -F %sdk/System/Library/PrivateFrameworks \
+# UN: -F "%xcode-extra-frameworks-dir"
# REQUIRES: long_test
# REQUIRES: nonexecutable_test
+# TODO(TF-491): Re-enable XFAIL.
+# XFAI: OS=macosx
+# https://bugs.swift.org/browse/SR-9847
from __future__ import print_function
diff --git a/validation-test/SIL/verify_all_overlays.sil b/validation-test/SIL/verify_all_overlays.sil
new file mode 100644
index 0000000..1a805fe
--- /dev/null
+++ b/validation-test/SIL/verify_all_overlays.sil
@@ -0,0 +1,7 @@
+// TODO(SR-9715): Fix this and re-enable.
+// RUN: echo "disabled"
+
+// CHECK-NOT: Unknown
+
+// REQUIRES: long_test
+// REQUIRES: nonexecutable_test
diff --git a/validation-test/Sema/type_checker_perf/fast/rdar19915443.swift.gyb b/validation-test/Sema/type_checker_perf/fast/rdar19915443.swift.gyb
new file mode 100644
index 0000000..2892e3e
--- /dev/null
+++ b/validation-test/Sema/type_checker_perf/fast/rdar19915443.swift.gyb
@@ -0,0 +1,10 @@
+// SWIFT_ENABLE_TENSORFLOW
+// UNSUPPORTED: macosx
+// RUN: %scale-test --begin 7 --end 15 --step 1 --select NumLeafScopes %s -Xfrontend=-swift-version -Xfrontend=5 -Xfrontend=-solver-disable-shrink -Xfrontend=-disable-constraint-solver-performance-hacks -Xfrontend=-solver-enable-operator-designated-types
+// REQUIRES: OS=macosx
+// REQUIRES: asserts
+let a = [0]
+let d = a[0] * 1
+%for i in range(0, N):
+ + a[0] * 1
+%end
diff --git a/validation-test/Sema/type_checker_perf/slow/nil_coalescing.swift.gyb b/validation-test/Sema/type_checker_perf/slow/nil_coalescing.swift.gyb
index f3f401f..ade43bc 100644
--- a/validation-test/Sema/type_checker_perf/slow/nil_coalescing.swift.gyb
+++ b/validation-test/Sema/type_checker_perf/slow/nil_coalescing.swift.gyb
@@ -2,6 +2,12 @@
// REQUIRES: asserts,no_asan
// REQUIRES: rdar38963783,no_asan
+// SWIFT_ENABLE_TENSORFLOW
+// This test is currently unsupported because the addition of `+` operators
+// to the stdlib (via `VectorNumeric`) causes type-checking to fail.
+// Re-enable when type-checking no longer fails.
+// UNSUPPORTED: OS=macosx
+
func t(_ x: Int?) -> Int {
return (x ?? 0)
%for i in range(1, N):
diff --git a/validation-test/Sema/type_checker_perf/slow/rdar23327871.swift.gyb b/validation-test/Sema/type_checker_perf/slow/rdar23327871.swift.gyb
index 1b4085f..923969e 100644
--- a/validation-test/Sema/type_checker_perf/slow/rdar23327871.swift.gyb
+++ b/validation-test/Sema/type_checker_perf/slow/rdar23327871.swift.gyb
@@ -1,3 +1,5 @@
+// SWIFT_ENABLE_TENSORFLOW
+// UNSUPPORTED: macosx
// RUN: %scale-test --begin 8 --end 16 --step 1 --select NumLeafScopes %s -Xfrontend=-solver-expression-time-threshold=1
// REQUIRES: asserts,no_asan
diff --git a/validation-test/Sema/type_checker_perf/slow/rdar35213699.swift b/validation-test/Sema/type_checker_perf/slow/rdar35213699.swift
new file mode 100644
index 0000000..27bb1ad
--- /dev/null
+++ b/validation-test/Sema/type_checker_perf/slow/rdar35213699.swift
@@ -0,0 +1,8 @@
+// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
+// REQUIRES: tools-release,no_asserts
+
+func test() {
+ let x: UInt = 1 * 2 + 3 * 4 + 5 * 6 + 7 * 8 + 9 * 10 + 11 * 12 + 13 * 14
+ // expected-error@-1 {{reasonable time}}
+}
+
diff --git a/validation-test/stdlib/Inputs b/validation-test/stdlib/Inputs
index 8556d70e..a5356ee 120000
--- a/validation-test/stdlib/Inputs
+++ b/validation-test/stdlib/Inputs
@@ -1 +1 @@
-../../test/stdlib/Inputs
\ No newline at end of file
+../../test/stdlib/Inputs/
\ No newline at end of file