| //===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| |
| import Swift |
| |
| %{ |
| storagescalarCounts = [2,4,8,16,32,64] |
| vectorscalarCounts = storagescalarCounts + [3] |
| }% |
| |
| %for n in vectorscalarCounts: |
| |
| //===----------------------------------------------------------------------===// |
| // Protocol conformances |
| //===----------------------------------------------------------------------===// |
| |
| extension SIMD${n}: AdditiveArithmetic where Scalar: FloatingPoint {} |
| |
| extension SIMD${n}: Differentiable |
| where |
| Scalar: Differentiable & BinaryFloatingPoint, |
| Scalar.TangentVector: BinaryFloatingPoint |
| { |
| public typealias TangentVector = SIMD${n} |
| |
| @inlinable |
| public var zeroTangentVectorInitializer: () -> TangentVector { |
| { .init(repeating: 0) } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Derivatives |
| //===----------------------------------------------------------------------===// |
| |
| extension SIMD${n} |
| where |
| Scalar: Differentiable & BinaryFloatingPoint, |
| Scalar.TangentVector == Scalar |
| { |
| // NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation. |
| @inlinable |
| @derivative(of: subscript(_:)) |
| internal func _vjpSubscript(_ index: Int) |
| -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) |
| { |
| return (self[index], { v in |
| var zeros = Self.zero |
| zeros[index] = v |
| return zeros |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: subscript(_:)) |
| internal func _jvpSubscript(index: Int) |
| -> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector) |
| { |
| return (self[index], { v in |
| return .init(v[index]) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: subscript(_:).set) |
| internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int) |
| -> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector) |
| { |
| self[index] = newValue |
| return ((), { dSelf in |
| let dNewValue = dSelf[index] |
| dSelf[index] = 0 |
| return dNewValue |
| }) |
| } |
| } |
| |
| %end |
| |
| extension SIMD |
| where |
| Self: Differentiable, |
| TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint, |
| TangentVector.Scalar: BinaryFloatingPoint |
| { |
| @inlinable |
| @derivative(of: +) |
| static func _vjpAdd(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) |
| ) |
| { |
| return (lhs + rhs, { v in |
| return (v, v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: +) |
| static func _jvpAdd(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, differential: (TangentVector, TangentVector) -> TangentVector |
| ) |
| { |
| return (lhs + rhs, { ltan, rtan in |
| return ltan + rtan |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _vjpSubtract(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) |
| ) |
| { |
| return (lhs - rhs, { v in |
| return (v, -v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _jvpSubtract(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, differential: (TangentVector, TangentVector) -> TangentVector |
| ) |
| { |
| return (lhs - rhs, { ltan, rtan in |
| return ltan - rtan |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _vjpNegate(rhs: Self) |
| -> (value: Self, pullback: (TangentVector) -> (TangentVector)) |
| { |
| return (-rhs, { v in |
| return -v |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _jvpNegate(rhs: Self) |
| -> (value: Self, differential: (TangentVector) -> (TangentVector)) |
| { |
| return (-rhs, { v in |
| return -v |
| }) |
| } |
| } |
| |
| extension SIMD |
| where |
| Self: Differentiable, |
| TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint, |
| Self.TangentVector == Self |
| { |
| @inlinable |
| @derivative(of: *) |
| static func _vjpMultiply(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) |
| ) |
| { |
| return (lhs * rhs, { v in |
| return (v * rhs, v * lhs) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: *) |
| static func _jvpMultiply(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, differential: (TangentVector, TangentVector) -> TangentVector |
| ) |
| { |
| return (lhs * rhs, { ltan, rtan in |
| return lhs * rtan + ltan * rhs |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _vjpDivide(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) |
| ) |
| { |
| return ( lhs / rhs, { v in |
| (v / rhs, -lhs / (rhs * rhs) * v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _jvpDivide(lhs: Self, rhs: Self) |
| -> ( |
| value: Self, differential: (TangentVector, TangentVector) -> TangentVector |
| ) |
| { |
| return ( lhs / rhs, { ltan, rtan in |
| (ltan * rhs - lhs * rtan) / (rhs * rhs) |
| }) |
| } |
| } |
| |
| extension SIMD |
| where |
| Self: Differentiable, |
| TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint & Differentiable, |
| Scalar.TangentVector: BinaryFloatingPoint, |
| TangentVector.Scalar == Scalar.TangentVector |
| { |
| @inlinable |
| @derivative(of: +) |
| static func _vjpAdd(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) |
| ) { |
| return (lhs + rhs, { v in |
| return (v.sum(), v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: +) |
| static func _jvpAdd(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| differential: (Scalar.TangentVector, TangentVector) -> TangentVector |
| ) { |
| return (lhs + rhs, { ltan, rtan in |
| return ltan + rtan |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _vjpSubtract(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) |
| ) { |
| return (lhs - rhs, { v in |
| return (v.sum(), -v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _jvpSubtract(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| differential: (Scalar.TangentVector, TangentVector) -> TangentVector |
| ) { |
| return (lhs - rhs, { ltan, rtan in |
| return ltan - rtan |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: +) |
| static func _vjpAdd(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) |
| ) { |
| return (lhs + rhs, { v in |
| return (v, v.sum()) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: +) |
| static func _jvpAdd(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| differential: (TangentVector, Scalar.TangentVector) -> TangentVector |
| ) { |
| return (lhs + rhs, { ltan, rtan in |
| return ltan + rtan |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _vjpSubtract(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) |
| ) { |
| return (lhs - rhs, { v in |
| return (v, -v.sum()) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: -) |
| static func _jvpSubtract(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| differential: (TangentVector, Scalar.TangentVector) -> TangentVector |
| ) { |
| return (lhs - rhs, { ltan, rtan in |
| return ltan - rtan |
| }) |
| } |
| } |
| |
| extension SIMD |
| where |
| Self: Differentiable, |
| TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint & Differentiable, |
| Self.TangentVector == Self, |
| Scalar.TangentVector == Scalar |
| { |
| @inlinable |
| @derivative(of: *) |
| static func _vjpMultiply(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) |
| ) { |
| return (lhs * rhs, { v in |
| return (v * rhs, (v * lhs).sum()) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: *) |
| static func _jvpMultiply(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| differential: (TangentVector, Scalar.TangentVector) -> TangentVector |
| ) { |
| return (lhs * rhs, { ltan, rtan in |
| return lhs * rtan + ltan * rhs |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _vjpDivide(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) |
| ) { |
| return (lhs / rhs, { v in |
| (v / rhs, (-lhs / (rhs * rhs) * v).sum()) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _jvpDivide(lhs: Self, rhs: Scalar) -> ( |
| value: Self, |
| differential: (TangentVector, Scalar.TangentVector) -> TangentVector |
| ) { |
| return (lhs / rhs, { ltan, rtan in |
| (ltan * rhs - lhs * rtan) / (rhs * rhs) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: *) |
| static func _vjpMultiply(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) |
| ) { |
| return (lhs * rhs, { v in |
| return ((v * rhs).sum(), v * lhs) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: *) |
| static func _jvpMultiply(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| differential: (Scalar.TangentVector, TangentVector) -> TangentVector |
| ) { |
| return (lhs * rhs, { ltan, rtan in |
| return lhs * rtan + ltan * rhs |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _vjpDivide(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) |
| ) { |
| return (lhs / rhs, { v in |
| ((v / rhs).sum(), -lhs / (rhs * rhs) * v) |
| }) |
| } |
| |
| @inlinable |
| @derivative(of: /) |
| static func _jvpDivide(lhs: Scalar, rhs: Self) -> ( |
| value: Self, |
| differential: (Scalar.TangentVector, TangentVector) -> TangentVector |
| ) { |
| return (lhs / rhs, { ltan, rtan in |
| (ltan * rhs - lhs * rtan) / (rhs * rhs) |
| }) |
| } |
| } |
| |
| // FIXME(TF-1103): Derivative registration does not yet support |
| // `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`. |
| /* |
| extension SIMD |
| where |
| Self: Differentiable, |
| TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint & Differentiable, |
| Scalar.TangentVector: BinaryFloatingPoint, |
| TangentVector == Self |
| { |
| @inlinable |
| @derivative(of: sum) |
| func _vjpSum() -> ( |
| value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector |
| ) { |
| return (sum(), { v in Self(repeating: Scalar(v)) }) |
| } |
| |
| @inlinable |
| @derivative(of: sum) |
| func _jvpSum() -> ( |
| value: Scalar, differential: (TangentVector) -> Scalar.TangentVector |
| ) { |
| return (sum(), { v in Scalar.TangentVector(v.sum()) }) |
| } |
| } |
| */ |
| |
| extension SIMD |
| where |
| Self: Differentiable, |
| Self.TangentVector: SIMD, |
| Scalar: BinaryFloatingPoint & Differentiable, |
| Self.TangentVector == Self, |
| Scalar.TangentVector == Scalar |
| { |
| @inlinable |
| @derivative(of: init(repeating:)) |
| static func _vjpInit(repeating value: Scalar) |
| -> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) |
| { |
| return (Self(repeating: value), { v in v.sum() }) |
| } |
| |
| @inlinable |
| @derivative(of: init(repeating:)) |
| static func _jvpInit(repeating value: Scalar) |
| -> (value: Self, differential: (Scalar.TangentVector) -> TangentVector) |
| { |
| return (Self(repeating: value), { v in Self(repeating: v) }) |
| } |
| } |