// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s
// RUN: %target-swift-frontend-typecheck -enable-testing -verify -disable-availability-checking %s

// Swift.AdditiveArithmetic:3:17: note: cannot yet register derivative default implementation for protocol requirements

import _Differentiation

// Dummy `Differentiable`-conforming type.
struct DummyTangentVector: Differentiable & AdditiveArithmetic {
  static var zero: Self { Self() }
  static func + (_: Self, _: Self) -> Self { Self() }
  static func - (_: Self, _: Self) -> Self { Self() }
  typealias TangentVector = Self
}

// Test top-level functions.

func id(_ x: Float) -> Float {
  return x
}
@derivative(of: id)
func jvpId(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
  return (x, { $0 })
}
@derivative(of: id, wrt: x)
func vjpIdExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (x, { $0 })
}

func generic<T: Differentiable>(_ x: T, _ y: T) -> T {
  return x
}
@derivative(of: generic)
func jvpGeneric<T: Differentiable>(x: T, y: T) -> (
  value: T, differential: (T.TangentVector, T.TangentVector) -> T.TangentVector
) {
  return (x, { $0 + $1 })
}
@derivative(of: generic)
func vjpGenericExtraGenericRequirements<T: Differentiable & FloatingPoint>(
  x: T, y: T
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
  return (x, { ($0, $0) })
}

// Test `wrt` parameter clauses.

func add(x: Float, y: Float) -> Float {
  return x + y
}
@derivative(of: add, wrt: x) // ok
func vjpAddWrtX(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) {
  return (x + y, { $0 })
}
@derivative(of: add, wrt: (x, y)) // ok
func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x + y, { ($0, $0) })
}

// Test index-based `wrt` parameters.

func subtract(x: Float, y: Float) -> Float {
  return x - y
}
@derivative(of: subtract, wrt: (0, y)) // ok
func vjpSubtractWrt0Y(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x - y, { ($0, $0) })
}
@derivative(of: subtract, wrt: (1)) // ok
func vjpSubtractWrt1(x: Float, y: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (x - y, { $0 })
}

// Test invalid original function.

// expected-error @+1 {{cannot find 'nonexistentFunction' in scope}}
@derivative(of: nonexistentFunction)
func vjpOriginalFunctionNotFound(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

// Test `@derivative` attribute where `value:` result does not conform to `Differentiable`.
// Invalid original function should be diagnosed first.
// expected-error @+1 {{cannot find 'nonexistentFunction' in scope}}
@derivative(of: nonexistentFunction)
func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) {
  fatalError()
}

// Test incorrect `@derivative` declaration type.

// expected-note @+2 {{'incorrectDerivativeType' defined here}}
// expected-note @+1 {{candidate global function does not have expected type '(Int) -> Int'}}
func incorrectDerivativeType(_ x: Float) -> Float {
  return x
}

// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:' and second element must have label 'pullback:' or 'differential:'}}
@derivative(of: incorrectDerivativeType)
func jvpResultIncorrect(x: Float) -> Float {
  return x
}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:'}}
@derivative(of: incorrectDerivativeType)
func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) {
  return (x, { $0 })
}
// expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; second element must have label 'pullback:' or 'differential:'}}
@derivative(of: incorrectDerivativeType)
func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) {
  return (x, { $0 })
}
// expected-error @+1 {{referenced declaration 'incorrectDerivativeType' could not be resolved}}
@derivative(of: incorrectDerivativeType)
func vjpResultNotDifferentiable(x: Int) -> (
  value: Int, pullback: (Int) -> Int
) {
  return (x, { $0 })
}
// expected-error @+2 {{function result's 'pullback' type does not match 'incorrectDerivativeType'}}
// expected-note @+3 {{'pullback' does not have expected type '(Float.TangentVector) -> Float.TangentVector' (aka '(Float) -> Float')}}
@derivative(of: incorrectDerivativeType)
func vjpResultIncorrectPullbackType(x: Float) -> (
  value: Float, pullback: (Double) -> Double
) {
  return (x, { $0 })
}

// Test invalid `wrt:` differentiation parameters.

func invalidWrtParam(_ x: Float, _ y: Float) -> Float {
  return x
}

// expected-error @+1 {{unknown parameter name 'z'}}
@derivative(of: add, wrt: z)
func vjpUnknownParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) {
  return (x + y, { $0 })
}
// expected-error @+1 {{parameters must be specified in original order}}
@derivative(of: invalidWrtParam, wrt: (y, x))
func vjpParamOrderNotIncreasing(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x + y, { ($0, $0) })
}
// expected-error @+1 {{'self' parameter is only applicable to instance methods}}
@derivative(of: invalidWrtParam, wrt: self)
func vjpInvalidSelfParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x + y, { ($0, $0) })
}
// expected-error @+1 {{parameter index is larger than total number of parameters}}
@derivative(of: invalidWrtParam, wrt: 2)
func vjpSubtractWrt2(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x - y, { ($0, $0) })
}
// expected-error @+1 {{parameters must be specified in original order}}
@derivative(of: invalidWrtParam, wrt: (1, x))
func vjpSubtractWrt1x(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x - y, { ($0, $0) })
}
// expected-error @+1 {{parameters must be specified in original order}}
@derivative(of: invalidWrtParam, wrt: (1, 0))
func vjpSubtractWrt10(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
  return (x - y, { ($0, $0) })
}

func noParameters() -> Float {
  return 1
}
// expected-error @+1 {{'vjpNoParameters()' has no parameters to differentiate with respect to}}
@derivative(of: noParameters)
func vjpNoParameters() -> (value: Float, pullback: (Float) -> Float) {
  return (1, { $0 })
}

func noDifferentiableParameters(x: Int) -> Float {
  return 1
}
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
@derivative(of: noDifferentiableParameters)
func vjpNoDifferentiableParameters(x: Int) -> (
  value: Float, pullback: (Float) -> Int
) {
  return (1, { _ in 0 })
}

func functionParameter(_ fn: (Float) -> Float) -> Float {
  return fn(1)
}
// expected-error @+1 {{can only differentiate with respect to parameters that conform to 'Differentiable', but '(Float) -> Float' does not conform to 'Differentiable'}}
@derivative(of: functionParameter, wrt: fn)
func vjpFunctionParameter(_ fn: (Float) -> Float) -> (
  value: Float, pullback: (Float) -> Float
) {
  return (functionParameter(fn), { $0 })
}

// Test static methods.

protocol StaticMethod: Differentiable {
  static func foo(_ x: Float) -> Float
  static func generic<T: Differentiable>(_ x: T) -> T
}

extension StaticMethod {
  static func foo(_ x: Float) -> Float { x }
  static func generic<T: Differentiable>(_ x: T) -> T { x }
}

extension StaticMethod {
  @derivative(of: foo)
  static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
  {
    return (x, { $0 })
  }

  // Test qualified declaration name.
  @derivative(of: StaticMethod.foo)
  static func vjpFoo(x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }

  @derivative(of: generic)
  static func vjpGeneric<T: Differentiable>(_ x: T) -> (
    value: T, pullback: (T.TangentVector) -> (T.TangentVector)
  ) {
    return (x, { $0 })
  }

  // expected-error @+1 {{'self' parameter is only applicable to instance methods}}
  @derivative(of: foo, wrt: (self, x))
  static func vjpFooWrtSelf(x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }
}

// Test instance methods.

protocol InstanceMethod: Differentiable {
  func foo(_ x: Self) -> Self
  func generic<T: Differentiable>(_ x: T) -> Self
}

extension InstanceMethod {
  // expected-note @+2 {{'foo' previously declared here}}
  // expected-note @+1 {{'foo' defined here}}
  func foo(_ x: Self) -> Self { x }

  // expected-note @+2 {{'generic' previously declared here}}
  // expected-note @+1 {{'generic' defined here}}
  func generic<T: Differentiable>(_ x: T) -> Self { self }
}

extension InstanceMethod {
  // expected-error @+1 {{invalid redeclaration of 'foo'}}
  func foo(_ x: Self) -> Self { self }

  // expected-error @+1 {{invalid redeclaration of 'generic'}}
  func generic<T: Differentiable>(_ x: T) -> Self { self }
}

extension InstanceMethod {
  @derivative(of: foo)
  func jvpFoo(x: Self) -> (
    value: Self, differential: (TangentVector, TangentVector) -> (TangentVector)
  ) {
    return (x, { $0 + $1 })
  }

  // Test qualified declaration name.
  @derivative(of: InstanceMethod.foo, wrt: x)
  func jvpFooWrtX(x: Self) -> (
    value: Self, differential: (TangentVector) -> (TangentVector)
  ) {
    return (x, { $0 })
  }

  @derivative(of: generic)
  func vjpGeneric<T: Differentiable>(_ x: T) -> (
    value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector)
  ) {
    return (self, { ($0, .zero) })
  }

  @derivative(of: generic, wrt: (self, x))
  func jvpGenericWrt<T: Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) {
    return (self, { dself, dx in dself })
  }

  // expected-error @+1 {{'self' parameter must come first in the parameter list}}
  @derivative(of: generic, wrt: (x, self))
  func jvpGenericWrtSelf<T: Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) {
    return (self, { dself, dx in dself })
  }
}

extension InstanceMethod {
  // If `Self` conforms to `Differentiable`, then `Self` is inferred to be a differentiation parameter.
  // expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
  // expected-note @+3 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)'}}
  @derivative(of: foo)
  func vjpFoo(x: Self) -> (
    value: Self, pullback: (TangentVector) -> TangentVector
  ) {
    return (x, { $0 })
  }

  // If `Self` conforms to `Differentiable`, then `Self` is inferred to be a differentiation parameter.
  // expected-error @+2 {{function result's 'pullback' type does not match 'generic'}}
  // expected-note @+3 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, T.TangentVector)'}}
  @derivative(of: generic)
  func vjpGeneric<T: Differentiable>(_ x: T) -> (
    value: Self, pullback: (TangentVector) -> T.TangentVector
  ) {
    return (self, { _ in .zero })
  }
}

// Test `@derivative` declaration with more constrained generic signature.

func req1<T>(_ x: T) -> T {
  return x
}
@derivative(of: req1)
func vjpExtraConformanceConstraint<T: Differentiable>(_ x: T) -> (
  value: T, pullback: (T.TangentVector) -> T.TangentVector
) {
  return (x, { $0 })
}

func req2<T, U>(_ x: T, _ y: U) -> T {
  return x
}
@derivative(of: req2)
func vjpExtraConformanceConstraints<T: Differentiable, U: Differentiable>( _ x: T, _ y: U) -> (
  value: T, pullback: (T) -> (T, U)
) where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
  return (x, { ($0, .zero) })
}

// Test `@derivative` declaration with extra same-type requirements.
func req3<T>(_ x: T) -> T {
  return x
}
@derivative(of: req3)
func vjpSameTypeRequirementsGenericParametersAllConcrete<T>(_ x: T) -> (
  value: T, pullback: (T.TangentVector) -> T.TangentVector
) where T: Differentiable, T.TangentVector == Float {
  return (x, { $0 })
}

struct Wrapper<T: Equatable>: Equatable {
  var x: T
  init(_ x: T) { self.x = x }
}
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
  static var zero: Self { .init(.zero) }
  static func + (lhs: Self, rhs: Self) -> Self { .init(lhs.x + rhs.x) }
  static func - (lhs: Self, rhs: Self) -> Self { .init(lhs.x - rhs.x) }
}
extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector {
  typealias TangentVector = Wrapper<T.TangentVector>
}
extension Wrapper where T: Differentiable, T == T.TangentVector {
  @derivative(of: init(_:))
  static func vjpInit(_ x: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) {
    fatalError()
  }
}

// Test class methods.

class Super {
  @differentiable
  // expected-note @+1 {{candidate instance method is not defined in the current type context}}
  func foo(_ x: Float) -> Float {
    return x
  }

  @derivative(of: foo)
  func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (foo(x), { v in v })
  }
}

class Sub: Super {
  // TODO(TF-649): Enable `@derivative` to override derivatives for original
  // declaration defined in superclass.
  // expected-error @+1 {{referenced declaration 'foo' could not be resolved}}
  @derivative(of: foo)
  override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float)
  {
    return (foo(x), { v in v })
  }
}

// Test non-`func` original declarations.

struct Struct<T> {
  var x: T
}
extension Struct: Equatable where T: Equatable {}
extension Struct: Differentiable & AdditiveArithmetic
where T: Differentiable & AdditiveArithmetic {
  static var zero: Self {
    fatalError()
  }
  static func + (lhs: Self, rhs: Self) -> Self {
    fatalError()
  }
  static func - (lhs: Self, rhs: Self) -> Self {
    fatalError()
  }
  typealias TangentVector = Struct<T.TangentVector>
  mutating func move(along direction: TangentVector) {
    x.move(along: direction.x)
  }
}

class Class<T> {
  var x: T
  init(_ x: T) {
    self.x = x
  }
}
extension Class: Differentiable where T: Differentiable {}

// Test computed properties.

extension Struct {
  var computedProperty: T {
    get { x }
    set { x = newValue }
    _modify { yield &x }
  }
}
extension Struct where T: Differentiable & AdditiveArithmetic {
  @derivative(of: computedProperty)
  func vjpProperty() -> (value: T, pullback: (T.TangentVector) -> TangentVector) {
    return (x, { v in .init(x: v) })
  }
  
  @derivative(of: computedProperty.get)
  func jvpProperty() -> (value: T, differential: (TangentVector) -> T.TangentVector) {
    fatalError()
  }
  
  @derivative(of: computedProperty.set)
  mutating func vjpPropertySetter(_ newValue: T) -> (
    value: (), pullback: (inout TangentVector) -> T.TangentVector
  ) {
    fatalError()
  }

  // expected-error @+1 {{cannot register derivative for _modify accessor}}
  @derivative(of: computedProperty._modify)
  mutating func vjpPropertyModify(_ newValue: T) -> (
    value: (), pullback: (inout TangentVector) -> T.TangentVector
  ) {
    fatalError()
  }
}

// Test initializers.

extension Struct {
  init(_ x: Float) {}
  init(_ x: T, y: Float) {}
}
extension Struct where T: Differentiable & AdditiveArithmetic {
  @derivative(of: init)
  static func vjpInit(_ x: Float) -> (
    value: Struct, pullback: (TangentVector) -> Float
  ) {
    return (.init(x), { _ in .zero })
  }

  @derivative(of: init(_:y:))
  static func vjpInit2(_ x: T, _ y: Float) -> (
    value: Struct, pullback: (TangentVector) -> (T.TangentVector, Float)
  ) {
    return (.init(x, y: y), { _ in (.zero, .zero) })
  }
}

// Test subscripts.

extension Struct {
  subscript() -> Float {
    get { 1 }
    set {}
  }

  subscript(float float: Float) -> Float {
    get { 1 }
    set {}
  }

  // expected-note @+1 {{candidate subscript does not have a setter}}
  subscript<T: Differentiable>(x: T) -> T { x }
}
extension Struct where T: Differentiable & AdditiveArithmetic {
  @derivative(of: subscript.get)
  func vjpSubscriptGetter() -> (value: Float, pullback: (Float) -> TangentVector) {
    return (1, { _ in .zero })
  }

  // expected-error @+2 {{a derivative already exists for '_'}}
  // expected-note @-6 {{other attribute declared here}}
  @derivative(of: subscript)
  func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) {
    return (1, { _ in .zero })
  }

  @derivative(of: subscript().get)
  func jvpSubscriptGetter() -> (value: Float, differential: (TangentVector) -> Float) {
    return (1, { _ in .zero })
  }

  @derivative(of: subscript(float:).get, wrt: self)
  func vjpSubscriptLabeledGetter(float: Float) -> (value: Float, pullback: (Float) -> TangentVector)  {
    return (1, { _ in .zero })
  }

  // expected-error @+2 {{a derivative already exists for '_'}}
  // expected-note @-6 {{other attribute declared here}}
  @derivative(of: subscript(float:), wrt: self)
  func vjpSubscriptLabeled(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) {
    return (1, { _ in .zero })
  }

  @derivative(of: subscript(float:).get)
  func jvpSubscriptLabeledGetter(float: Float) -> (value: Float, differential: (TangentVector, Float) -> Float)   {
    return (1, { (_,_) in 1})
  }

  @derivative(of: subscript(_:).get, wrt: self)
  func vjpSubscriptGenericGetter<T: Differentiable>(x: T) -> (value: T, pullback: (T.TangentVector) -> TangentVector)   {
    return (x, { _ in .zero })
  }

  // expected-error @+2 {{a derivative already exists for '_'}}
  // expected-note @-6 {{other attribute declared here}}
  @derivative(of: subscript(_:), wrt: self)
  func vjpSubscriptGeneric<T: Differentiable>(x: T) -> (value: T, pullback: (T.TangentVector) -> TangentVector)   {
    return (x, { _ in .zero })
  }

  @derivative(of: subscript.set)
  mutating func vjpSubscriptSetter(_ newValue: Float) -> (
    value: (), pullback: (inout TangentVector) -> Float
  ) {
    fatalError()
  }

  @derivative(of: subscript().set)
  mutating func jvpSubscriptSetter(_ newValue: Float) -> (
    value: (), differential: (inout TangentVector, Float) -> ()
  ) {
    fatalError()
  }

  @derivative(of: subscript(float:).set)
  mutating func vjpSubscriptLabeledSetter(float: Float, newValue: Float) -> (
    value: (), pullback: (inout TangentVector) -> (Float, Float)
  ) {
    fatalError()
  }

  @derivative(of: subscript(float:).set)
  mutating func jvpSubscriptLabeledSetter(float: Float, _ newValue: Float) -> (
    value: (), differential: (inout TangentVector, Float, Float) -> Void
  ) {
    fatalError()
  }

  // Error: original subscript has no setter.
  // expected-error @+1 {{referenced declaration 'subscript(_:)' could not be resolved}}
  @derivative(of: subscript(_:).set, wrt: self)
  mutating func vjpSubscriptGeneric_NoSetter<T: Differentiable>(x: T) -> (
    value: T, pullback: (T.TangentVector) -> TangentVector
  ) {
    return (x, { _ in .zero })
  }
}

extension Class {
  subscript() -> Float {
    get { 1 }
    // expected-note @+1 {{'subscript()' declared here}}
    set {}
  }
}
extension Class where T: Differentiable {
  @derivative(of: subscript.get)
  func vjpSubscriptGetter() -> (value: Float, pullback: (Float) -> TangentVector) {
    return (1, { _ in .zero })
  }

  // expected-error @+2 {{a derivative already exists for '_'}}
  // expected-note @-6 {{other attribute declared here}}
  @derivative(of: subscript)
  func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) {
    return (1, { _ in .zero })
  }

  // FIXME(SR-13096): Enable derivative registration for class property/subscript setters.
  // This requires changing derivative type calculation rules for functions with
  // class-typed parameters. We need to assume that all functions taking
  // class-typed operands may mutate those operands.
  // expected-error @+1 {{cannot yet register derivative for class property or subscript setters}}
  @derivative(of: subscript.set)
  func vjpSubscriptSetter(_ newValue: Float) -> (
    value: (), pullback: (inout TangentVector) -> Float
  ) {
    fatalError()
  }
}

// Test duplicate `@derivative` attribute.

func duplicate(_ x: Float) -> Float { x }
// expected-note @+1 {{other attribute declared here}}
@derivative(of: duplicate)
func jvpDuplicate1(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
  return (duplicate(x), { $0 })
}
// expected-error @+1 {{a derivative already exists for 'duplicate'}}
@derivative(of: duplicate)
func jvpDuplicate2(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
  return (duplicate(x), { $0 })
}

// Test invalid original declaration kind.

// expected-note @+1 {{candidate var does not have a getter}}
var globalVariable: Float

// expected-error @+1 {{referenced declaration 'globalVariable' could not be resolved}}
@derivative(of: globalVariable)
func invalidOriginalDeclaration(x: Float) -> (
  value: Float, differential: (Float) -> (Float)
) {
  return (x, { $0 })
}

// Test ambiguous original declaration.

protocol P1 {}
protocol P2 {}
// expected-note @+1 {{candidate global function found here}}
func ambiguous<T: P1>(_ x: T) -> T { x }
// expected-note @+1 {{candidate global function found here}}
func ambiguous<T: P2>(_ x: T) -> T { x }

// expected-error @+1 {{referenced declaration 'ambiguous' is ambiguous}}
@derivative(of: ambiguous)
func jvpAmbiguous<T: P1 & P2 & Differentiable>(x: T)
  -> (value: T, differential: (T.TangentVector) -> (T.TangentVector))
{
  return (x, { $0 })
}

// Test no valid original declaration.
// Original declarations are invalid because they have extra generic
// requirements unsatisfied by the `@derivative` function.

// expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}}
func invalid<T: BinaryFloatingPoint>(x: T) -> T { x }
// expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}}
func invalid<T: CustomStringConvertible>(x: T) -> T { x }
// expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}}
func invalid<T: FloatingPoint>(x: T) -> T { x }

// expected-error @+1 {{referenced declaration 'invalid' could not be resolved}}
@derivative(of: invalid)
func jvpInvalid<T: Differentiable>(x: T) -> (
  value: T, differential: (T.TangentVector) -> T.TangentVector
) {
  return (x, { $0 })
}

// Test invalid derivative type context: instance vs static method mismatch.

struct InvalidTypeContext<T: Differentiable> {
  // expected-note @+1 {{candidate static method does not have type equal to or less constrained than '<T where T : Differentiable> (InvalidTypeContext<T>) -> (T) -> T'}}
  static func staticMethod(_ x: T) -> T { x }

  // expected-error @+1 {{referenced declaration 'staticMethod' could not be resolved}}
  @derivative(of: staticMethod)
  func jvpStatic(_ x: T) -> (
    value: T, differential: (T.TangentVector) -> (T.TangentVector)
  ) {
    return (x, { $0 })
  }
}

// Test stored property original declaration.

struct HasStoredProperty {
  // expected-note @+1 {{'stored' declared here}}
  var stored: Float
}
extension HasStoredProperty: Differentiable & AdditiveArithmetic {
  static var zero: Self {
    fatalError()
  }
  static func + (lhs: Self, rhs: Self) -> Self {
    fatalError()
  }
  static func - (lhs: Self, rhs: Self) -> Self {
    fatalError()
  }
  typealias TangentVector = Self
}
extension HasStoredProperty {
  // expected-error @+1 {{cannot register derivative for stored property 'stored'}}
  @derivative(of: stored)
  func vjpStored() -> (value: Float, pullback: (Float) -> TangentVector) {
    return (stored, { _ in .zero })
  }
}

// Test derivative registration for protocol requirements. Currently unsupported.
// TODO(TF-982): Lift this restriction and add proper support.

protocol ProtocolRequirementDerivative {
  // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}}
  func requirement(_ x: Float) -> Float
}
extension ProtocolRequirementDerivative {
  // expected-error @+1 {{referenced declaration 'requirement' could not be resolved}}
  @derivative(of: requirement)
  func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    fatalError()
  }
}

// Test `inout` parameters.

func multipleSemanticResults(_ x: inout Float) -> Float {
  return x
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(x: inout Float) -> (
  value: Float, pullback: (Float) -> Float
) {
  return (multipleSemanticResults(&x), { $0 })
}

struct InoutParameters: Differentiable {
  typealias TangentVector = DummyTangentVector
  mutating func move(along _: TangentVector) {}
}

extension InoutParameters {
  // expected-note @+1 4 {{'staticMethod(_:rhs:)' defined here}}
  static func staticMethod(_ lhs: inout Self, rhs: Self) {}

  // Test wrt `inout` parameter.

  @derivative(of: staticMethod)
  static func vjpWrtInout(_ lhs: inout Self, _ rhs: Self) -> (
    value: Void, pullback: (inout TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}}
  @derivative(of: staticMethod)
  static func vjpWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> (
    // expected-note @+1 {{'pullback' does not have expected type '(inout InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(inout DummyTangentVector) -> DummyTangentVector')}}
    value: Void, pullback: (TangentVector) -> TangentVector
  ) { fatalError() }

  @derivative(of: staticMethod)
  static func jvpWrtInout(_ lhs: inout Self, _ rhs: Self) -> (
    value: Void, differential: (inout TangentVector, TangentVector) -> Void
  ) { fatalError() }

  // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}}
  @derivative(of: staticMethod)
  static func jvpWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> (
    // expected-note @+1 {{'differential' does not have expected type '(inout InoutParameters.TangentVector, InoutParameters.TangentVector) -> ()' (aka '(inout DummyTangentVector, DummyTangentVector) -> ()')}}
    value: Void, differential: (TangentVector, TangentVector) -> Void
  ) { fatalError() }

  // Test non-wrt `inout` parameter.

  @derivative(of: staticMethod, wrt: rhs)
  static func vjpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> (
    value: Void, pullback: (TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}}
  @derivative(of: staticMethod, wrt: rhs)
  static func vjpNotWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> (
    // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}}
    value: Void, pullback: (inout TangentVector) -> TangentVector
  ) { fatalError() }

  @derivative(of: staticMethod, wrt: rhs)
  static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> (
    value: Void, differential: (TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}}
  @derivative(of: staticMethod, wrt: rhs)
  static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> (
    // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}}
    value: Void, differential: (inout TangentVector) -> TangentVector
  ) { fatalError() }
}

extension InoutParameters {
  // expected-note @+1 4 {{'mutatingMethod' defined here}}
  mutating func mutatingMethod(_ other: Self) {}

  // Test wrt `inout` `self` parameter.

  @derivative(of: mutatingMethod)
  mutating func vjpWrtInout(_ other: Self) -> (
    value: Void, pullback: (inout TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}}
  @derivative(of: mutatingMethod)
  mutating func vjpWrtInoutMismatch(_ other: Self) -> (
    // expected-note @+1 {{'pullback' does not have expected type '(inout InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(inout DummyTangentVector) -> DummyTangentVector')}}
    value: Void, pullback: (TangentVector) -> TangentVector
  ) { fatalError() }

  @derivative(of: mutatingMethod)
  mutating func jvpWrtInout(_ other: Self) -> (
    value: Void, differential: (inout TangentVector, TangentVector) -> Void
  ) { fatalError() }

  // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}}
  @derivative(of: mutatingMethod)
  mutating func jvpWrtInoutMismatch(_ other: Self) -> (
    // expected-note @+1 {{'differential' does not have expected type '(inout InoutParameters.TangentVector, InoutParameters.TangentVector) -> ()' (aka '(inout DummyTangentVector, DummyTangentVector) -> ()')}}
    value: Void, differential: (TangentVector, TangentVector) -> Void
  ) { fatalError() }

  // Test non-wrt `inout` `self` parameter.

  @derivative(of: mutatingMethod, wrt: other)
  mutating func vjpNotWrtInout(_ other: Self) -> (
    value: Void, pullback: (TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}}
  @derivative(of: mutatingMethod, wrt: other)
  mutating func vjpNotWrtInoutMismatch(_ other: Self) -> (
    // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}}
    value: Void, pullback: (inout TangentVector) -> TangentVector
  ) { fatalError() }

  @derivative(of: mutatingMethod, wrt: other)
  mutating func jvpNotWrtInout(_ other: Self) -> (
    value: Void, differential: (TangentVector) -> TangentVector
  ) { fatalError() }

  // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}}
  @derivative(of: mutatingMethod, wrt: other)
  mutating func jvpNotWrtInoutMismatch(_ other: Self) -> (
    // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}}
    value: Void, differential: (TangentVector, TangentVector) -> Void
  ) { fatalError() }
}

// Test no semantic results.

func noSemanticResults(_ x: Float) {}

// expected-error @+1 {{cannot differentiate void function 'noSemanticResults'}}
@derivative(of: noSemanticResults)
func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}

// Test multiple semantic results.

extension InoutParameters {
  func multipleSemanticResults(_ x: inout Float) -> Float { x }
  // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
  @derivative(of: multipleSemanticResults)
  func vjpMultipleSemanticResults(_ x: inout Float) -> (
    value: Float, pullback: (inout Float) -> Void
  ) { fatalError() }

  func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
  // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
  @derivative(of: inoutVoid)
  func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
    value: Float, pullback: (inout Float) -> Void
  ) { fatalError() }
}

// Test original/derivative function `inout` parameter mismatches.

extension InoutParameters {
  // expected-note @+1 {{candidate instance method does not have expected type '(InoutParameters) -> (inout Float) -> Void'}}
  func inoutParameterMismatch(_ x: Float) {}

  // expected-error @+1 {{referenced declaration 'inoutParameterMismatch' could not be resolved}}
  @derivative(of: inoutParameterMismatch)
  func vjpInoutParameterMismatch(_ x: inout Float) -> (value: Void, pullback: (inout Float) -> Void) {
    fatalError()
  }

  // expected-note @+1 {{candidate instance method does not have expected type '(inout InoutParameters) -> (Float) -> Void'}}
  func mutatingMismatch(_ x: Float) {}

  // expected-error @+1 {{referenced declaration 'mutatingMismatch' could not be resolved}}
  @derivative(of: mutatingMismatch)
  mutating func vjpMutatingMismatch(_ x: Float) -> (value: Void, pullback: (inout Float) -> Void) {
    fatalError()
  }
}

// Test cross-file derivative registration.

extension FloatingPoint where Self: Differentiable {
  @usableFromInline
  @derivative(of: rounded)
  func vjpRounded() -> (
    value: Self,
    pullback: (Self.TangentVector) -> (Self.TangentVector)
  ) {
    fatalError()
  }
}

extension Differentiable where Self: AdditiveArithmetic {
  // expected-error @+1 {{referenced declaration '+' could not be resolved}}
  @derivative(of: +)
  static func vjpPlus(x: Self, y: Self) -> (
    value: Self,
    pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
  ) {
    return (x + y, { v in (v, v) })
  }
}

extension AdditiveArithmetic
where Self: Differentiable, Self == Self.TangentVector {
  // expected-error @+1 {{referenced declaration '+' could not be resolved}}
  @derivative(of: +)
  func vjpPlusInstanceMethod(x: Self, y: Self) -> (
    value: Self, pullback: (Self) -> (Self, Self)
  ) {
    return (x + y, { v in (v, v) })
  }
}

// Test derivatives of default implementations.
protocol HasADefaultImplementation {
  func req(_ x: Float) -> Float
}
extension HasADefaultImplementation {
  func req(_ x: Float) -> Float { x }
  // ok
  @derivative(of: req)
  func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    (x, { 10 * $0 })
  }
}

// Test default derivatives of requirements.
protocol HasADefaultDerivative {
  // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}}
  func req(_ x: Float) -> Float
}
extension HasADefaultDerivative {
  // TODO(TF-982): Support default derivatives for protocol requirements.
  // expected-error @+1 {{referenced declaration 'req' could not be resolved}}
  @derivative(of: req)
  func vjpReq(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    (x, { 10 * $0 })
  }
}

// MARK: - Original function visibility = derivative function visibility

public func public_original_public_derivative(_ x: Float) -> Float { x }
@derivative(of: public_original_public_derivative)
public func _public_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

public func public_original_usablefrominline_derivative(_ x: Float) -> Float { x }
@usableFromInline
@derivative(of: public_original_usablefrominline_derivative)
func _public_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_internal_derivative(_ x: Float) -> Float { x }
@derivative(of: internal_original_internal_derivative)
func _internal_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

private func private_original_private_derivative(_ x: Float) -> Float { x }
@derivative(of: private_original_private_derivative)
private func _private_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

fileprivate func fileprivate_original_fileprivate_derivative(_ x: Float) -> Float { x }
@derivative(of: fileprivate_original_fileprivate_derivative)
fileprivate func _fileprivate_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_usablefrominline_derivative(_ x: Float) -> Float { x }
@usableFromInline
@derivative(of: internal_original_usablefrominline_derivative)
func _internal_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_inlinable_derivative(_ x: Float) -> Float { x }
@inlinable
@derivative(of: internal_original_inlinable_derivative)
func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: internal_original_alwaysemitintoclient_derivative)
func _internal_original_alwaysemitintoclient_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

// MARK: - Original function visibility < derivative function visibility

@usableFromInline
func usablefrominline_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_usablefrominline_original_public_derivative' is public, but original function 'usablefrominline_original_public_derivative' is internal}}
@derivative(of: usablefrominline_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
public func _usablefrominline_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_public_derivative' is public, but original function 'internal_original_public_derivative' is internal}}
@derivative(of: internal_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
public func _internal_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

private func private_original_usablefrominline_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_usablefrominline_derivative' is internal, but original function 'private_original_usablefrominline_derivative' is private}}
@derivative(of: private_original_usablefrominline_derivative)
@usableFromInline
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-1=private }}
func _private_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

private func private_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_public_derivative' is public, but original function 'private_original_public_derivative' is private}}
@derivative(of: private_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-7=private}}
public func _private_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

private func private_original_internal_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_internal_derivative' is internal, but original function 'private_original_internal_derivative' is private}}
@derivative(of: private_original_internal_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}}
func _private_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

fileprivate func fileprivate_original_private_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_fileprivate_original_private_derivative' is private, but original function 'fileprivate_original_private_derivative' is fileprivate}}
@derivative(of: fileprivate_original_private_derivative)
// expected-note @+1 {{mark the derivative function as 'fileprivate' to match the original function}} {{1-8=fileprivate}}
private func _fileprivate_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

private func private_original_fileprivate_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_fileprivate_derivative' is fileprivate, but original function 'private_original_fileprivate_derivative' is private}}
@derivative(of: private_original_fileprivate_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-12=private}}
fileprivate func _private_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

// MARK: - Original function visibility > derivative function visibility

public func public_original_private_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_private_derivative' is fileprivate, but original function 'public_original_private_derivative' is public}}
@derivative(of: public_original_private_derivative)
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
fileprivate func _public_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

public func public_original_internal_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_internal_derivative' is internal, but original function 'public_original_internal_derivative' is public}}
@derivative(of: public_original_internal_derivative)
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
func _public_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

func internal_original_fileprivate_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_fileprivate_derivative' is fileprivate, but original function 'internal_original_fileprivate_derivative' is internal}}
@derivative(of: internal_original_fileprivate_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-12=internal}}
fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

// Test invalid reference to an accessor of a non-storage declaration.

// expected-note @+1 {{candidate global function does not have a getter}}
func function(_ x: Float) -> Float {
  x
}

// expected-error @+1 {{referenced declaration 'function' could not be resolved}}
@derivative(of: function(_:).get)
func vjpFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}

// Test ambiguity that exists when Type function name is the same
// as an accessor label.

extension Float {
  // Original function name conflicts with an accessor name ("set").
  func set() -> Float {
    self
  }

  // Original function name does not conflict with an accessor name.
  func method() -> Float {
    self
  }

  // Test ambiguous parse.
  // Expected:
  // - Base type: `Float`
  // - Declaration name: `set`
  // - Accessor kind: <none>
  // Actual:
  // - Base type: <none>
  // - Declaration name: `Float`
  // - Accessor kind: `set`
  // expected-error @+1 {{cannot find 'Float' in scope}}
  @derivative(of: Float.set)
  func jvpSet() -> (value: Float, differential: (Float) -> Float) {
    fatalError()
  }

  @derivative(of: Float.method)
  func jvpMethod() -> (value: Float, differential: (Float) -> Float) {
    fatalError()
  }
}

// Test original function with opaque result type.

// expected-note @+1 {{candidate global function does not have expected type '(Float) -> Float'}}
func opaqueResult(_ x: Float) -> some Differentiable { x }

// expected-error @+1 {{referenced declaration 'opaqueResult' could not be resolved}}
@derivative(of: opaqueResult)
func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  fatalError()
}
