blob: ecfc60e345af765cac026228d5d78a251e771f91 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
// UNSUPPORTED: use_os_stdlib
// `inout` parameter differentiation tests.
import DifferentiationUnittest
import StdlibUnittest
var InoutParameterAutoDiffTests = TestSuite("InoutParameterDifferentiation")
// TODO(TF-1173): Move floating-point mutating operation tests to
// `test/AutoDiff/stdlib/floating_point.swift.gyb` when forward-mode
// differentiation supports `inout` parameter differentiation.
InoutParameterAutoDiffTests.test("Float.+=") {
func mutatingAddWrapper(_ x: Float, _ y: Float) -> Float {
var result: Float = x
result += y
return result
}
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingAddWrapper))
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingAddWrapper)(10))
}
InoutParameterAutoDiffTests.test("Float.-=") {
func mutatingSubtractWrapper(_ x: Float, _ y: Float) -> Float {
var result: Float = x
result += y
return result
}
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingSubtractWrapper))
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingSubtractWrapper)(10))
}
InoutParameterAutoDiffTests.test("Float.*=") {
func mutatingMultiplyWrapper(_ x: Float, _ y: Float) -> Float {
var result: Float = x
result += y
return result
}
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingMultiplyWrapper))
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingMultiplyWrapper)(10))
}
InoutParameterAutoDiffTests.test("Float./=") {
func mutatingDivideWrapper(_ x: Float, _ y: Float) -> Float {
var result: Float = x
result += y
return result
}
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingDivideWrapper))
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingDivideWrapper)(10))
}
// Simplest possible `inout` parameter differentiation.
InoutParameterAutoDiffTests.test("InoutIdentity") {
// Semantically, an empty function with an `inout` parameter is an identity
// function.
func inoutIdentity(_ x: inout Float) {}
func identity(_ x: Float) -> Float {
var result = x
inoutIdentity(&result)
return result
}
expectEqual(1, gradient(at: 10, in: identity))
expectEqual(10, pullback(at: 10, in: identity)(10))
}
extension Float {
// Custom version of `Float.*=`, implemented using `Float.*` and mutation.
// Verify that its generated derivative has the same behavior as the
// registered derivative for `Float.*=`.
@differentiable
static func multiplyAssign(_ lhs: inout Float, _ rhs: Float) {
lhs = lhs * rhs
}
}
InoutParameterAutoDiffTests.test("ControlFlow") {
func sum(_ array: [Float]) -> Float {
var result: Float = 0
for i in withoutDerivative(at: array.indices) {
result += array[i]
}
return result
}
expectEqual([1, 1, 1], gradient(at: [1, 2, 3], in: sum))
func product(_ array: [Float]) -> Float {
var result: Float = 1
for i in withoutDerivative(at: array.indices) {
result *= array[i]
}
return result
}
expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: product))
func productCustom(_ array: [Float]) -> Float {
var result: Float = 1
for i in withoutDerivative(at: array.indices) {
Float.multiplyAssign(&result, array[i])
}
return result
}
expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: productCustom))
}
InoutParameterAutoDiffTests.test("SetAccessor") {
struct S: Differentiable {
var x: Float
var computed: Float {
get { x }
set { x = newValue }
}
}
// `squared` implemented using a `set` accessor.
func squared(_ x: Float) -> Float {
var s = S(x: 1)
s.x *= x
s.computed *= x
return s.x
}
expectEqual(6, gradient(at: 3, in: squared))
expectEqual(8, gradient(at: 4, in: squared))
}
// Test differentiation wrt `inout` parameters that have a class type.
InoutParameterAutoDiffTests.test("InoutClassParameter") {
class Class: Differentiable {
@differentiable
var x: Float
init(_ x: Float) {
self.x = x
}
}
do {
func squaredViaMutation(_ c: inout Class) {
c = Class(c.x * c.x)
}
func squared(_ x: Float) -> Float {
var c = Class(x)
squaredViaMutation(&c)
return c.x
}
expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
expectEqual(200, pullback(at: 10, in: squared)(10))
}
do {
func squaredViaModifyAccessor(_ c: inout Class) {
// The line below calls `Class.x.modify`.
c.x *= c.x
}
func squared(_ x: Float) -> Float {
var c = Class(x)
squaredViaModifyAccessor(&c)
return c.x
}
// FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values.
// expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
// expectEqual(200, pullback(at: 10, in: squared)(10))
expectEqual((100, 1), valueWithGradient(at: 10, in: squared))
expectEqual(10, pullback(at: 10, in: squared)(10))
}
}
// SR-13305: Test function with non-wrt `inout` parameter, which should be
// treated as a differentiability result.
protocol SR_13305_Protocol {
@differentiable(wrt: x)
func method(_ x: Float, _ y: inout Float)
@differentiable(wrt: x)
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T)
}
InoutParameterAutoDiffTests.test("non-wrt inout parameter") {
struct SR_13305_Struct: SR_13305_Protocol {
@differentiable(wrt: x)
func method(_ x: Float, _ y: inout Float) {
y = y * x
}
@differentiable(wrt: x)
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T) {
y = x
}
}
@differentiable(wrt: x)
func foo(_ s: SR_13305_Struct, _ x: Float, _ y: Float) -> Float {
var y = y
s.method(x, &y)
return y
}
@differentiable(wrt: x)
func fooGeneric<T: SR_13305_Protocol>(_ s: T, _ x: Float, _ y: Float) -> Float {
var y = y
s.method(x, &y)
return x
}
let s = SR_13305_Struct()
do {
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { foo(s, $0, $1) })
expectEqual(6, value)
expectEqual((3, 2), (dx, dy))
}
expectEqual((value: 6, gradient: 3), valueWithGradient(at: 2, in: { foo(s, $0, 3) }))
do {
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { fooGeneric(s, $0, $1) })
expectEqual(2, value)
expectEqual((1, 0), (dx, dy))
}
expectEqual((value: 2, gradient: 1), valueWithGradient(at: 2, in: { fooGeneric(s, $0, 3) }))
}
runAllTests()