blob: 0471557654b549b653da317c4f315848f80913e2 [file] [log] [blame]
// RUN: %target-run-simple-swift
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -module-name null -o /dev/null 2>&1 %s | %FileCheck %s
// REQUIRES: executable_test
// REQUIRES: asserts
// Test differentiation of `Optional` properties.
import DifferentiationUnittest
import StdlibUnittest
var OptionalTests = TestSuite("OptionalPropertyDifferentiation")
// Test `Optional` struct stored properties.
struct Struct: Differentiable {
var stored: Float
var optional: Float?
@differentiable
func method() -> Float {
let s: Struct
do {
let tmp = Struct(stored: stored, optional: optional)
let tuple = (tmp, tmp)
s = tuple.0
}
if let x = s.optional {
return x * s.stored
}
return s.stored
}
}
// Check active SIL instructions in representative original functions.
// This tests SIL instruction coverage of derivative function cloners (e.g. PullbackCloner).
// CHECK-LABEL: [AD] Activity info for ${{.*}}Struct{{.*}}method{{.*}} at parameter indices (0) and result indices (0)
// CHECK: [ACTIVE] {{.*}} struct_extract {{%.*}} : $Struct, #Struct.stored
// CHECK: [ACTIVE] {{.*}} struct_extract {{%.*}} : $Struct, #Struct.optional
// CHECK: [ACTIVE] {{.*}} tuple ({{%.*}} : $Struct, {{%.*}} : $Struct)
// CHECK: [ACTIVE] {{.*}} destructure_tuple {{%.*}} : $(Struct, Struct)
// CHECK: [ACTIVE] {{.*}} struct_element_addr {{%.*}} : $*Struct, #Struct.optional
// CHECK: [ACTIVE] {{.*}} struct_element_addr {{%.*}} : $*Struct, #Struct.stored
// CHECK-LABEL: [AD] Activity info for $s4null6StructV6stored8optionalACSf_SfSgtcfC at parameter indices (0, 1) and result indices (0)
// CHECK: [ACTIVE] {{%.*}} struct $Struct ({{%.*}} : $Float, {{%.*}} : $Optional<Float>)
struct StructTracked: Differentiable {
var stored: NonresilientTracked<Float>
var optional: NonresilientTracked<Float>?
@differentiable
func method() -> NonresilientTracked<Float> {
let s: StructTracked
do {
let tmp = StructTracked(stored: stored, optional: optional)
let tuple = (tmp, tmp)
s = tuple.0
}
if let x = s.optional {
return x * s.stored
}
return s.stored
}
}
struct StructGeneric<T: Differentiable>: Differentiable {
var stored: T
var optional: T?
@differentiable
func method() -> T {
let s: StructGeneric
do {
let tmp = StructGeneric(stored: stored, optional: optional)
let tuple = (tmp, tmp)
s = tuple.0
}
if let x = s.optional {
return x
}
return s.stored
}
}
OptionalTests.test("Optional struct stored properties") {
expectEqual(
valueWithGradient(at: Struct(stored: 3, optional: 4), in: { $0.method() }),
(12, .init(stored: 4, optional: .init(3))))
expectEqual(
valueWithGradient(at: Struct(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
expectEqual(
valueWithGradient(at: StructTracked(stored: 3, optional: 4), in: { $0.method() }),
(12, .init(stored: 4, optional: .init(3))))
expectEqual(
valueWithGradient(at: StructTracked(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
expectEqual(
valueWithGradient(at: StructGeneric<Float>(stored: 3, optional: 4), in: { $0.method() }),
(4, .init(stored: 0, optional: .init(1))))
expectEqual(
valueWithGradient(at: StructGeneric<Float>(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
}
// Test `Optional` class stored properties.
struct Class: Differentiable {
var stored: Float
var optional: Float?
init(stored: Float, optional: Float?) {
self.stored = stored
self.optional = optional
}
@differentiable
func method() -> Float {
let c: Class
do {
let tmp = Class(stored: stored, optional: optional)
let tuple = (tmp, tmp)
c = tuple.0
}
if let x = c.optional {
return x * c.stored
}
return c.stored
}
}
struct ClassTracked: Differentiable {
var stored: NonresilientTracked<Float>
var optional: NonresilientTracked<Float>?
init(stored: NonresilientTracked<Float>, optional: NonresilientTracked<Float>?) {
self.stored = stored
self.optional = optional
}
@differentiable
func method() -> NonresilientTracked<Float> {
let c: ClassTracked
do {
let tmp = ClassTracked(stored: stored, optional: optional)
let tuple = (tmp, tmp)
c = tuple.0
}
if let x = c.optional {
return x * c.stored
}
return c.stored
}
}
struct ClassGeneric<T: Differentiable>: Differentiable {
var stored: T
var optional: T?
init(stored: T, optional: T?) {
self.stored = stored
self.optional = optional
}
@differentiable
func method() -> T {
let c: ClassGeneric
do {
let tmp = ClassGeneric(stored: stored, optional: optional)
let tuple = (tmp, tmp)
c = tuple.0
}
if let x = c.optional {
return x
}
return c.stored
}
}
OptionalTests.test("Optional class stored properties") {
expectEqual(
valueWithGradient(at: Class(stored: 3, optional: 4), in: { $0.method() }),
(12, .init(stored: 4, optional: .init(3))))
expectEqual(
valueWithGradient(at: Class(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
expectEqual(
valueWithGradient(at: ClassTracked(stored: 3, optional: 4), in: { $0.method() }),
(12, .init(stored: 4, optional: .init(3))))
expectEqual(
valueWithGradient(at: ClassTracked(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
expectEqual(
valueWithGradient(at: ClassGeneric<Tracked<Float>>(stored: 3, optional: 4), in: { $0.method() }),
(4, .init(stored: 0, optional: .init(1))))
expectEqual(
valueWithGradient(at: ClassGeneric<Tracked<Float>>(stored: 3, optional: nil), in: { $0.method() }),
(3, .init(stored: 1, optional: .init(0))))
}
runAllTests()