blob: ff198e430e33ca9221ef079e23c0e92e075d8d45 [file] [log] [blame]
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -module-name main %s
// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
// This file tests:
// - The "derivative generic signature" of `@differentiable` and `@derivative`
// attributes.
// - The generic signature of lowered SIL differentiability witnesses.
// Context:
// - For `@differentiable` attributes: the derivative generic signature is
// resolved from the original declaration's generic signature and additional
// `where` clause requirements.
// - For `@derivative` attributes: the derivative generic signature is the
// attributed declaration's generic signature.
import _Differentiation
//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//
// Test original declaration with a generic signature and derivative generic
// signature where all generic parameters are concrete (i.e. bound to concrete
// types via same-type requirements).
struct AllConcrete<T>: Differentiable {}
extension AllConcrete {
// Original generic signature: `<T>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: `<T where T == Float>`
@_silgen_name("allconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> AllConcrete {
return self
}
}
extension AllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
// 1. The derivative generic signature is equal to the original generic signature.
// 2. The derivative generic signature has *all concrete* generic parameters.
//
// Then the attribute should be lowered to a SIL differentiability witness with
// *no* derivative generic signature.
extension AllConcrete where T == Float {
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: none
@_silgen_name("allconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> AllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
// Witness generic signature: none
@_silgen_name("allconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> AllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}
// Test original declaration with a generic signature and derivative generic
// signature where *not* all generic parameters are concrete.
// types via same-type requirements).
struct NotAllConcrete<T, U>: Differentiable {}
extension NotAllConcrete {
// Original generic signature: `<T, U>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig_constrained")
@differentiable(where T == Float)
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
return self
}
}
extension NotAllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
extension NotAllConcrete where T == Float {
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_original_gensig")
@differentiable
func originalGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig")
@differentiable(where T == Float)
func whereClauseGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}