| //===--- ArrayDifferentiation.swift ---------------------------*- swift -*-===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| |
| import Swift |
| |
| //===----------------------------------------------------------------------===// |
| // Protocol conformances |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(TF-938): Add `Element: Differentiable` requirement. |
| extension Array { |
| /// The view of an array as the differentiable product manifold of `Element` |
| /// multiplied with itself `count` times. |
| @frozen |
| public struct DifferentiableView { |
| var _base: [Element] |
| } |
| } |
| |
| extension Array.DifferentiableView: Differentiable |
| where Element: Differentiable { |
| /// The viewed array. |
| public var base: [Element] { |
| get { _base } |
| _modify { yield &_base } |
| } |
| |
| @usableFromInline |
| @derivative(of: base) |
| func _vjpBase() -> ( |
| value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector |
| ) { |
| return (base, { $0 }) |
| } |
| |
| @usableFromInline |
| @derivative(of: base) |
| func _jvpBase() -> ( |
| value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector |
| ) { |
| return (base, { $0 }) |
| } |
| |
| /// Creates a differentiable view of the given array. |
| public init(_ base: [Element]) { self._base = base } |
| |
| @usableFromInline |
| @derivative(of: init(_:)) |
| static func _vjpInit(_ base: [Element]) -> ( |
| value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector |
| ) { |
| return (Array.DifferentiableView(base), { $0 }) |
| } |
| |
| @usableFromInline |
| @derivative(of: init(_:)) |
| static func _jvpInit(_ base: [Element]) -> ( |
| value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector |
| ) { |
| return (Array.DifferentiableView(base), { $0 }) |
| } |
| |
| public typealias TangentVector = |
| Array<Element.TangentVector>.DifferentiableView |
| |
| public mutating func move(along direction: TangentVector) { |
| precondition( |
| base.count == direction.base.count, """ |
| Count mismatch: \(base.count) ('self') and \(direction.base.count) \ |
| ('direction') |
| """) |
| for i in base.indices { |
| base[i].move(along: direction.base[i]) |
| } |
| } |
| |
| /// A closure that produces a `TangentVector` of zeros with the same |
| /// `count` as `self`. |
| public var zeroTangentVectorInitializer: () -> TangentVector { |
| return base.zeroTangentVectorInitializer |
| } |
| } |
| |
| extension Array.DifferentiableView: Equatable |
| where Element: Differentiable & Equatable { |
| public static func == ( |
| lhs: Array.DifferentiableView, |
| rhs: Array.DifferentiableView |
| ) -> Bool { |
| return lhs.base == rhs.base |
| } |
| } |
| |
| extension Array.DifferentiableView: ExpressibleByArrayLiteral |
| where Element: Differentiable { |
| public init(arrayLiteral elements: Element...) { |
| self.init(elements) |
| } |
| } |
| |
| extension Array.DifferentiableView: CustomStringConvertible |
| where Element: Differentiable { |
| public var description: String { |
| return base.description |
| } |
| } |
| |
| /// Makes `Array.DifferentiableView` additive as the product space. |
| /// |
| /// Note that `Array.DifferentiableView([])` is the zero in the product spaces |
| /// of all counts. |
| extension Array.DifferentiableView: AdditiveArithmetic |
| where Element: AdditiveArithmetic & Differentiable { |
| |
| public static var zero: Array.DifferentiableView { |
| return Array.DifferentiableView([]) |
| } |
| |
| public static func + ( |
| lhs: Array.DifferentiableView, |
| rhs: Array.DifferentiableView |
| ) -> Array.DifferentiableView { |
| if lhs.base.count == 0 { |
| return rhs |
| } |
| if rhs.base.count == 0 { |
| return lhs |
| } |
| precondition( |
| lhs.base.count == rhs.base.count, |
| "Count mismatch: \(lhs.base.count) and \(rhs.base.count)") |
| return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+)) |
| } |
| |
| public static func - ( |
| lhs: Array.DifferentiableView, |
| rhs: Array.DifferentiableView |
| ) -> Array.DifferentiableView { |
| if lhs.base.count == 0 { |
| return rhs |
| } |
| if rhs.base.count == 0 { |
| return lhs |
| } |
| precondition( |
| lhs.base.count == rhs.base.count, |
| "Count mismatch: \(lhs.base.count) and \(rhs.base.count)") |
| return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-)) |
| } |
| |
| @inlinable |
| public subscript(_ index: Int) -> Element { |
| if index < base.count { |
| return base[index] |
| } else { |
| return Element.zero |
| } |
| } |
| } |
| |
| /// Makes `Array` differentiable as the product manifold of `Element` |
| /// multiplied with itself `count` times. |
| extension Array: Differentiable where Element: Differentiable { |
| // In an ideal world, `TangentVector` would be `[Element.TangentVector]`. |
| // Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for |
| // `TangentVector` because `Array` already has a static `+` method with |
| // different semantics from `AdditiveArithmetic.+`. So we use |
| // `Array.DifferentiableView` for all these associated types. |
| public typealias TangentVector = |
| Array<Element.TangentVector>.DifferentiableView |
| |
| public mutating func move(along direction: TangentVector) { |
| var view = DifferentiableView(self) |
| view.move(along: direction) |
| self = view.base |
| } |
| |
| /// A closure that produces a `TangentVector` of zeros with the same |
| /// `count` as `self`. |
| public var zeroTangentVectorInitializer: () -> TangentVector { |
| { [zeroInits = map(\.zeroTangentVectorInitializer)] in |
| TangentVector(zeroInits.map { $0() }) |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Derivatives |
| //===----------------------------------------------------------------------===// |
| |
| extension Array where Element: Differentiable { |
| @usableFromInline |
| @derivative(of: subscript) |
| func _vjpSubscript(index: Int) -> ( |
| value: Element, pullback: (Element.TangentVector) -> TangentVector |
| ) { |
| func pullback(_ v: Element.TangentVector) -> TangentVector { |
| var dSelf = [Element.TangentVector]( |
| repeating: .zero, |
| count: count) |
| dSelf[index] = v |
| return TangentVector(dSelf) |
| } |
| return (self[index], pullback) |
| } |
| |
| @usableFromInline |
| @derivative(of: subscript) |
| func _jvpSubscript(index: Int) -> ( |
| value: Element, differential: (TangentVector) -> Element.TangentVector |
| ) { |
| func differential(_ v: TangentVector) -> Element.TangentVector { |
| return v[index] |
| } |
| return (self[index], differential) |
| } |
| |
| @usableFromInline |
| @derivative(of: +) |
| static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> ( |
| value: Self, |
| pullback: (TangentVector) -> (TangentVector, TangentVector) |
| ) { |
| func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { |
| precondition( |
| v.base.count == lhs.count + rhs.count, """ |
| Tangent vector with invalid count; expected to equal the sum of \ |
| operand counts \(lhs.count) and \(rhs.count) |
| """) |
| return ( |
| TangentVector([Element.TangentVector](v.base[0..<lhs.count])), |
| TangentVector([Element.TangentVector](v.base[lhs.count...])) |
| ) |
| } |
| return (lhs + rhs, pullback) |
| } |
| |
| @usableFromInline |
| @derivative(of: +) |
| static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> ( |
| value: Self, |
| differential: (TangentVector, TangentVector) -> TangentVector |
| ) { |
| func differential(_ l: TangentVector, _ r: TangentVector) -> TangentVector { |
| precondition( |
| l.base.count == lhs.count && r.base.count == rhs.count, """ |
| Tangent vectors with invalid count; expected to equal the \ |
| operand counts \(lhs.count) and \(rhs.count) |
| """) |
| return .init(l.base + r.base) |
| } |
| return (lhs + rhs, differential) |
| } |
| } |
| |
| |
| extension Array where Element: Differentiable { |
| @usableFromInline |
| @derivative(of: append) |
| mutating func _vjpAppend(_ element: Element) -> ( |
| value: Void, pullback: (inout TangentVector) -> Element.TangentVector |
| ) { |
| let appendedElementIndex = count |
| append(element) |
| return ((), { v in |
| defer { v.base.removeLast() } |
| return v.base[appendedElementIndex] |
| }) |
| } |
| |
| @usableFromInline |
| @derivative(of: append) |
| mutating func _jvpAppend(_ element: Element) -> ( |
| value: Void, |
| differential: (inout TangentVector, Element.TangentVector) -> Void |
| ) { |
| append(element) |
| return ((), { $0.base.append($1) }) |
| } |
| } |
| |
| extension Array where Element: Differentiable { |
| @usableFromInline |
| @derivative(of: +=) |
| static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, pullback: (inout TangentVector) -> TangentVector |
| ) { |
| let lhsCount = lhs.count |
| lhs += rhs |
| return ((), { v in |
| let drhs = |
| TangentVector(.init(v.base.dropFirst(lhsCount))) |
| let rhsCount = drhs.base.count |
| v.base.removeLast(rhsCount) |
| return drhs |
| }) |
| } |
| |
| @usableFromInline |
| @derivative(of: +=) |
| static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, differential: (inout TangentVector, TangentVector) -> Void |
| ) { |
| lhs += rhs |
| return ((), { $0.base += $1.base }) |
| } |
| } |
| |
| extension Array where Element: Differentiable { |
| @usableFromInline |
| @derivative(of: init(repeating:count:)) |
| static func _vjpInit(repeating repeatedValue: Element, count: Int) -> ( |
| value: Self, pullback: (TangentVector) -> Element.TangentVector |
| ) { |
| ( |
| value: Self(repeating: repeatedValue, count: count), |
| pullback: { v in |
| v.base.reduce(.zero, +) |
| } |
| ) |
| } |
| |
| @usableFromInline |
| @derivative(of: init(repeating:count:)) |
| static func _jvpInit(repeating repeatedValue: Element, count: Int) -> ( |
| value: Self, differential: (Element.TangentVector) -> TangentVector |
| ) { |
| ( |
| value: Self(repeating: repeatedValue, count: count), |
| differential: { v in TangentVector(.init(repeating: v, count: count)) } |
| ) |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Differentiable higher order functions for collections |
| //===----------------------------------------------------------------------===// |
| |
| extension Array where Element: Differentiable { |
| @inlinable |
| @differentiable(wrt: self) |
| public func differentiableMap<Result: Differentiable>( |
| _ body: @differentiable (Element) -> Result |
| ) -> [Result] { |
| map(body) |
| } |
| |
| @inlinable |
| @derivative(of: differentiableMap) |
| internal func _vjpDifferentiableMap<Result: Differentiable>( |
| _ body: @differentiable (Element) -> Result |
| ) -> ( |
| value: [Result], |
| pullback: (Array<Result>.TangentVector) -> Array.TangentVector |
| ) { |
| var values: [Result] = [] |
| var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = [] |
| for x in self { |
| let (y, pb) = valueWithPullback(at: x, in: body) |
| values.append(y) |
| pullbacks.append(pb) |
| } |
| func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector { |
| .init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) }) |
| } |
| return (value: values, pullback: pullback) |
| } |
| |
| @inlinable |
| @derivative(of: differentiableMap) |
| internal func _jvpDifferentiableMap<Result: Differentiable>( |
| _ body: @differentiable (Element) -> Result |
| ) -> ( |
| value: [Result], |
| differential: (Array.TangentVector) -> Array<Result>.TangentVector |
| ) { |
| var values: [Result] = [] |
| var differentials: [(Element.TangentVector) -> Result.TangentVector] = [] |
| for x in self { |
| let (y, df) = valueWithDifferential(at: x, in: body) |
| values.append(y) |
| differentials.append(df) |
| } |
| func differential(_ tans: Array.TangentVector) -> Array<Result>.TangentVector { |
| .init(zip(tans.base, differentials).map { tan, df in df(tan) }) |
| } |
| return (value: values, differential: differential) |
| } |
| } |
| |
| extension Array where Element: Differentiable { |
| @inlinable |
| @differentiable(wrt: (self, initialResult)) |
| public func differentiableReduce<Result: Differentiable>( |
| _ initialResult: Result, |
| _ nextPartialResult: @differentiable (Result, Element) -> Result |
| ) -> Result { |
| reduce(initialResult, nextPartialResult) |
| } |
| |
| @inlinable |
| @derivative(of: differentiableReduce) |
| internal func _vjpDifferentiableReduce<Result: Differentiable>( |
| _ initialResult: Result, |
| _ nextPartialResult: @differentiable (Result, Element) -> Result |
| ) -> ( |
| value: Result, |
| pullback: (Result.TangentVector) |
| -> (Array.TangentVector, Result.TangentVector) |
| ) { |
| var pullbacks: |
| [(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] = |
| [] |
| let count = self.count |
| pullbacks.reserveCapacity(count) |
| var result = initialResult |
| for element in self { |
| let (y, pb) = |
| valueWithPullback(at: result, element, in: nextPartialResult) |
| result = y |
| pullbacks.append(pb) |
| } |
| return ( |
| value: result, |
| pullback: { tangent in |
| var resultTangent = tangent |
| var elementTangents = TangentVector([]) |
| elementTangents.base.reserveCapacity(count) |
| for pullback in pullbacks.reversed() { |
| let (newResultTangent, elementTangent) = pullback(resultTangent) |
| resultTangent = newResultTangent |
| elementTangents.base.append(elementTangent) |
| } |
| return (TangentVector(elementTangents.base.reversed()), resultTangent) |
| } |
| ) |
| } |
| |
| @inlinable |
| @derivative(of: differentiableReduce, wrt: (self, initialResult)) |
| func _jvpDifferentiableReduce<Result: Differentiable>( |
| _ initialResult: Result, |
| _ nextPartialResult: @differentiable (Result, Element) -> Result |
| ) -> (value: Result, |
| differential: (Array.TangentVector, Result.TangentVector) |
| -> Result.TangentVector) { |
| var differentials: |
| [(Result.TangentVector, Element.TangentVector) -> Result.TangentVector] |
| = [] |
| let count = self.count |
| differentials.reserveCapacity(count) |
| var result = initialResult |
| for element in self { |
| let (y, df) = |
| valueWithDifferential(at: result, element, in: nextPartialResult) |
| result = y |
| differentials.append(df) |
| } |
| return (value: result, differential: { dSelf, dInitial in |
| var dResult = dInitial |
| for (dElement, df) in zip(dSelf.base, differentials) { |
| dResult = df(dResult, dElement) |
| } |
| return dResult |
| }) |
| } |
| } |