blob: 63b9eb6eb824b460cd5c178f2265e4bce45697dd [file] [log] [blame]
//===--- DifferentiationInvoker.h -----------------------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Class that represents an invoker of differentiation.
// Used to track diagnostic source locations.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H
#include "swift/Basic/SourceLoc.h"
#include <utility>
namespace swift {
class ApplyInst;
class DifferentiableFunctionInst;
class LinearFunctionInst;
class SILDifferentiabilityWitness;
namespace autodiff {
/// The invoker of a differentiation task. It can be some user syntax, e.g.
/// an `differentiable_function` instruction lowered from an
/// `DifferentiableFunctionExpr` expression, the differentiation pass, or
/// nothing at all. This will be used to emit informative diagnostics.
struct DifferentiationInvoker {
public:
/// The kind of the invoker of a differentiation task.
enum class Kind {
// Invoked by an `differentiable_function` instruction, which may or may not
// be linked to a Swift AST node (e.g. an `DifferentiableFunctionExpr`
// expression).
DifferentiableFunctionInst,
// Invoked by an `linear_function` instruction, which may or may not
// be linked to a Swift AST node (e.g. an `LinearFunctionExpr` expression).
LinearFunctionInst,
// Invoked by the indirect application of differentiation. This case has an
// associated original `apply` instruction and
// `SILDifferentiabilityWitness`.
IndirectDifferentiation,
// Invoked by a `SILDifferentiabilityWitness` **without** being linked to a
// Swift AST attribute. This case has an associated
// `SILDifferentiabilityWitness`.
SILDifferentiabilityWitnessInvoker
};
private:
Kind kind;
union Value {
/// The instruction associated with the `DifferentiableFunctionInst` case.
DifferentiableFunctionInst *diffFuncInst;
Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}
/// The instruction associated with the `LinearFunctionInst` case.
LinearFunctionInst *linearFuncInst;
Value(LinearFunctionInst *inst) : linearFuncInst(inst) {}
/// The parent `apply` instruction and the witness associated with the
/// `IndirectDifferentiation` case.
std::pair<ApplyInst *, SILDifferentiabilityWitness *>
indirectDifferentiation;
Value(ApplyInst *applyInst, SILDifferentiabilityWitness *witness)
: indirectDifferentiation({applyInst, witness}) {}
/// The witness associated with the `SILDifferentiabilityWitnessInvoker`
/// case.
SILDifferentiabilityWitness *witness;
Value(SILDifferentiabilityWitness *witness) : witness(witness) {}
} value;
/*implicit*/
DifferentiationInvoker(Kind kind, Value value) : kind(kind), value(value) {}
public:
DifferentiationInvoker(DifferentiableFunctionInst *inst)
: kind(Kind::DifferentiableFunctionInst), value(inst) {}
DifferentiationInvoker(LinearFunctionInst *inst)
: kind(Kind::LinearFunctionInst), value(inst) {}
DifferentiationInvoker(ApplyInst *applyInst,
SILDifferentiabilityWitness *witness)
: kind(Kind::IndirectDifferentiation), value({applyInst, witness}) {}
DifferentiationInvoker(SILDifferentiabilityWitness *witness)
: kind(Kind::SILDifferentiabilityWitnessInvoker), value(witness) {}
Kind getKind() const { return kind; }
DifferentiableFunctionInst *getDifferentiableFunctionInst() const {
assert(kind == Kind::DifferentiableFunctionInst);
return value.diffFuncInst;
}
LinearFunctionInst *getLinearFunctionInst() const {
assert(kind == Kind::LinearFunctionInst);
return value.linearFuncInst;
}
std::pair<ApplyInst *, SILDifferentiabilityWitness *>
getIndirectDifferentiation() const {
assert(kind == Kind::IndirectDifferentiation);
return value.indirectDifferentiation;
}
SILDifferentiabilityWitness *getSILDifferentiabilityWitnessInvoker() const {
assert(kind == Kind::SILDifferentiabilityWitnessInvoker);
return value.witness;
}
SourceLoc getLocation() const;
void print(llvm::raw_ostream &os) const;
};
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
DifferentiationInvoker invoker) {
invoker.print(os);
return os;
}
} // end namespace autodiff
} // end namespace swift
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H