blob: f875fe268a48a48c6665225d7958aa5c90ce7074 [file] [log] [blame]
//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
%{
storagescalarCounts = [2,4,8,16,32,64]
vectorscalarCounts = storagescalarCounts + [3]
}%
%for n in vectorscalarCounts:
//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//
extension SIMD${n}: AdditiveArithmetic where Scalar: FloatingPoint {}
extension SIMD${n}: Differentiable
where
Scalar: Differentiable & BinaryFloatingPoint,
Scalar.TangentVector: BinaryFloatingPoint
{
public typealias TangentVector = SIMD${n}
@inlinable
public var zeroTangentVectorInitializer: () -> TangentVector {
{ .init(repeating: 0) }
}
}
// SWIFT_ENABLE_TENSORFLOW
extension SIMD${n}: EuclideanDifferentiable
where
Scalar: EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector: BinaryFloatingPoint
{
}
// SWIFT_ENABLE_TENSORFLOW END
//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//
extension SIMD${n}
where
Scalar: Differentiable & BinaryFloatingPoint,
Scalar.TangentVector == Scalar
{
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
@inlinable
@derivative(of: subscript(_:))
internal func _vjpSubscript(_ index: Int)
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
{
return (self[index], { v in
var zeros = Self.zero
zeros[index] = v
return zeros
})
}
@inlinable
@derivative(of: subscript(_:))
internal func _jvpSubscript(index: Int)
-> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
{
return (self[index], { v in
return .init(v[index])
})
}
@inlinable
@derivative(of: subscript(_:).set)
internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
-> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
{
self[index] = newValue
return ((), { dSelf in
let dNewValue = dSelf[index]
dSelf[index] = 0
return dNewValue
})
}
}
%end
extension SIMD
where
Self: Differentiable,
TangentVector: SIMD,
Scalar: BinaryFloatingPoint,
TangentVector.Scalar: BinaryFloatingPoint
{
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Self, rhs: Self)
-> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
)
{
return (lhs + rhs, { v in
return (v, v)
})
}
@inlinable
@derivative(of: +)
static func _jvpAdd(lhs: Self, rhs: Self)
-> (
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
)
{
return (lhs + rhs, { ltan, rtan in
return ltan + rtan
})
}
@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Self, rhs: Self)
-> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
)
{
return (lhs - rhs, { v in
return (v, -v)
})
}
@inlinable
@derivative(of: -)
static func _jvpSubtract(lhs: Self, rhs: Self)
-> (
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
)
{
return (lhs - rhs, { ltan, rtan in
return ltan - rtan
})
}
@inlinable
@derivative(of: -)
static func _vjpNegate(rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector))
{
return (-rhs, { v in
return -v
})
}
@inlinable
@derivative(of: -)
static func _jvpNegate(rhs: Self)
-> (value: Self, differential: (TangentVector) -> (TangentVector))
{
return (-rhs, { v in
return -v
})
}
}
extension SIMD
where
Self: Differentiable,
TangentVector: SIMD,
Scalar: BinaryFloatingPoint,
Self.TangentVector == Self
{
@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Self, rhs: Self)
-> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
)
{
return (lhs * rhs, { v in
return (v * rhs, v * lhs)
})
}
@inlinable
@derivative(of: *)
static func _jvpMultiply(lhs: Self, rhs: Self)
-> (
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
)
{
return (lhs * rhs, { ltan, rtan in
return lhs * rtan + ltan * rhs
})
}
@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Self, rhs: Self)
-> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
)
{
return ( lhs / rhs, { v in
(v / rhs, -lhs / (rhs * rhs) * v)
})
}
@inlinable
@derivative(of: /)
static func _jvpDivide(lhs: Self, rhs: Self)
-> (
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
)
{
return ( lhs / rhs, { ltan, rtan in
(ltan * rhs - lhs * rtan) / (rhs * rhs)
})
}
}
extension SIMD
where
Self: Differentiable,
TangentVector: SIMD,
Scalar: BinaryFloatingPoint & Differentiable,
Scalar.TangentVector: BinaryFloatingPoint,
TangentVector.Scalar == Scalar.TangentVector
{
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Scalar, rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
) {
return (lhs + rhs, { v in
return (v.sum(), v)
})
}
@inlinable
@derivative(of: +)
static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
value: Self,
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
) {
return (lhs + rhs, { ltan, rtan in
return ltan + rtan
})
}
@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
) {
return (lhs - rhs, { v in
return (v.sum(), -v)
})
}
@inlinable
@derivative(of: -)
static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
value: Self,
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
) {
return (lhs - rhs, { ltan, rtan in
return ltan - rtan
})
}
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
) {
return (lhs + rhs, { v in
return (v, v.sum())
})
}
@inlinable
@derivative(of: +)
static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
value: Self,
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
) {
return (lhs + rhs, { ltan, rtan in
return ltan + rtan
})
}
@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
) {
return (lhs - rhs, { v in
return (v, -v.sum())
})
}
@inlinable
@derivative(of: -)
static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
value: Self,
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
) {
return (lhs - rhs, { ltan, rtan in
return ltan - rtan
})
}
}
extension SIMD
where
Self: Differentiable,
TangentVector: SIMD,
Scalar: BinaryFloatingPoint & Differentiable,
Self.TangentVector == Self,
Scalar.TangentVector == Scalar
{
@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Self, rhs: Scalar) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
) {
return (lhs * rhs, { v in
return (v * rhs, (v * lhs).sum())
})
}
@inlinable
@derivative(of: *)
static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
value: Self,
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
) {
return (lhs * rhs, { ltan, rtan in
return lhs * rtan + ltan * rhs
})
}
@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
) {
return (lhs / rhs, { v in
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
})
}
@inlinable
@derivative(of: /)
static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
value: Self,
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
) {
return (lhs / rhs, { ltan, rtan in
(ltan * rhs - lhs * rtan) / (rhs * rhs)
})
}
@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
) {
return (lhs * rhs, { v in
return ((v * rhs).sum(), v * lhs)
})
}
@inlinable
@derivative(of: *)
static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
value: Self,
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
) {
return (lhs * rhs, { ltan, rtan in
return lhs * rtan + ltan * rhs
})
}
@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
) {
return (lhs / rhs, { v in
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
})
}
@inlinable
@derivative(of: /)
static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
value: Self,
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
) {
return (lhs / rhs, { ltan, rtan in
(ltan * rhs - lhs * rtan) / (rhs * rhs)
})
}
}
extension SIMD
where
Self: Differentiable,
TangentVector: SIMD,
Scalar: BinaryFloatingPoint & Differentiable,
Scalar.TangentVector: BinaryFloatingPoint,
TangentVector == Self
{
@inlinable
@derivative(of: sum)
func _vjpSum() -> (
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
) {
return (sum(), { v in Self(repeating: Scalar(v)) })
}
@inlinable
@derivative(of: sum)
func _jvpSum() -> (
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
) {
return (sum(), { v in Scalar.TangentVector(v.sum()) })
}
}
extension SIMD
where
Self: Differentiable,
Self.TangentVector: SIMD,
Scalar: BinaryFloatingPoint & Differentiable,
Self.TangentVector == Self,
Scalar.TangentVector == Scalar
{
@inlinable
@derivative(of: init(repeating:))
static func _vjpInit(repeating value: Scalar)
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
{
return (Self(repeating: value), { v in v.sum() })
}
@inlinable
@derivative(of: init(repeating:))
static func _jvpInit(repeating value: Scalar)
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
{
return (Self(repeating: value), { v in Self(repeating: v) })
}
}