blob: 3abe2229ea33f4ef5b159046645d1c3686cc97df [file] [log] [blame]
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -verify %s
// Test forward-mode differentiation transform diagnostics.
// TODO: Move these tests back into `autodiff_diagnostics.swift` once
// forward mode reaches feature parity with reverse mode.
import _Differentiation
import DifferentiationUnittest
//===----------------------------------------------------------------------===//
// Basic function
//===----------------------------------------------------------------------===//
@differentiable
func basic(_ x: Float) -> Float {
return x
}
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
func cond(_ x: Float) -> Float {
if x > 0 {
return x * x
}
return x + x
}
//===----------------------------------------------------------------------===//
// Non-varied results
//===----------------------------------------------------------------------===//
@differentiable
func nonVariedResult(_ x: Float) -> Float {
// TODO(TF-788): Re-enable non-varied result warning.
// xpected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
return 0
}
//===----------------------------------------------------------------------===//
// Multiple results
//===----------------------------------------------------------------------===//
func multipleResults(_ x: Float) -> (Float, Float) {
return (x, x)
}
@differentiable
func usesMultipleResults(_ x: Float) -> Float {
let tuple = multipleResults(x)
return tuple.0 + tuple.1
}
//===----------------------------------------------------------------------===//
// `inout` parameter differentiation
//===----------------------------------------------------------------------===//
@differentiable
func activeInoutParamNonactiveInitialResult(_ x: Float) -> Float {
var result: Float = 1
result += x
return result
}
@differentiable
func activeInoutParamTuple(_ x: Float) -> Float {
var tuple = (x, x)
tuple.0 *= x
return x * tuple.0
}
// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
func activeInoutParamControlFlow(_ array: [Float]) -> Float {
var result: Float = 1
for i in withoutDerivative(at: array).indices {
result += array[i]
}
return result
}
struct X: Differentiable {
var x: Float
@differentiable(wrt: y)
mutating func mutate(_ y: X) { self.x = y.x }
}
@differentiable
func activeMutatingMethod(_ x: Float) -> Float {
let x1 = X.init(x: x)
var x2 = X.init(x: 0)
x2.mutate(x1)
return x1.x
}
struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) {}
}
@differentiable(wrt: x)
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
var result = x
result.mutatingMethod(result)
return result
}
//===----------------------------------------------------------------------===//
// Subset parameter differentiation thunks
//===----------------------------------------------------------------------===//
// FIXME(SR-13046): Non-differentiability diagnostic crash due to invalid source location.
/*
func testNoDerivativeParameter(_ f: @differentiable (Float, @noDerivative Float) -> Float) -> Float {
return derivative(at: 2, 3) { (x, y) in f(x * x, y) }
}
*/
//===----------------------------------------------------------------------===//
// Stored property access differentiation
//===----------------------------------------------------------------------===//
// Test differentiation of invalid stored property access instructions:
// `struct_extract`, `struct_element_addr`, `ref_element_addr`.
struct StructTangentVectorNotStruct: Differentiable {
var x: Float
enum TangentVector: Differentiable, AdditiveArithmetic {
case x(Float)
typealias TangentVector = Self
static func ==(_: Self, _: Self) -> Bool { fatalError() }
static var zero: Self { fatalError() }
static func +(_: Self, _: Self) -> Self { fatalError() }
static func -(_: Self, _: Self) -> Self { fatalError() }
}
mutating func move(along direction: TangentVector) {}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_struct_tangent_vector_not_struct")
func testStructTangentVectorNotStruct(_ s: StructTangentVectorNotStruct) -> Float {
// expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorNotStruct.x' because 'StructTangentVectorNotStruct.TangentVector' is not a struct}}
return s.x
}
// CHECK-LABEL: sil {{.*}} @test_struct_tangent_vector_not_struct
// CHECK: struct_extract {{%.*}} : $StructTangentVectorNotStruct, #StructTangentVectorNotStruct.x
struct StructOriginalPropertyNotDifferentiable: Differentiable {
struct Nondiff {
var x: Float
}
var nondiff: Nondiff
struct TangentVector: Differentiable & AdditiveArithmetic {
var nondiff: Float
}
mutating func move(along direction: TangentVector) {}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_struct_original_property_not_differentiable")
func testStructOriginalPropertyNotDifferentiable(_ s: StructOriginalPropertyNotDifferentiable) -> Float {
// expected-note @+1 {{cannot differentiate access to property 'StructOriginalPropertyNotDifferentiable.nondiff' because property type 'StructOriginalPropertyNotDifferentiable.Nondiff' does not conform to 'Differentiable'}}
return s.nondiff.x
}
// CHECK-LABEL: sil {{.*}} @test_struct_original_property_not_differentiable
// CHECK: struct_extract {{%.*}} : $StructOriginalPropertyNotDifferentiable, #StructOriginalPropertyNotDifferentiable.nondiff
struct StructTangentVectorPropertyNotFound: Differentiable {
var x: Float
struct TangentVector: Differentiable, AdditiveArithmetic {
var y: Float
}
mutating func move(along direction: TangentVector) {}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_struct_tangent_property_not_found")
func testStructTangentPropertyNotFound(_ s: StructTangentVectorPropertyNotFound) -> Float {
// expected-warning @+1 {{variable 'tmp' was never mutated}}
var tmp = s
// expected-note @+1 {{cannot differentiate access to property 'StructTangentVectorPropertyNotFound.x' because 'StructTangentVectorPropertyNotFound.TangentVector' does not have a stored property named 'x'}}
return tmp.x
}
// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_found
// CHECK: struct_element_addr {{%.*}} : $*StructTangentVectorPropertyNotFound, #StructTangentVectorPropertyNotFound.x
struct StructTangentPropertyWrongType: Differentiable {
var x: Float
struct TangentVector: Differentiable, AdditiveArithmetic {
var x: Double
}
mutating func move(along direction: TangentVector) {}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_struct_tangent_property_wrong_type")
func testStructTangentPropertyWrongType(_ s: StructTangentPropertyWrongType) -> Float {
// expected-warning @+1 {{variable 'tmp' was never mutated}}
var tmp = s
// expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyWrongType.x' because 'StructTangentPropertyWrongType.TangentVector.x' does not have expected type 'Float.TangentVector' (aka 'Float')}}
return tmp.x
}
// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_wrong_type
// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyWrongType, #StructTangentPropertyWrongType.x
final class ClassTangentPropertyWrongType: Differentiable {
var x: Float = 0
struct TangentVector: Differentiable, AdditiveArithmetic {
var x: Double
}
func move(along direction: TangentVector) {}
}
// SR-13464: Missing support for classes in forward-mode AD
/*
// xpected-error @+2 {{function is not differentiable}}
// xpected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_class_tangent_property_wrong_type")
func testClassTangentPropertyWrongType(_ c: ClassTangentPropertyWrongType) -> Float {
// xpected-warning @+1 {{variable 'tmp' was never mutated}}
var tmp = c
// xpected-note @+1 {{cannot differentiate access to property 'ClassTangentPropertyWrongType.x' because 'ClassTangentPropertyWrongType.TangentVector.x' does not have expected type 'Float.TangentVector' (aka 'Float')}}
return tmp.x
}
*/
// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_wrong_type
// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyWrongType, #ClassTangentPropertyWrongType.x
struct StructTangentPropertyNotStored: Differentiable {
var x: Float
struct TangentVector: Differentiable, AdditiveArithmetic {
var x: Float { 0 }
}
mutating func move(along direction: TangentVector) {}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_struct_tangent_property_not_stored")
func testStructTangentPropertyNotStored(_ s: StructTangentPropertyNotStored) -> Float {
// expected-warning @+1 {{variable 'tmp' was never mutated}}
var tmp = s
// expected-note @+1 {{cannot differentiate access to property 'StructTangentPropertyNotStored.x' because 'StructTangentPropertyNotStored.TangentVector.x' is not a stored property}}
return tmp.x
}
// CHECK-LABEL: sil {{.*}} @test_struct_tangent_property_not_stored
// CHECK: struct_element_addr {{%.*}} : $*StructTangentPropertyNotStored, #StructTangentPropertyNotStored.x
final class ClassTangentPropertyNotStored: Differentiable {
var x: Float = 0
struct TangentVector: Differentiable, AdditiveArithmetic {
var x: Float { 0 }
}
func move(along direction: TangentVector) {}
}
// SR-13464: Missing support for classes in forward-mode AD
/*
// xpected-error @+2 {{function is not differentiable}}
// xpected-note @+3 {{when differentiating this function definition}}
@differentiable
@_silgen_name("test_class_tangent_property_not_stored")
func testClassTangentPropertyNotStored(_ c: ClassTangentPropertyNotStored) -> Float {
// xpected-warning @+1 {{variable 'tmp' was never mutated}}
var tmp = c
// xpected-note @+1 {{cannot differentiate access to property 'ClassTangentPropertyNotStored.x' because 'ClassTangentPropertyNotStored.TangentVector.x' is not a stored property}}
return tmp.x
}
*/
// CHECK-LABEL: sil {{.*}} @test_class_tangent_property_not_stored
// CHECK: ref_element_addr {{%.*}} : $ClassTangentPropertyNotStored, #ClassTangentPropertyNotStored.x