blob: 49ee8df0257e2891a6c55b3fab8ed6b85300d4f1 [file] [log] [blame]
// RUN: %target-swift-emit-sil -enable-experimental-forward-mode-differentiation -verify -module-name main %s | %FileCheck %s
// Test derivative generic signatures:
// - In `@differentiable` and `@derivative` attributes.
// - In SIL differentiability witnesses.
// - In generated derivative functions and derivative function thunks.
//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//
// If all generic parameters are concrete (e.g. bound via same-type requirements
// to concrete types), `@differentiable` attribute should not have a derivative
// generic signature.
// Test `@differentiable` attribute where original declaration has generic
// signature and all generic parameters are concrete (e.g. bound to concrete
// types via same-type requirements). SILGen lowers the original declaration to
// a function with no generic signature, so the differentiability witness should
// have no derivative generic signature.
// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
// Test same-type requirements.
// If all generic parameters are concrete (e.g. bound via same-type requirements
// to concrete types), `@differentiable` attribute should have no derivative
// generic signature. Otherwise, :
struct AllConcrete<T>: Differentiable {}
extension AllConcrete {
// Original generic signature: `<T>`.
// Where clause generic signature: `<T where T == Float>`.
@_silgen_name("allconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> AllConcrete {
return self
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: vjp: @AD__allconcrete_where_gensig_constrained__vjp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}
extension AllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
extension AllConcrete where T == Float {
// Original generic signature: `<T where T == Float>`.
// Where clause generic signature: none.
@_silgen_name("allconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> AllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: vjp: @AD__allconcrete_original_gensig__vjp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// Original generic signature: `<T where T == Float>`.
// Where clause generic signature: `<T where T == Float>`.
@_silgen_name("allconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> AllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: vjp: @AD__allconcrete_where_gensig__vjp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}
extension AllConcrete where T == Float {
func testDifferentiability() {
let _: @differentiable (AllConcrete) -> AllConcrete =
{ $0.originalGenericSignature() }
let _: @differentiable (AllConcrete) -> AllConcrete =
{ $0.whereClauseGenericSignature() }
let _: @differentiable (AllConcrete) -> AllConcrete =
{ $0.whereClauseGenericSignatureConstrained() }
}
}
// Test `@differentiable` attribute where original declaration has generic
// signature and not all generic parameters are concrete. The lowered SIL
// function and the differentiability witness should both have a derivative
// generic signature.
// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
struct NotAllConcrete<T, U>: Differentiable {}
extension NotAllConcrete {
// Original generic signature: `<T, U>`.
// Where clause generic signature: `<T, U where T == Float>`.
@_silgen_name("notallconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
return self
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: vjp: @AD__notallconcrete_where_gensig_constrained__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}
extension NotAllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
extension NotAllConcrete where T == Float {
// Original generic signature: `<T, U where T == Float>`.
// Where clause generic signature: none.
@_silgen_name("notallconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: vjp: @AD__notallconcrete_original_gensig__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
// Original generic signature: `<T, U where T == Float>`.
// Where clause generic signature: `<T, U where T == Float>`.
@_silgen_name("notallconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: vjp: @AD__notallconcrete_where_gensig__vjp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}
extension NotAllConcrete where T == Float {
func testDifferentiability() {
let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
{ $0.originalGenericSignature() }
let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
{ $0.whereClauseGenericSignature() }
let _: @differentiable (NotAllConcrete) -> NotAllConcrete =
{ $0.whereClauseGenericSignatureConstrained() }
}
}