blob: 1922e0e71d19fbcb9c4364c2ef80471f763e8baf [file] [log] [blame]
// SWIFT_ENABLE_TENSORFLOW
// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_key_path_iterable_other_module.swift
struct Tensor<Scalar> {
var scalar: Scalar
init(_ scalar: Scalar) {
self.scalar = scalar
}
}
extension Tensor : Equatable where Scalar : Equatable {}
extension Tensor : AdditiveArithmetic where Scalar : AdditiveArithmetic {}
extension Tensor : VectorProtocol where Scalar : AdditiveArithmetic {
typealias VectorSpaceScalar = Scalar
func adding(_: Scalar) -> Self { self }
func subtracting(_: Scalar) -> Self { self }
func scaled(by scalar: Scalar) -> Self { self }
}
// Synthesis should work for empty structs.
// `allKeyPaths` simply returns `[]`.
struct Empty : KeyPathIterable {}
struct Parameters : KeyPathIterable {
var w: Float
var b: Float
}
func testParameters() {
var params = Parameters(w: 1, b: 2)
assert(params.allKeyPaths.count == 2)
assert(params.allKeyPaths(to: Float.self).count == 2)
assert(params.allKeyPaths(to: Int.self).count == 0)
for kp in params.allWritableKeyPaths(to: Float.self) {
params[keyPath: kp] *= 2
}
}
struct TensorParameters : KeyPathIterable {
var w: Tensor<Float>
var b: Tensor<Float>
// Non-stored-property members should not affect synthesis.
var computed: Float {
return (w + b).scalar
}
func foo() {}
typealias Foo = Int
}
extension TensorParameters : VectorProtocol {
static var zero: TensorParameters {
return TensorParameters(w: Tensor(0), b: Tensor(0))
}
static func + (lhs: TensorParameters, rhs: TensorParameters) -> TensorParameters {
return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
}
static func - (lhs: TensorParameters, rhs: TensorParameters) -> TensorParameters {
return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
}
typealias VectorSpaceScalar = Float
func adding(_ x: VectorSpaceScalar) -> TensorParameters {
return TensorParameters(w: w.adding(x), b: b.adding(x))
}
func subtracting(_ x: VectorSpaceScalar) -> TensorParameters {
return TensorParameters(w: w.subtracting(x), b: b.subtracting(x))
}
func scaled(by scalar: VectorSpaceScalar) -> TensorParameters {
return TensorParameters(w: w.scaled(by: scalar), b: b.scaled(by: scalar))
}
}
struct HeterogeneousParameters : KeyPathIterable {
var float: Float
var double: Double
var tensor: Tensor<Float>
var params: Parameters
}
func testHeterogenousParameters(_ params: Parameters) {
let hetero = HeterogeneousParameters(float: 0, double: 0,
tensor: Tensor(0), params: params)
assert(hetero.allKeyPaths.count == 4)
assert(hetero.recursivelyAllKeyPaths.count == 6)
assert(hetero.allKeyPaths(to: Float.self).count == 1)
assert(hetero.recursivelyAllKeyPaths(to: Float.self).count == 3)
assert(hetero.allKeyPaths(to: Tensor<Float>.self).count == 1)
assert(hetero.allKeyPaths(to: Parameters.self).count == 1)
assert(hetero.allKeyPaths(to: Int.self).count == 0)
}
// Test type in generic context.
struct A<T> {
struct B<U, V> {
struct GenericContextParams : KeyPathIterable {
var params: Parameters
var float: Float
}
}
}
// Test generic optimizer.
struct DummyOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>
where P : VectorProtocol, P.VectorSpaceScalar == Scalar
{
let learningRate: Scalar
var firstMoments: P = P.zero
mutating func fitParameters(
parameters: inout P, withGradients gradients: P
) {
for kp in parameters.recursivelyAllWritableKeyPaths(to: Tensor<Scalar>.self) {
firstMoments[keyPath: kp].scale(by: learningRate)
parameters[keyPath: kp] -= parameters[keyPath: kp].scaled(by: learningRate)
}
}
}
// TF-575: Test overloaded key path component name.
protocol NameLookupConflictProtocol {}
extension NameLookupConflictProtocol {
func member() {}
}
struct NameLookupConflict: NameLookupConflictProtocol & KeyPathIterable {
// Note: `NameLookupConflict.member` is overloaded with
// `MemberNameConflictProtocol.member`.
// This makes the following generated code fail:
//
// var allKeyPaths: [PartialKeyPath<Self>] {
// [\Self.member]
// }
//
// error: cannot convert value of type
// 'WritableKeyPath<NameLookupConflict, Float>' to expected element type
// 'PartialKeyPath<NameLookupConflict>'
var member: Float
}
// Test derived conformances in disallowed contexts.
// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'KeyPathIterable'}}
// expected-error @+1 {{extension outside of file declaring struct 'OtherFileNonconforming' prevents automatic synthesis of 'AllKeyPaths' for protocol 'KeyPathIterable'}}
extension OtherFileNonconforming : KeyPathIterable {}
// expected-error @+2 {{type 'GenericOtherFileNonconforming<T>' does not conform to protocol 'KeyPathIterable'}}
// expected-error @+1 {{extension outside of file declaring generic struct 'GenericOtherFileNonconforming' prevents automatic synthesis of 'AllKeyPaths' for protocol 'KeyPathIterable'}}
extension GenericOtherFileNonconforming : KeyPathIterable {}