blob: 95edc67390fe68692df06ff17eaa0c66163d95b9 [file] [log] [blame]
//===--- DifferentiationSupport.swift -------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// SWIFT_ENABLE_TENSORFLOW
//
// This file defines not-yet-upstreamed support for differentiable programming
// and deep learning APIs.
//
//===----------------------------------------------------------------------===//
infix operator .* : MultiplicationPrecedence
infix operator .*= : AssignmentPrecedence
//===----------------------------------------------------------------------===//
// Compiler Protocols
//===----------------------------------------------------------------------===//
/// A type with values that support pointwise multiplication.
// TODO: Add API documentation.
public protocol PointwiseMultiplicative : AdditiveArithmetic {
/// The one value.
///
/// One is the identity element for multiplication. For any value,
/// `x .* .one == x` and `.one .* x == x`.
static var one: Self { get }
/// The multiplicative inverse of self.
///
/// For any value, `x .* x.reciprocal == .one` and
/// `x.reciprocal .* x == .one`.
var reciprocal: Self { get }
/// Multiplies two values and produces their product.
///
/// - Parameters:
/// - lhs: The first value to multiply.
/// - rhs: The second value to multiply.
static func .*(lhs: Self, rhs: Self) -> Self
/// Multiplies two values and produces their product.
///
/// - Parameters:
/// - lhs: The first value to multiply.
/// - rhs: The second value to multiply.
static func .*=(lhs: inout Self, rhs: Self)
}
public extension PointwiseMultiplicative {
static func .*=(lhs: inout Self, rhs: Self) {
lhs = lhs .* rhs
}
}
public extension PointwiseMultiplicative
where Self : ExpressibleByIntegerLiteral {
static var one: Self {
return 1
}
}
/// A type that represents an unranked vector space. Values of this type are
/// elements in this vector space and have either no shape or a static shape.
public protocol VectorProtocol : AdditiveArithmetic {
/// The type of scalars in the vector space.
associatedtype VectorSpaceScalar : AdditiveArithmetic
func adding(_ x: VectorSpaceScalar) -> Self
mutating func add(_ x: VectorSpaceScalar)
func subtracting(_ x: VectorSpaceScalar) -> Self
mutating func subtract(_ x: VectorSpaceScalar)
/// Returns `self` multiplied by the given scalar.
func scaled(by scalar: VectorSpaceScalar) -> Self
/// Multiplies `self` by the given scalar.
mutating func scale(by scalar: VectorSpaceScalar)
}
public extension VectorProtocol {
mutating func add(_ x: VectorSpaceScalar) {
self = adding(x)
}
mutating func subtract(_ x: VectorSpaceScalar) {
self = subtracting(x)
}
mutating func scale(by scalar: VectorSpaceScalar) {
self = scaled(by: scalar)
}
}
/*
// Note: These default-implemented operators will slow down type-checking
// performance and break existing code.
public extension VectorProtocol {
static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.adding(rhs)
}
static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self {
rhs.adding(lhs)
}
static func += (lhs: inout Self, rhs: VectorSpaceScalar) {
lhs.add(rhs)
}
static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.subtracting(rhs)
}
static func -= (lhs: inout Self, rhs: VectorSpaceScalar) {
lhs.subtract(rhs)
}
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.scaled(by: rhs)
}
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
rhs.scaled(by: lhs)
}
static func *= (lhs: inout Self, rhs: VectorSpaceScalar) {
lhs.scale(by: rhs)
}
}
public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self {
-rhs.adding(lhs)
}
static prefix func - (x: Self) -> Self {
.zero - x
}
}
*/
/// A type that is differentiable in the Euclidean space.
/// The type may represent a vector space, or consist of a vector space and some
/// other non-differentiable component.
///
/// Mathematically, this represents a product manifold that consists of
/// a differentiable vector space and some arbitrary manifold, where the tangent
/// bundle of the entire product manifold is equal to the vector space
/// component.
///
/// This abstraction is useful for representing common differentiable data
/// structures that contain both differentiable vector properties and other
/// stored properties that do not have a derivative, e.g.
///
/// ```swift
/// struct Perceptron: @memberwise EuclideanDifferentiable {
/// var weight: SIMD16<Float>
/// var bias: Float
/// @noDerivative var useBias: Bool
/// }
/// ```
///
/// - Note: Conform a type to `EuclideanDifferentiable` if it is differentiable
/// only with respect to its vector space component and when its
/// `TangentVector` is equal to its vector space component.
public protocol EuclideanDifferentiable: Differentiable {
/// The differentiable vector component of `self`.
var differentiableVectorView: TangentVector { get }
}
public extension EuclideanDifferentiable where TangentVector == Self {
var differentiableVectorView: TangentVector { _read { yield self } }
}
//===----------------------------------------------------------------------===//
// Functional utilities
//===----------------------------------------------------------------------===//
/// Make a function be recomputed in its pullback, known as "checkpointing" in
/// traditional automatic differentiation.
@inlinable
public func withRecomputationInPullbacks<T, U>(
_ body: @escaping @differentiable (T) -> U
) -> @differentiable (T) -> U where T : Differentiable, U : Differentiable {
return differentiableFunction { x in
(value: body(x), pullback: { v in pullback(at: x, in: body)(v) })
}
}
public extension Differentiable {
@inlinable
@differentiable(wrt: self)
func withRecomputationInPullbacks<Result : Differentiable>(
_ body: @escaping @differentiable (Self) -> Result
) -> Result {
return body(self)
}
@inlinable
@derivative(of: withRecomputationInPullbacks)
internal func _vjp_withRecomputationInPullbacks<Result : Differentiable>(
_ body: @escaping @differentiable (Self) -> Result
) -> (value: Result, pullback: (Result.TangentVector) -> TangentVector) {
return Swift.valueWithPullback(
at: self, in: Swift.withRecomputationInPullbacks(body)
)
}
}