blob: 282f89f93324f9e3d5f0b2ef4d7489c2eb3ae951 [file] [log] [blame]
//===--- DifferentiationUnittest.swift.gyb --------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
@_exported import _Differentiation
import StdlibUnittest
public enum _GlobalLeakCount {
public static var count = 0
}
/// Execute body and check expected leak count.
public func withLeakChecking(
expectedLeakCount: Int = 0, file: String = #file, line: UInt = #line,
_ body: () -> Void
) {
// Note: compare expected leak count with relative leak count after
// running `body`.
// This approach is more robust than comparing leak count with zero
// and resetting leak count to zero, which is stateful and causes issues.
let beforeLeakCount = _GlobalLeakCount.count
body()
let leakCount = _GlobalLeakCount.count - beforeLeakCount
expectEqual(
expectedLeakCount, leakCount, "Leaks detected: \(leakCount)",
file: file, line: line)
}
public extension TestSuite {
/// Execute test function and check expected leak count.
func testWithLeakChecking(
_ name: String,
expectedLeakCount: Int = 0,
file: String = #file, line: UInt = #line,
_ testFunction: @escaping () -> Void
) {
test(name, file: file, line: line) {
withLeakChecking(
expectedLeakCount: expectedLeakCount, file: file,
line: line, testFunction)
}
}
}
/// A resilient type that tracks the number of live instances of a wrapped
/// value type.
///
/// `Tracked<T>` is used to check for memory leaks in functions created via
/// automatic differentiation.
public struct Tracked<T> {
internal class Box {
fileprivate var value: T
init(_ value: T) {
self.value = value
_GlobalLeakCount.count += 1
}
deinit {
_GlobalLeakCount.count -= 1
}
}
internal var handle: Box
@differentiable(where T: Differentiable, T == T.TangentVector)
public init(_ value: T) {
self.handle = Box(value)
}
@differentiable(where T: Differentiable, T == T.TangentVector)
public var value: T {
get { handle.value }
set { handle.value = newValue }
}
}
/// A non-resilient type that tracks the number of live instances of a wrapped
/// value type.
///
/// `NonresilientTracked<T>` is used to check for memory leaks in functions
/// created via automatic differentiation.
@frozen
public struct NonresilientTracked<T> {
@usableFromInline
internal class Box {
fileprivate var value: T
init(_ value: T) {
self.value = value
_GlobalLeakCount.count += 1
}
deinit {
_GlobalLeakCount.count -= 1
}
}
@usableFromInline
internal var handle: Box
@differentiable(where T: Differentiable, T == T.TangentVector)
public init(_ value: T) {
self.handle = Box(value)
}
@differentiable(where T: Differentiable, T == T.TangentVector)
public var value: T {
get { handle.value }
set { handle.value = newValue }
}
}
% for Self in ['Tracked', 'NonresilientTracked']:
extension ${Self}: ExpressibleByFloatLiteral
where T: ExpressibleByFloatLiteral {
public init(floatLiteral value: T.FloatLiteralType) {
self.handle = Box(T(floatLiteral: value))
}
}
extension ${Self}: CustomStringConvertible {
public var description: String { return "${Self}(\(value))" }
}
extension ${Self}: ExpressibleByIntegerLiteral
where T: ExpressibleByIntegerLiteral {
public init(integerLiteral value: T.IntegerLiteralType) {
self.handle = Box(T(integerLiteral: value))
}
}
extension ${Self}: Comparable where T: Comparable {
public static func < (lhs: ${Self}, rhs: ${Self}) -> Bool {
return lhs.value < rhs.value
}
public static func <= (lhs: ${Self}, rhs: ${Self}) -> Bool {
return lhs.value <= rhs.value
}
public static func > (lhs: ${Self}, rhs: ${Self}) -> Bool {
return lhs.value > rhs.value
}
public static func >= (lhs: ${Self}, rhs: ${Self}) -> Bool {
return lhs.value >= rhs.value
}
}
extension ${Self}: AdditiveArithmetic where T: AdditiveArithmetic {
public static var zero: ${Self} { return ${Self}(T.zero) }
public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
return ${Self}(lhs.value + rhs.value)
}
public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
return ${Self}(lhs.value - rhs.value)
}
}
extension ${Self}: Equatable where T: Equatable {
public static func == (lhs: ${Self}, rhs: ${Self}) -> Bool {
return lhs.value == rhs.value
}
}
extension ${Self}: SignedNumeric & Numeric
where T: SignedNumeric, T == T.Magnitude {
public typealias Magnitude = ${Self}<T.Magnitude>
public init?<U>(exactly source: U) where U: BinaryInteger {
if let t = T(exactly: source) {
self.init(t)
}
return nil
}
public var magnitude: Magnitude { return Magnitude(value.magnitude) }
public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
return ${Self}(lhs.value * rhs.value)
}
public static func *= (lhs: inout ${Self}, rhs: ${Self}) {
lhs = lhs * rhs
}
}
extension ${Self} where T: FloatingPoint {
public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
return ${Self}(lhs.value / rhs.value)
}
public static func /= (lhs: inout ${Self}, rhs: ${Self}) {
lhs = lhs / rhs
}
}
extension ${Self}: Strideable
where T: Strideable, T.Stride == T.Stride.Magnitude {
public typealias Stride = ${Self}<T.Stride>
public func distance(to other: ${Self}) -> Stride {
return Stride(value.distance(to: other.value))
}
public func advanced(by n: Stride) -> ${Self} {
return ${Self}(value.advanced(by: n.value))
}
}
// For now, `T` must be restricted to trivial types (like `Float` or `Tensor`).
extension ${Self}: Differentiable
where T: Differentiable, T == T.TangentVector {
public typealias TangentVector = ${Self}<T.TangentVector>
}
extension ${Self} where T: Differentiable, T == T.TangentVector {
@usableFromInline
@derivative(of: init)
internal static func _vjpInit(_ value: T)
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector))
{
return (${Self}(value), { v in v.value })
}
@usableFromInline
@derivative(of: init)
internal static func _jvpInit(_ value: T)
-> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector))
{
return (${Self}(value), { v in ${Self}(v) })
}
@usableFromInline
@derivative(of: value)
internal func _vjpValue() -> (
value: T, pullback: (T.TangentVector) -> Self.TangentVector
) {
return (value, { v in ${Self}(v) })
}
@usableFromInline
@derivative(of: value)
internal func _jvpValue() -> (
value: T, differential: (Self.TangentVector) -> T.TangentVector
) {
return (value, { v in v.value })
}
}
extension ${Self} where T: Differentiable, T == T.TangentVector {
@usableFromInline
@derivative(of: +)
internal static func _vjpAdd(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self))
{
return (lhs + rhs, { v in (v, v) })
}
@usableFromInline
@derivative(of: +)
internal static func _jvpAdd(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> Self)
{
return (lhs + rhs, { $0 + $1 })
}
@usableFromInline
@derivative(of: -)
internal static func _vjpSubtract(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self))
{
return (lhs - rhs, { v in (v, .zero - v) })
}
@usableFromInline
@derivative(of: -)
internal static func _jvpSubtract(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> Self)
{
return (lhs - rhs, { $0 - $1 })
}
}
extension ${Self}
where
T: Differentiable & SignedNumeric, T == T.Magnitude,
T == T.TangentVector
{
@usableFromInline
@derivative(of: *)
internal static func _vjpMultiply(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self))
{
return (lhs * rhs, { v in (v * rhs, v * lhs) })
}
@usableFromInline
@derivative(of: *)
internal static func _jvpMultiply(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self))
{
return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs })
}
}
extension ${Self}
where T: Differentiable & FloatingPoint, T == T.TangentVector {
@usableFromInline
@derivative(of: /)
internal static func _vjpDivide(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self))
{
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}
@usableFromInline
@derivative(of: /)
internal static func _jvpDivide(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self))
{
return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy })
}
}
// Differential operators for `${Self}<T>`.
public func gradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable (T) -> ${Self}<R>
) -> T.TangentVector where R.TangentVector == R {
return pullback(at: x, in: f)(1)
}
public func gradient<T, U, R: FloatingPoint>(
at x: T, _ y: U, in f: @differentiable (T, U) -> ${Self}<R>
) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R {
return pullback(at: x, y, in: f)(1)
}
public func derivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable (${Self}<T>) -> R
) -> R.TangentVector where T.TangentVector == T {
return differential(at: x, in: f)(1)
}
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
at x: ${Self}<T>, _ y: ${Self}<U>,
in f: @differentiable (${Self}<T>, ${Self}<U>) -> R
) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U {
return differential(at: x, y, in: f)(1, 1)
}
public func valueWithGradient<T, R: FloatingPoint>(
at x: T, in f: @differentiable (T) -> ${Self}<R>
) -> (value: ${Self}<R>, gradient: T.TangentVector) {
let (y, pullback) = valueWithPullback(at: x, in: f)
return (y, pullback(1))
}
public func valueWithDerivative<T: FloatingPoint, R>(
at x: ${Self}<T>, in f: @differentiable (${Self}<T>) -> R
) -> (value: R, derivative: R.TangentVector) {
let (y, differential) = valueWithDifferential(at: x, in: f)
return (y, differential(1))
}
% end