| //===--- DifferentiationUnittest.swift.gyb --------------------------------===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| |
| @_exported import _Differentiation |
| import StdlibUnittest |
| |
| public enum _GlobalLeakCount { |
| public static var count = 0 |
| } |
| |
| /// Execute body and check expected leak count. |
| public func withLeakChecking( |
| expectedLeakCount: Int = 0, file: String = #file, line: UInt = #line, |
| _ body: () -> Void |
| ) { |
| // Note: compare expected leak count with relative leak count after |
| // running `body`. |
| // This approach is more robust than comparing leak count with zero |
| // and resetting leak count to zero, which is stateful and causes issues. |
| let beforeLeakCount = _GlobalLeakCount.count |
| body() |
| let leakCount = _GlobalLeakCount.count - beforeLeakCount |
| expectEqual( |
| expectedLeakCount, leakCount, "Leaks detected: \(leakCount)", |
| file: file, line: line) |
| } |
| |
| public extension TestSuite { |
| /// Execute test function and check expected leak count. |
| func testWithLeakChecking( |
| _ name: String, |
| expectedLeakCount: Int = 0, |
| file: String = #file, line: UInt = #line, |
| _ testFunction: @escaping () -> Void |
| ) { |
| test(name, file: file, line: line) { |
| withLeakChecking( |
| expectedLeakCount: expectedLeakCount, file: file, |
| line: line, testFunction) |
| } |
| } |
| } |
| |
| /// A resilient type that tracks the number of live instances of a wrapped |
| /// value type. |
| /// |
| /// `Tracked<T>` is used to check for memory leaks in functions created via |
| /// automatic differentiation. |
| public struct Tracked<T> { |
| internal class Box { |
| fileprivate var value: T |
| init(_ value: T) { |
| self.value = value |
| _GlobalLeakCount.count += 1 |
| } |
| deinit { |
| _GlobalLeakCount.count -= 1 |
| } |
| } |
| internal var handle: Box |
| |
| @differentiable(where T: Differentiable, T == T.TangentVector) |
| public init(_ value: T) { |
| self.handle = Box(value) |
| } |
| |
| @differentiable(where T: Differentiable, T == T.TangentVector) |
| public var value: T { |
| get { handle.value } |
| set { handle.value = newValue } |
| } |
| } |
| |
| /// A non-resilient type that tracks the number of live instances of a wrapped |
| /// value type. |
| /// |
| /// `NonresilientTracked<T>` is used to check for memory leaks in functions |
| /// created via automatic differentiation. |
| @frozen |
| public struct NonresilientTracked<T> { |
| @usableFromInline |
| internal class Box { |
| fileprivate var value: T |
| init(_ value: T) { |
| self.value = value |
| _GlobalLeakCount.count += 1 |
| } |
| deinit { |
| _GlobalLeakCount.count -= 1 |
| } |
| } |
| @usableFromInline |
| internal var handle: Box |
| |
| @differentiable(where T: Differentiable, T == T.TangentVector) |
| public init(_ value: T) { |
| self.handle = Box(value) |
| } |
| |
| @differentiable(where T: Differentiable, T == T.TangentVector) |
| public var value: T { |
| get { handle.value } |
| set { handle.value = newValue } |
| } |
| } |
| |
| % for Self in ['Tracked', 'NonresilientTracked']: |
| |
| extension ${Self}: ExpressibleByFloatLiteral |
| where T: ExpressibleByFloatLiteral { |
| public init(floatLiteral value: T.FloatLiteralType) { |
| self.handle = Box(T(floatLiteral: value)) |
| } |
| } |
| |
| extension ${Self}: CustomStringConvertible { |
| public var description: String { return "${Self}(\(value))" } |
| } |
| |
| extension ${Self}: ExpressibleByIntegerLiteral |
| where T: ExpressibleByIntegerLiteral { |
| public init(integerLiteral value: T.IntegerLiteralType) { |
| self.handle = Box(T(integerLiteral: value)) |
| } |
| } |
| |
| extension ${Self}: Comparable where T: Comparable { |
| public static func < (lhs: ${Self}, rhs: ${Self}) -> Bool { |
| return lhs.value < rhs.value |
| } |
| public static func <= (lhs: ${Self}, rhs: ${Self}) -> Bool { |
| return lhs.value <= rhs.value |
| } |
| public static func > (lhs: ${Self}, rhs: ${Self}) -> Bool { |
| return lhs.value > rhs.value |
| } |
| public static func >= (lhs: ${Self}, rhs: ${Self}) -> Bool { |
| return lhs.value >= rhs.value |
| } |
| } |
| |
| extension ${Self}: AdditiveArithmetic where T: AdditiveArithmetic { |
| public static var zero: ${Self} { return ${Self}(T.zero) } |
| public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} { |
| return ${Self}(lhs.value + rhs.value) |
| } |
| public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} { |
| return ${Self}(lhs.value - rhs.value) |
| } |
| } |
| |
| extension ${Self}: Equatable where T: Equatable { |
| public static func == (lhs: ${Self}, rhs: ${Self}) -> Bool { |
| return lhs.value == rhs.value |
| } |
| } |
| |
| extension ${Self}: SignedNumeric & Numeric |
| where T: SignedNumeric, T == T.Magnitude { |
| public typealias Magnitude = ${Self}<T.Magnitude> |
| |
| public init?<U>(exactly source: U) where U: BinaryInteger { |
| if let t = T(exactly: source) { |
| self.init(t) |
| } |
| return nil |
| } |
| public var magnitude: Magnitude { return Magnitude(value.magnitude) } |
| |
| public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} { |
| return ${Self}(lhs.value * rhs.value) |
| } |
| |
| public static func *= (lhs: inout ${Self}, rhs: ${Self}) { |
| lhs = lhs * rhs |
| } |
| } |
| |
| extension ${Self} where T: FloatingPoint { |
| public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} { |
| return ${Self}(lhs.value / rhs.value) |
| } |
| |
| public static func /= (lhs: inout ${Self}, rhs: ${Self}) { |
| lhs = lhs / rhs |
| } |
| } |
| |
| extension ${Self}: Strideable |
| where T: Strideable, T.Stride == T.Stride.Magnitude { |
| public typealias Stride = ${Self}<T.Stride> |
| |
| public func distance(to other: ${Self}) -> Stride { |
| return Stride(value.distance(to: other.value)) |
| } |
| public func advanced(by n: Stride) -> ${Self} { |
| return ${Self}(value.advanced(by: n.value)) |
| } |
| } |
| |
| // For now, `T` must be restricted to trivial types (like `Float` or `Tensor`). |
| extension ${Self}: Differentiable |
| where T: Differentiable, T == T.TangentVector { |
| public typealias TangentVector = ${Self}<T.TangentVector> |
| } |
| |
| extension ${Self} where T: Differentiable, T == T.TangentVector { |
| @usableFromInline |
| @derivative(of: init) |
| internal static func _vjpInit(_ value: T) |
| -> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) |
| { |
| return (${Self}(value), { v in v.value }) |
| } |
| |
| @usableFromInline |
| @derivative(of: init) |
| internal static func _jvpInit(_ value: T) |
| -> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) |
| { |
| return (${Self}(value), { v in ${Self}(v) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: value) |
| internal func _vjpValue() -> ( |
| value: T, pullback: (T.TangentVector) -> Self.TangentVector |
| ) { |
| return (value, { v in ${Self}(v) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: value) |
| internal func _jvpValue() -> ( |
| value: T, differential: (Self.TangentVector) -> T.TangentVector |
| ) { |
| return (value, { v in v.value }) |
| } |
| } |
| |
| extension ${Self} where T: Differentiable, T == T.TangentVector { |
| @usableFromInline |
| @derivative(of: +) |
| internal static func _vjpAdd(lhs: Self, rhs: Self) |
| -> (value: Self, pullback: (Self) -> (Self, Self)) |
| { |
| return (lhs + rhs, { v in (v, v) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: +) |
| internal static func _jvpAdd(lhs: Self, rhs: Self) |
| -> (value: Self, differential: (Self, Self) -> Self) |
| { |
| return (lhs + rhs, { $0 + $1 }) |
| } |
| |
| @usableFromInline |
| @derivative(of: -) |
| internal static func _vjpSubtract(lhs: Self, rhs: Self) |
| -> (value: Self, pullback: (Self) -> (Self, Self)) |
| { |
| return (lhs - rhs, { v in (v, .zero - v) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: -) |
| internal static func _jvpSubtract(lhs: Self, rhs: Self) |
| -> (value: Self, differential: (Self, Self) -> Self) |
| { |
| return (lhs - rhs, { $0 - $1 }) |
| } |
| } |
| |
| extension ${Self} |
| where |
| T: Differentiable & SignedNumeric, T == T.Magnitude, |
| T == T.TangentVector |
| { |
| @usableFromInline |
| @derivative(of: *) |
| internal static func _vjpMultiply(lhs: Self, rhs: Self) |
| -> (value: Self, pullback: (Self) -> (Self, Self)) |
| { |
| return (lhs * rhs, { v in (v * rhs, v * lhs) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: *) |
| internal static func _jvpMultiply(lhs: Self, rhs: Self) |
| -> (value: Self, differential: (Self, Self) -> (Self)) |
| { |
| return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) |
| } |
| } |
| |
| extension ${Self} |
| where T: Differentiable & FloatingPoint, T == T.TangentVector { |
| @usableFromInline |
| @derivative(of: /) |
| internal static func _vjpDivide(lhs: Self, rhs: Self) |
| -> (value: Self, pullback: (Self) -> (Self, Self)) |
| { |
| return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) |
| } |
| |
| @usableFromInline |
| @derivative(of: /) |
| internal static func _jvpDivide(lhs: Self, rhs: Self) |
| -> (value: Self, differential: (Self, Self) -> (Self)) |
| { |
| return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy }) |
| } |
| } |
| |
| // Differential operators for `${Self}<T>`. |
| |
| public func gradient<T, R: FloatingPoint>( |
| at x: T, in f: @differentiable (T) -> ${Self}<R> |
| ) -> T.TangentVector where R.TangentVector == R { |
| return pullback(at: x, in: f)(1) |
| } |
| |
| public func gradient<T, U, R: FloatingPoint>( |
| at x: T, _ y: U, in f: @differentiable (T, U) -> ${Self}<R> |
| ) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R { |
| return pullback(at: x, y, in: f)(1) |
| } |
| |
| public func derivative<T: FloatingPoint, R>( |
| at x: ${Self}<T>, in f: @differentiable (${Self}<T>) -> R |
| ) -> R.TangentVector where T.TangentVector == T { |
| return differential(at: x, in: f)(1) |
| } |
| |
| public func derivative<T: FloatingPoint, U: FloatingPoint, R>( |
| at x: ${Self}<T>, _ y: ${Self}<U>, |
| in f: @differentiable (${Self}<T>, ${Self}<U>) -> R |
| ) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U { |
| return differential(at: x, y, in: f)(1, 1) |
| } |
| |
| public func valueWithGradient<T, R: FloatingPoint>( |
| at x: T, in f: @differentiable (T) -> ${Self}<R> |
| ) -> (value: ${Self}<R>, gradient: T.TangentVector) { |
| let (y, pullback) = valueWithPullback(at: x, in: f) |
| return (y, pullback(1)) |
| } |
| |
| public func valueWithDerivative<T: FloatingPoint, R>( |
| at x: ${Self}<T>, in f: @differentiable (${Self}<T>) -> R |
| ) -> (value: R, derivative: R.TangentVector) { |
| let (y, differential) = valueWithDifferential(at: x, in: f) |
| return (y, differential(1)) |
| } |
| |
| % end |