| //===--- DifferentiationSupport.swift -------------------------*- swift -*-===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // SWIFT_ENABLE_TENSORFLOW |
| // |
| // This file defines not-yet-upstreamed support for differentiable programming |
| // and deep learning APIs. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| infix operator .* : MultiplicationPrecedence |
| infix operator .*= : AssignmentPrecedence |
| |
| //===----------------------------------------------------------------------===// |
| // Compiler Protocols |
| //===----------------------------------------------------------------------===// |
| |
| /// A type with values that support pointwise multiplication. |
| // TODO: Add API documentation. |
| public protocol PointwiseMultiplicative : AdditiveArithmetic { |
| /// The one value. |
| /// |
| /// One is the identity element for multiplication. For any value, |
| /// `x .* .one == x` and `.one .* x == x`. |
| static var one: Self { get } |
| |
| /// The multiplicative inverse of self. |
| /// |
| /// For any value, `x .* x.reciprocal == .one` and |
| /// `x.reciprocal .* x == .one`. |
| var reciprocal: Self { get } |
| |
| /// Multiplies two values and produces their product. |
| /// |
| /// - Parameters: |
| /// - lhs: The first value to multiply. |
| /// - rhs: The second value to multiply. |
| static func .*(lhs: Self, rhs: Self) -> Self |
| |
| /// Multiplies two values and produces their product. |
| /// |
| /// - Parameters: |
| /// - lhs: The first value to multiply. |
| /// - rhs: The second value to multiply. |
| static func .*=(lhs: inout Self, rhs: Self) |
| } |
| |
| public extension PointwiseMultiplicative { |
| static func .*=(lhs: inout Self, rhs: Self) { |
| lhs = lhs .* rhs |
| } |
| } |
| |
| public extension PointwiseMultiplicative |
| where Self : ExpressibleByIntegerLiteral { |
| static var one: Self { |
| return 1 |
| } |
| } |
| |
| /// A type that represents an unranked vector space. Values of this type are |
| /// elements in this vector space and have either no shape or a static shape. |
| public protocol VectorProtocol : AdditiveArithmetic { |
| /// The type of scalars in the vector space. |
| associatedtype VectorSpaceScalar : AdditiveArithmetic |
| |
| func adding(_ x: VectorSpaceScalar) -> Self |
| |
| mutating func add(_ x: VectorSpaceScalar) |
| |
| func subtracting(_ x: VectorSpaceScalar) -> Self |
| |
| mutating func subtract(_ x: VectorSpaceScalar) |
| |
| /// Returns `self` multiplied by the given scalar. |
| func scaled(by scalar: VectorSpaceScalar) -> Self |
| |
| /// Multiplies `self` by the given scalar. |
| mutating func scale(by scalar: VectorSpaceScalar) |
| } |
| |
| public extension VectorProtocol { |
| mutating func add(_ x: VectorSpaceScalar) { |
| self = adding(x) |
| } |
| |
| mutating func subtract(_ x: VectorSpaceScalar) { |
| self = subtracting(x) |
| } |
| |
| mutating func scale(by scalar: VectorSpaceScalar) { |
| self = scaled(by: scalar) |
| } |
| } |
| |
| /* |
| // Note: These default-implemented operators will slow down type-checking |
| // performance and break existing code. |
| |
| public extension VectorProtocol { |
| static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self { |
| lhs.adding(rhs) |
| } |
| |
| static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self { |
| rhs.adding(lhs) |
| } |
| |
| static func += (lhs: inout Self, rhs: VectorSpaceScalar) { |
| lhs.add(rhs) |
| } |
| |
| static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self { |
| lhs.subtracting(rhs) |
| } |
| |
| static func -= (lhs: inout Self, rhs: VectorSpaceScalar) { |
| lhs.subtract(rhs) |
| } |
| |
| static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self { |
| lhs.scaled(by: rhs) |
| } |
| |
| static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self { |
| rhs.scaled(by: lhs) |
| } |
| |
| static func *= (lhs: inout Self, rhs: VectorSpaceScalar) { |
| lhs.scale(by: rhs) |
| } |
| } |
| |
| public extension VectorProtocol where VectorSpaceScalar : SignedNumeric { |
| static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self { |
| -rhs.adding(lhs) |
| } |
| |
| static prefix func - (x: Self) -> Self { |
| .zero - x |
| } |
| } |
| */ |
| |
| /// A type that is differentiable in the Euclidean space. |
| /// The type may represent a vector space, or consist of a vector space and some |
| /// other non-differentiable component. |
| /// |
| /// Mathematically, this represents a product manifold that consists of |
| /// a differentiable vector space and some arbitrary manifold, where the tangent |
| /// bundle of the entire product manifold is equal to the vector space |
| /// component. |
| /// |
| /// This abstraction is useful for representing common differentiable data |
| /// structures that contain both differentiable vector properties and other |
| /// stored properties that do not have a derivative, e.g. |
| /// |
| /// ```swift |
| /// struct Perceptron: @memberwise EuclideanDifferentiable { |
| /// var weight: SIMD16<Float> |
| /// var bias: Float |
| /// @noDerivative var useBias: Bool |
| /// } |
| /// ``` |
| /// |
| /// - Note: Conform a type to `EuclideanDifferentiable` if it is differentiable |
| /// only with respect to its vector space component and when its |
| /// `TangentVector` is equal to its vector space component. |
| public protocol EuclideanDifferentiable: Differentiable { |
| /// The differentiable vector component of `self`. |
| var differentiableVectorView: TangentVector { get } |
| } |
| |
| public extension EuclideanDifferentiable where TangentVector == Self { |
| var differentiableVectorView: TangentVector { _read { yield self } } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Functional utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Make a function be recomputed in its pullback, known as "checkpointing" in |
| /// traditional automatic differentiation. |
| @inlinable |
| public func withRecomputationInPullbacks<T, U>( |
| _ body: @escaping @differentiable (T) -> U |
| ) -> @differentiable (T) -> U where T : Differentiable, U : Differentiable { |
| return differentiableFunction { x in |
| (value: body(x), pullback: { v in pullback(at: x, in: body)(v) }) |
| } |
| } |
| |
| public extension Differentiable { |
| @inlinable |
| @differentiable(wrt: self) |
| func withRecomputationInPullbacks<Result : Differentiable>( |
| _ body: @escaping @differentiable (Self) -> Result |
| ) -> Result { |
| return body(self) |
| } |
| |
| @inlinable |
| @derivative(of: withRecomputationInPullbacks) |
| internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>( |
| _ body: @escaping @differentiable (Self) -> Result |
| ) -> (value: Result, pullback: (Result.TangentVector) -> TangentVector) { |
| return Swift.valueWithPullback( |
| at: self, in: Swift.withRecomputationInPullbacks(body) |
| ) |
| } |
| } |