//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

%{
storagescalarCounts = [2,4,8,16,32,64]
vectorscalarCounts = storagescalarCounts + [3]
}%

%for n in vectorscalarCounts:

//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//

extension SIMD${n}: AdditiveArithmetic where Scalar: FloatingPoint {}

extension SIMD${n}: Differentiable
where
  Scalar: Differentiable & BinaryFloatingPoint,
  Scalar.TangentVector: BinaryFloatingPoint
{
  public typealias TangentVector = SIMD${n}

  @inlinable
  public var zeroTangentVectorInitializer: () -> TangentVector {
    { .init(repeating: 0) }
  }
}

// SWIFT_ENABLE_TENSORFLOW
extension SIMD${n}: EuclideanDifferentiable
where
  Scalar: EuclideanDifferentiable & BinaryFloatingPoint,
  Scalar.TangentVector: BinaryFloatingPoint
{
}
// SWIFT_ENABLE_TENSORFLOW END

//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//

extension SIMD${n}
where
  Scalar: Differentiable & BinaryFloatingPoint,
  Scalar.TangentVector == Scalar
{
  // NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
  @inlinable
  @derivative(of: subscript(_:))
  internal func _vjpSubscript(_ index: Int)
    -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
  {
    return (self[index], { v in
      var zeros = Self.zero
      zeros[index] = v
      return zeros
    })
  }

  @inlinable
  @derivative(of: subscript(_:))
  internal func _jvpSubscript(index: Int)
    -> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
  {
    return (self[index], { v in
      return .init(v[index])
    })
  }

  @inlinable
  @derivative(of: subscript(_:).set)
  internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
    -> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
  {
    self[index] = newValue
    return ((), { dSelf in
      let dNewValue = dSelf[index]
      dSelf[index] = 0
      return dNewValue
    })
  }
}

%end

extension SIMD
where
  Self: Differentiable,
  TangentVector: SIMD,
  Scalar: BinaryFloatingPoint,
  TangentVector.Scalar: BinaryFloatingPoint
{
  @inlinable
  @derivative(of: +)
  static func _vjpAdd(lhs: Self, rhs: Self)
    -> (
      value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
    )
  {
    return (lhs + rhs, { v in
      return (v, v)
    })
  }

  @inlinable
  @derivative(of: +)
  static func _jvpAdd(lhs: Self, rhs: Self)
    -> (
      value: Self, differential: (TangentVector, TangentVector) -> TangentVector
    )
  {
    return (lhs + rhs, { ltan, rtan in
      return ltan + rtan
    })
  }

  @inlinable
  @derivative(of: -)
  static func _vjpSubtract(lhs: Self, rhs: Self)
    -> (
      value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
    )
  {
    return (lhs - rhs, { v in
      return (v, -v)
    })
  }

  @inlinable
  @derivative(of: -)
  static func _jvpSubtract(lhs: Self, rhs: Self)
    -> (
      value: Self, differential: (TangentVector, TangentVector) -> TangentVector
    )
  {
    return (lhs - rhs, { ltan, rtan in
      return ltan - rtan
    })
  }

  @inlinable
  @derivative(of: -)
  static func _vjpNegate(rhs: Self)
    -> (value: Self, pullback: (TangentVector) -> (TangentVector))
  {
    return (-rhs, { v in
      return -v
    })
  }

  @inlinable
  @derivative(of: -)
  static func _jvpNegate(rhs: Self)
    -> (value: Self, differential: (TangentVector) -> (TangentVector))
  {
    return (-rhs, { v in
      return -v
    })
  }
}

extension SIMD
where
  Self: Differentiable,
  TangentVector: SIMD,
  Scalar: BinaryFloatingPoint,
  Self.TangentVector == Self
{
  @inlinable
  @derivative(of: *)
  static func _vjpMultiply(lhs: Self, rhs: Self)
    -> (
      value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
    )
  {
    return (lhs * rhs, { v in
      return (v * rhs, v * lhs)
    })
  }

  @inlinable
  @derivative(of: *)
  static func _jvpMultiply(lhs: Self, rhs: Self)
    -> (
      value: Self,  differential: (TangentVector, TangentVector) -> TangentVector
    )
  {
    return (lhs * rhs, { ltan, rtan in
      return lhs * rtan + ltan * rhs
    })
  }

  @inlinable
  @derivative(of: /)
  static func _vjpDivide(lhs: Self, rhs: Self)
    -> (
      value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
    )
  {
    return ( lhs / rhs, { v in
      (v / rhs, -lhs / (rhs * rhs) * v)
    })
  }

  @inlinable
  @derivative(of: /)
  static func _jvpDivide(lhs: Self, rhs: Self)
    -> (
      value: Self, differential: (TangentVector, TangentVector) -> TangentVector
    )
  {
    return ( lhs / rhs, { ltan, rtan in
      (ltan * rhs - lhs * rtan) / (rhs * rhs)
    })
  }
}

extension SIMD
where
  Self: Differentiable,
  TangentVector: SIMD,
  Scalar: BinaryFloatingPoint & Differentiable,
  Scalar.TangentVector: BinaryFloatingPoint,
  TangentVector.Scalar == Scalar.TangentVector
{
  @inlinable
  @derivative(of: +)
  static func _vjpAdd(lhs: Scalar, rhs: Self) -> (
    value: Self,
    pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
  ) {
    return (lhs + rhs, { v in
      return (v.sum(), v)
    })
  }

  @inlinable
  @derivative(of: +)
  static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
    value: Self,
    differential: (Scalar.TangentVector, TangentVector) -> TangentVector
  ) {
    return (lhs + rhs, { ltan, rtan in
      return ltan + rtan
    })
  }

  @inlinable
  @derivative(of: -)
  static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
    value: Self,
    pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
  ) {
    return (lhs - rhs, { v in
      return (v.sum(), -v)
    })
  }

  @inlinable
  @derivative(of: -)
  static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
    value: Self,
    differential: (Scalar.TangentVector, TangentVector) -> TangentVector
  ) {
    return (lhs - rhs, { ltan, rtan in
      return ltan - rtan
    })
  }

  @inlinable
  @derivative(of: +)
  static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
    value: Self,
    pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
  ) {
    return (lhs + rhs, { v in
      return (v, v.sum())
    })
  }

  @inlinable
  @derivative(of: +)
  static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
    value: Self,
    differential: (TangentVector, Scalar.TangentVector) -> TangentVector
  ) {
    return (lhs + rhs, { ltan, rtan in
      return ltan + rtan
    })
  }

  @inlinable
  @derivative(of: -)
  static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
    value: Self,
    pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
  ) {
    return (lhs - rhs, { v in
      return (v, -v.sum())
    })
  }

  @inlinable
  @derivative(of: -)
  static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
    value: Self,
    differential: (TangentVector, Scalar.TangentVector) -> TangentVector
  ) {
    return (lhs - rhs, { ltan, rtan in
      return ltan - rtan
    })
  }
}

extension SIMD
where
  Self: Differentiable,
  TangentVector: SIMD,
  Scalar: BinaryFloatingPoint & Differentiable,
  Self.TangentVector == Self,
  Scalar.TangentVector == Scalar
{
  @inlinable
  @derivative(of: *)
  static func _vjpMultiply(lhs: Self, rhs: Scalar) -> (
    value: Self,
    pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
  ) {
    return (lhs * rhs, { v in
      return (v * rhs, (v * lhs).sum())
    })
  }

  @inlinable
  @derivative(of: *)
  static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
    value: Self,
    differential: (TangentVector, Scalar.TangentVector) -> TangentVector
  ) {
    return (lhs * rhs, { ltan, rtan in
      return lhs * rtan + ltan * rhs
    })
  }

  @inlinable
  @derivative(of: /)
  static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
    value: Self,
    pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
  ) {
    return (lhs / rhs, { v in
      (v / rhs, (-lhs / (rhs * rhs) * v).sum())
    })
  }

  @inlinable
  @derivative(of: /)
  static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
    value: Self,
    differential: (TangentVector, Scalar.TangentVector) -> TangentVector
  ) {
    return (lhs / rhs, { ltan, rtan in
      (ltan * rhs - lhs * rtan) / (rhs * rhs)
    })
  }

  @inlinable
  @derivative(of: *)
  static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
    value: Self,
    pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
  ) {
    return (lhs * rhs, { v in
      return ((v * rhs).sum(), v * lhs)
    })
  }

  @inlinable
  @derivative(of: *)
  static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
    value: Self,
    differential: (Scalar.TangentVector, TangentVector) -> TangentVector
  ) {
    return (lhs * rhs, { ltan, rtan in
      return lhs * rtan + ltan * rhs
    })
  }

  @inlinable
  @derivative(of: /)
  static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
    value: Self,
    pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
  ) {
    return (lhs / rhs, { v in
      ((v / rhs).sum(), -lhs / (rhs * rhs) * v)
    })
  }

  @inlinable
  @derivative(of: /)
  static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
    value: Self,
    differential: (Scalar.TangentVector, TangentVector) -> TangentVector
  ) {
    return (lhs / rhs, { ltan, rtan in
      (ltan * rhs - lhs * rtan) / (rhs * rhs)
    })
  }
}

extension SIMD
where
  Self: Differentiable,
  TangentVector: SIMD,
  Scalar: BinaryFloatingPoint & Differentiable,
  Scalar.TangentVector: BinaryFloatingPoint,
  TangentVector == Self
{
  @inlinable
  @derivative(of: sum)
  func _vjpSum() -> (
    value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
  ) {
    return (sum(), { v in Self(repeating: Scalar(v)) })
  }

  @inlinable
  @derivative(of: sum)
  func _jvpSum() -> (
    value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
  ) {
    return (sum(), { v in Scalar.TangentVector(v.sum()) })
  }
}

extension SIMD
where
  Self: Differentiable,
  Self.TangentVector: SIMD,
  Scalar: BinaryFloatingPoint & Differentiable,
  Self.TangentVector == Self,
  Scalar.TangentVector == Scalar
{
  @inlinable
  @derivative(of: init(repeating:))
  static func _vjpInit(repeating value: Scalar)
    -> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
  {
    return (Self(repeating: value), { v in v.sum() })
  }

  @inlinable
  @derivative(of: init(repeating:))
  static func _jvpInit(repeating value: Scalar)
    -> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
  {
    return (Self(repeating: value), { v in Self(repeating: v) })
  }
}
