blob: 26eda34b501faf1a0d73fdbd8c595eccfe542f11 [file] [log] [blame]
//===--- ArrayDifferentiation.swift ---------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//
// TODO(TF-938): Add `Element: Differentiable` requirement.
extension Array {
/// The view of an array as the differentiable product manifold of `Element`
/// multiplied with itself `count` times.
@frozen
public struct DifferentiableView {
var _base: [Element]
}
}
extension Array.DifferentiableView: Differentiable
where Element: Differentiable {
/// The viewed array.
public var base: [Element] {
get { _base }
_modify { yield &_base }
}
@usableFromInline
@derivative(of: base)
func _vjpBase() -> (
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
) {
return (base, { $0 })
}
@usableFromInline
@derivative(of: base)
func _jvpBase() -> (
value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector
) {
return (base, { $0 })
}
/// Creates a differentiable view of the given array.
public init(_ base: [Element]) { self._base = base }
@usableFromInline
@derivative(of: init(_:))
static func _vjpInit(_ base: [Element]) -> (
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
) {
return (Array.DifferentiableView(base), { $0 })
}
@usableFromInline
@derivative(of: init(_:))
static func _jvpInit(_ base: [Element]) -> (
value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector
) {
return (Array.DifferentiableView(base), { $0 })
}
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView
public mutating func move(along direction: TangentVector) {
precondition(
base.count == direction.base.count, """
Count mismatch: \(base.count) ('self') and \(direction.base.count) \
('direction')
""")
for i in base.indices {
base[i].move(along: direction.base[i])
}
}
/// A closure that produces a `TangentVector` of zeros with the same
/// `count` as `self`.
public var zeroTangentVectorInitializer: () -> TangentVector {
return base.zeroTangentVectorInitializer
}
}
// SWIFT_ENABLE_TENSORFLOW
extension Array.DifferentiableView: EuclideanDifferentiable
where Element: EuclideanDifferentiable {
public var differentiableVectorView: Array.DifferentiableView.TangentVector {
Array.DifferentiableView.TangentVector(
base.map { $0.differentiableVectorView })
}
}
// SWIFT_ENABLE_TENSORFLOW END
extension Array.DifferentiableView: Equatable
where Element: Differentiable & Equatable {
public static func == (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Bool {
return lhs.base == rhs.base
}
}
extension Array.DifferentiableView: ExpressibleByArrayLiteral
where Element: Differentiable {
public init(arrayLiteral elements: Element...) {
self.init(elements)
}
}
extension Array.DifferentiableView: CustomStringConvertible
where Element: Differentiable {
public var description: String {
return base.description
}
}
/// Makes `Array.DifferentiableView` additive as the product space.
///
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
/// of all counts.
extension Array.DifferentiableView: AdditiveArithmetic
where Element: AdditiveArithmetic & Differentiable {
public static var zero: Array.DifferentiableView {
return Array.DifferentiableView([])
}
public static func + (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
precondition(
lhs.base.count == rhs.base.count,
"Count mismatch: \(lhs.base.count) and \(rhs.base.count)")
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
}
public static func - (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
precondition(
lhs.base.count == rhs.base.count,
"Count mismatch: \(lhs.base.count) and \(rhs.base.count)")
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
}
@inlinable
public subscript(_ index: Int) -> Element {
if index < base.count {
return base[index]
} else {
return Element.zero
}
}
}
/// Makes `Array` differentiable as the product manifold of `Element`
/// multiplied with itself `count` times.
extension Array: Differentiable where Element: Differentiable {
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
// `TangentVector` because `Array` already has a static `+` method with
// different semantics from `AdditiveArithmetic.+`. So we use
// `Array.DifferentiableView` for all these associated types.
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView
public mutating func move(along direction: TangentVector) {
var view = DifferentiableView(self)
view.move(along: direction)
self = view.base
}
/// A closure that produces a `TangentVector` of zeros with the same
/// `count` as `self`.
public var zeroTangentVectorInitializer: () -> TangentVector {
{ [zeroInits = map(\.zeroTangentVectorInitializer)] in
TangentVector(zeroInits.map { $0() })
}
}
}
// SWIFT_ENABLE_TENSORFLOW
extension Array: EuclideanDifferentiable
where Element: EuclideanDifferentiable {
public var differentiableVectorView: TangentVector {
TangentVector(map { $0.differentiableVectorView })
}
}
// SWIFT_ENABLE_TENSORFLOW END
//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//
extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: subscript)
func _vjpSubscript(index: Int) -> (
value: Element, pullback: (Element.TangentVector) -> TangentVector
) {
func pullback(_ v: Element.TangentVector) -> TangentVector {
var dSelf = [Element.TangentVector](
repeating: .zero,
count: count)
dSelf[index] = v
return TangentVector(dSelf)
}
return (self[index], pullback)
}
@usableFromInline
@derivative(of: subscript)
func _jvpSubscript(index: Int) -> (
value: Element, differential: (TangentVector) -> Element.TangentVector
) {
func differential(_ v: TangentVector) -> Element.TangentVector {
return v[index]
}
return (self[index], differential)
}
@usableFromInline
@derivative(of: +)
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, TangentVector)
) {
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
precondition(
v.base.count == lhs.count + rhs.count, """
Tangent vector with invalid count; expected to equal the sum of \
operand counts \(lhs.count) and \(rhs.count)
""")
return (
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),
TangentVector([Element.TangentVector](v.base[lhs.count...]))
)
}
return (lhs + rhs, pullback)
}
@usableFromInline
@derivative(of: +)
static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> (
value: Self,
differential: (TangentVector, TangentVector) -> TangentVector
) {
func differential(_ l: TangentVector, _ r: TangentVector) -> TangentVector {
precondition(
l.base.count == lhs.count && r.base.count == rhs.count, """
Tangent vectors with invalid count; expected to equal the \
operand counts \(lhs.count) and \(rhs.count)
""")
return .init(l.base + r.base)
}
return (lhs + rhs, differential)
}
}
extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: append)
mutating func _vjpAppend(_ element: Element) -> (
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
) {
let appendedElementIndex = count
append(element)
return ((), { v in
defer { v.base.removeLast() }
return v.base[appendedElementIndex]
})
}
@usableFromInline
@derivative(of: append)
mutating func _jvpAppend(_ element: Element) -> (
value: Void,
differential: (inout TangentVector, Element.TangentVector) -> Void
) {
append(element)
return ((), { $0.base.append($1) })
}
}
extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: +=)
static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, pullback: (inout TangentVector) -> TangentVector
) {
let lhsCount = lhs.count
lhs += rhs
return ((), { v in
let drhs =
TangentVector(.init(v.base.dropFirst(lhsCount)))
let rhsCount = drhs.base.count
v.base.removeLast(rhsCount)
return drhs
})
}
@usableFromInline
@derivative(of: +=)
static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, differential: (inout TangentVector, TangentVector) -> Void
) {
lhs += rhs
return ((), { $0.base += $1.base })
}
}
extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: init(repeating:count:))
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, pullback: (TangentVector) -> Element.TangentVector
) {
(
value: Self(repeating: repeatedValue, count: count),
pullback: { v in
v.base.reduce(.zero, +)
}
)
}
@usableFromInline
@derivative(of: init(repeating:count:))
static func _jvpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, differential: (Element.TangentVector) -> TangentVector
) {
(
value: Self(repeating: repeatedValue, count: count),
differential: { v in TangentVector(.init(repeating: v, count: count)) }
)
}
}
//===----------------------------------------------------------------------===//
// Differentiable higher order functions for collections
//===----------------------------------------------------------------------===//
extension Array where Element: Differentiable {
@inlinable
@differentiable(wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> [Result] {
map(body)
}
@inlinable
@derivative(of: differentiableMap)
internal func _vjpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> (
value: [Result],
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
) {
var values: [Result] = []
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
for x in self {
let (y, pb) = valueWithPullback(at: x, in: body)
values.append(y)
pullbacks.append(pb)
}
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
}
return (value: values, pullback: pullback)
}
@inlinable
@derivative(of: differentiableMap)
internal func _jvpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> (
value: [Result],
differential: (Array.TangentVector) -> Array<Result>.TangentVector
) {
var values: [Result] = []
var differentials: [(Element.TangentVector) -> Result.TangentVector] = []
for x in self {
let (y, df) = valueWithDifferential(at: x, in: body)
values.append(y)
differentials.append(df)
}
func differential(_ tans: Array.TangentVector) -> Array<Result>.TangentVector {
.init(zip(tans.base, differentials).map { tan, df in df(tan) })
}
return (value: values, differential: differential)
}
}
extension Array where Element: Differentiable {
@inlinable
@differentiable(wrt: (self, initialResult))
public func differentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> Result {
reduce(initialResult, nextPartialResult)
}
@inlinable
@derivative(of: differentiableReduce)
internal func _vjpDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> (
value: Result,
pullback: (Result.TangentVector)
-> (Array.TangentVector, Result.TangentVector)
) {
var pullbacks:
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] =
[]
let count = self.count
pullbacks.reserveCapacity(count)
var result = initialResult
for element in self {
let (y, pb) =
valueWithPullback(at: result, element, in: nextPartialResult)
result = y
pullbacks.append(pb)
}
return (
value: result,
pullback: { tangent in
var resultTangent = tangent
var elementTangents = TangentVector([])
elementTangents.base.reserveCapacity(count)
for pullback in pullbacks.reversed() {
let (newResultTangent, elementTangent) = pullback(resultTangent)
resultTangent = newResultTangent
elementTangents.base.append(elementTangent)
}
return (TangentVector(elementTangents.base.reversed()), resultTangent)
}
)
}
@inlinable
@derivative(of: differentiableReduce, wrt: (self, initialResult))
func _jvpDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> (value: Result,
differential: (Array.TangentVector, Result.TangentVector)
-> Result.TangentVector) {
var differentials:
[(Result.TangentVector, Element.TangentVector) -> Result.TangentVector]
= []
let count = self.count
differentials.reserveCapacity(count)
var result = initialResult
for element in self {
let (y, df) =
valueWithDifferential(at: result, element, in: nextPartialResult)
result = y
differentials.append(df)
}
return (value: result, differential: { dSelf, dInitial in
var dResult = dInitial
for (dElement, df) in zip(dSelf.base, differentials) {
dResult = df(dResult, dElement)
}
return dResult
})
}
}