//===--- TgmathDerivatives.swift.gyb --------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
// This file defines derivatives for tgmath functions.
//===----------------------------------------------------------------------===//

// SWIFT_ENABLE_TENSORFLOW
// `_Differentiation` module sources are currently compiled as part of the core
// standard library on tensorflow branch, so import declarations are not
// necessary.
#if false
// SWIFT_ENABLE_TENSORFLOW END

import Swift

#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
  import Darwin.C.tgmath
#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku)
  import Glibc
#elseif os(Windows)
  import CRT
#else
#error("Unsupported platform")
#endif

// SWIFT_ENABLE_TENSORFLOW
#endif
// SWIFT_ENABLE_TENSORFLOW END

@usableFromInline
@derivative(of: fma)
func _jvpFma<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T,
  _ z: T
) -> (value: T, differential: (T, T, T) -> T) where T == T.TangentVector {
  return (fma(x, y, z), { (dx, dy, dz) in dx * y + dy * x + dz })
}

@usableFromInline
@derivative(of: fma)
func _vjpFma<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T,
  _ z: T
) -> (value: T, pullback: (T) -> (T, T, T)) where T == T.TangentVector {
  return (fma(x, y, z), { v in (v * y, v * x, v) })
}

@usableFromInline
@derivative(of: remainder)
func _jvpRemainder<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
  fatalError("""
    Unimplemented JVP for 'remainder(_:)'. \
    https://bugs.swift.org/browse/TF-1108 tracks this issue
    """)
}

@usableFromInline
@derivative(of: remainder)
func _vjpRemainder<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
  return (remainder(x, y), { v in (v, -v * ((x / y).rounded(.toNearestOrEven))) })
}

@usableFromInline
@derivative(of: fmod)
func _jvpFmod<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
  fatalError("""
    Unimplemented JVP for 'fmod(_:)'. \
    https://bugs.swift.org/browse/TF-1108 tracks this issue
    """)
}

@usableFromInline
@derivative(of: fmod)
func _vjpFmod<T: FloatingPoint & Differentiable> (
  _ x: T,
  _ y: T
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
  return (fmod(x, y), { v in (v, -v * ((x / y).rounded(.towardZero))) })
}

%for derivative_kind in ['jvp', 'vjp']:
%  linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
@usableFromInline
@derivative(of: sqrt)
func _${derivative_kind}Sqrt<T: FloatingPoint & Differentiable> (
  _ x: T
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
  let value = sqrt(x)
  return (value, { v in v / (2 * value) })
}

@usableFromInline
@derivative(of: ceil)
func _${derivative_kind}Ceil<T: FloatingPoint & Differentiable> (
  _ x: T
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
  return (ceil(x), { v in 0 })
}

@usableFromInline
@derivative(of: floor)
func _${derivative_kind}Floor<T: FloatingPoint & Differentiable> (
  _ x: T
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
  return (floor(x), { v in 0 })
}

@usableFromInline
@derivative(of: round)
func _${derivative_kind}Round<T: FloatingPoint & Differentiable> (
  _ x: T
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
  return (round(x), { v in 0 })
}

@usableFromInline
@derivative(of: trunc)
func _${derivative_kind}Trunc<T: FloatingPoint & Differentiable> (
  _ x: T
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
  return (trunc(x), { v in 0 })
}
%end # for derivative_kind in ['jvp', 'vjp']:

// Unary functions
%for derivative_kind in ['jvp', 'vjp']:
%  linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
%  for T in ['Float', 'Double', 'Float80']:
%    if T == 'Float80':
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
%    end
@inlinable
@derivative(of: exp)
func _${derivative_kind}Exp(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  let value = exp(x)
  return (value, { v in value * v })
}

@inlinable
@derivative(of: exp2)
func _${derivative_kind}Exp2(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  let value = exp2(x)
  return (value, { v in v * ${T}(M_LN2) * value })
}

@inlinable
@derivative(of: log)
func _${derivative_kind}Log(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (log(x), { v in v / x })
}

@inlinable
@derivative(of: log10)
func _${derivative_kind}Log10(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (log10(x), { v in v * ${T}(M_LOG10E) / x })
}

@inlinable
@derivative(of: log2)
func _${derivative_kind}Log2(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (log2(x), { v in v / (${T}(M_LN2) * x) })
}

@inlinable
@derivative(of: sin)
func _${derivative_kind}Sin(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (sin(x), { v in v * cos(x) })
}

@inlinable
@derivative(of: cos)
func _${derivative_kind}Cos(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (cos(x), { v in -v * sin(x) })
}

@inlinable
@derivative(of: tan)
func _${derivative_kind}Tan(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  let value = tan(x)
  return (value, { v in v * (1 + value * value) })
}

@inlinable
@derivative(of: asin)
func _${derivative_kind}Asin(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (asin(x), { v in v / sqrt(1 - x * x) })
}

@inlinable
@derivative(of: acos)
func _${derivative_kind}Acos(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (acos(x), { v in -v / sqrt(1 - x * x) })
}

@inlinable
@derivative(of: atan)
func _${derivative_kind}Atan(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (atan(x), { v in v / (1 + x * x) })
}

@inlinable
@derivative(of: sinh)
func _${derivative_kind}Sinh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (sinh(x), { v in v * cosh(x) })
}

@inlinable
@derivative(of: cosh)
func _${derivative_kind}Cosh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (cosh(x), { v in v * sinh(x) })
}

@inlinable
@derivative(of: tanh)
func _${derivative_kind}Tanh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  let value = tanh(x)
  return (value, { v in v * (1 - value * value) })
}

@inlinable
@derivative(of: asinh)
func _${derivative_kind}Asinh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (asinh(x), { v in v / sqrt(1 + x * x) })
}

@inlinable
@derivative(of: acosh)
func _${derivative_kind}Acosh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (acosh(x), { v in v / sqrt(x * x - 1) })
}

@inlinable
@derivative(of: atanh)
func _${derivative_kind}Atanh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (atanh(x), { v in v / (1 - x * x) })
}

@inlinable
@derivative(of: expm1)
func _${derivative_kind}Expm1(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (expm1(x), { v in exp(x) * v })
}

@inlinable
@derivative(of: log1p)
func _${derivative_kind}Log1p(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (log1p(x), { v in v / (x + 1) })
}

@inlinable
@derivative(of: erf)
func _${derivative_kind}Erf(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (erf(x), { v in v * ${T}(M_2_SQRTPI) * exp(-x * x) })
}

@inlinable
@derivative(of: erfc)
func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
  return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
}

%    if T == 'Float80':
#endif
%    end # if T == 'Float80':
%  end # for T in ['Float', 'Double', 'Float80']:
%end # for derivative_kind in ['jvp', 'vjp']:

// Binary functions
%for T in ['Float', 'Float80']:
%  if T == 'Float80':
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
%  end
@inlinable
@derivative(of: pow)
func _vjpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T})) {
  let value = pow(x, y)
  return (value, { v in (
    v * y * pow(x, y - 1), v * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x)
  ) })
}

@inlinable
@derivative(of: pow)
func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -> ${T}) {
  let value = pow(x, y)
  return (value, { (dx, dy) in
    dx * y * pow(x, y - 1) + dy * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x)
  })
}
%  if T == 'Float80':
#endif
%  end # if T == 'Float80':
%end # for T in ['Float', 'Float80']:
