blob: 108c61bc8f71c794e4578476b654837162db1da2 [file] [log] [blame]
//===--- TFDeabstraction.cpp - Lowering & canonicalization for tensor ops -===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This psss is in charge of lowering general code coming out of the mandatory
// SIL passes and producing a canonicalized and deabstracted standard form.
// It combines together standard techniques like inlining, generics
// specialization, and scalarization of structs and tuples.
//
// This is intended to be part of the mandatory passes, so its behavior is
// defined to be as simple and predictable as possible. We don't want to use
// heuristic techniques to resolve virtual calls for example, we'd rather leave
// them, so the user has a simple and predictable model for what this can
// handle.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "tf-deabstraction"
#include "TFUtilities.h"
#include "TFConstExpr.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/SILInliner.h"
#include "swift/SIL/SILConstants.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
using namespace swift;
using namespace tf;
using llvm::DenseMap;
static llvm::cl::opt<bool>
TFDumpDeabstractionDetails("tf-dump-deabstraction-details",
llvm::cl::init(false),
llvm::cl::desc("Dump extra details about TensorFlow deabstraction"));
static llvm::cl::opt<bool>
TFStrictDeabstraction("tf-strict-deabstraction", llvm::cl::init(false),
llvm::cl::desc("Verify #tfop's are valid without the perf optimizer"));
template<typename...T, typename...U>
static InFlightDiagnostic
diagnose(ASTContext &Context, SourceLoc loc, Diag<T...> diag, U &&...args) {
return Context.Diags.diagnose(loc, diag, std::forward<U>(args)...);
}
/// Delete the specified instruction (e.g. like inst->eraseFromParent()), but
/// also check to see if this instruction was the last use of any code that can
/// be trivially deleted. If so, remove that trivially dead code.
static void deleteInstAndAbandonedUses(SILInstruction *inst) {
for (auto &operand : inst->getAllOperands()) {
auto opInst = operand.get()->getDefiningInstruction();
operand.drop();
if (opInst && !opInst->hasUsesOfAnyResult())
recursivelyDeleteTriviallyDeadInstructions(opInst);
}
// Finally, delete the instruction itself.
inst->eraseFromParent();
}
namespace {
/// This class wraps the state and logic necessary to deabstract code into one
/// specific SIL function, which has been designated as a potential top-level
/// host for tensor code.
class TFDeabstraction {
SILFunction &fn;
TensorFunctionClassifier &tfc;
ConstExprEvaluator &constantEvaluator;
SILPassManager *passManager;
/// This is set to true by the early inlining phase if the function was
/// forcibly flattened to make all references to global variables visible
/// within the current function. This is done for top level code in
/// Playgrounds and the REPL.
bool forciblyFlattened = false;
/// This keeps track of whether we've ever changed this function through the
/// 'aboutToChangeFunction' method. This enables it to print debug log info
/// only for interesting functions.
bool changedFunction = false;
/// This is the list of tensor operations in the current function, filled in
/// by simplifyTensorOperands. This contains both the builtin instructions
/// that reflect the #tfop() invocations, as well as any retain/release
/// instructions using TensorHandle values.
SmallVector<SILInstruction*, 32> tensorOps;
public:
TFDeabstraction(SILFunction &fn, TensorFunctionClassifier &tfc,
ConstExprEvaluator &constantEvaluator, SILPassManager *PM)
: fn(fn), tfc(tfc), constantEvaluator(constantEvaluator), passManager(PM){
}
/// Deabstract the specified top level function as a deabstraction context.
void doIt();
/// This function is called on key entrypoints that mutate the SIL function.
/// This just exists to reduce the amount of debug spew to focus on the
/// functions that matter.
void aboutToChangeFunction() {
// If we already changed the function then no need to print again.
if (changedFunction) return;
changedFunction = true;
logCurrentState("Input", /*detailed*/false);
}
private:
void logCurrentState(const char *name, bool isDetailed);
void inlineCalls();
void simplifyTensorOperands();
void promoteToSSA(MutableArrayRef<AllocStackInst*> allocs);
void prepareStackAllocForPromotion(AllocStackInst *alloc);
void propagateTensorValues();
void checkAndCanonicalizeAttributes();
};
} // end anonymous namespace
void TFDeabstraction::logCurrentState(const char *name, bool isDetailed) {
// If this is detailed information and no-one asked for it, early out.
if (isDetailed && !TFDumpDeabstractionDetails)
return;
auto outs = getTFDumpIntermediateStream();
if (!outs) return;
*outs << "--- TFDeabstraction " << name << ": " << fn.getName() << "\n";
fn.print(*outs);
*outs << "----\n";
outs->flush();
}
/// Return true if this is a "array.uninitialized" call, which creates an array
/// and returns it with uninitialized elements for the caller to fill in.
static bool isArrayUninitialized(SILInstruction *call) {
auto *apply = dyn_cast<ApplyInst>(call);
if (!apply) return false;
auto semantics = ArraySemanticsCall(apply, "array.uninitialized");
return semantics.getKind() == ArrayCallKind::kArrayUninitialized;
}
/// Scan the function looking for call sites that should be inlined to expose
/// tensor operations, and inline them to expose those ops.
void TFDeabstraction::inlineCalls() {
llvm::PrettyStackTraceFormat X("TFDeabstraction::inlineCalls");
// We generally want to carefully and deliberately choose which functions to
// inline into our 'fn' function, but if this is a main function with top
// level code (e.g. in a playground) then we want to aggressively inline
// when/if we see any global_addr's with a TensorHandle in them. This allows
// us to promote these global_addrs to registers safely.
//
// TODO: This should be enough for now, but isn't really the right long term
// approach. Long term we should build a full callgraph and look for call
// paths that can touch tensor flavored global variables. If a function
// doesn't do so, then there is no reason to inline it. This can start to
// matter for larger examples.
//
// TODO: This should handle playgrounds and #! scripts, but probably isn't
// enough to handle REPL generated code. How do we identify the functions it
// produces for each entered statement? Matching on __repl or whatever prefix
// LLDB and the integrated REPL use is probably enough.
//
if (fn.getName() == SWIFT_ENTRY_POINT_FUNCTION) {
forciblyFlattened = [&]() -> bool {
for (auto &bb : fn)
for (auto &i : bb)
if (auto *inst = dyn_cast<GlobalAddrInst>(&i)) {
if (tfc.containsTensorFlowValue(inst->getType()))
return true;
}
return false;
}();
}
/// This predicate decides whether we should mandatory inline the specified
/// call site.
auto shouldInline = [&](FullApplySite site,
const SILFunction &callee) -> bool {
// If this is a call of an explicitly noinline function, don't inline it!
if (callee.getInlineStrategy() == NoInline)
return false;
// Check for array internals which we could be inlined, but prefer to
// leave in abstracted form for easier analysis. For things like
// Tensor<Float>([[1,2],[3,4]]), we prefer to see higher level array
// construction calls beacuse we end up removing them anyway.
if (isArrayUninitialized(site.getInstruction()))
return false;
// FIXME: This is a specific hack to inline literal conversion operations
// that prevent SSA promotion and bloat code. This should be eliminated
// when we have a proper constexpr model and can just constant fold through
// the memory indirections.
if (callee.getName().contains("ExpressibleByBuiltinIntegerLiteral") ||
callee.getName().contains("ExpressibleByIntegerLiteral") ||
callee.getName().contains("SSfs13FloatingPoint") ||
callee.getName().contains("S10TensorFlow0A0VAAs13FloatingPointRzrlE12"
"randomNormal4mean6stddev5s") ||
callee.getName().contains("S10TensorFlow0A5ShapeV12arrayLiteral"
"ACs5Int32Vd_tcfC") ||
callee.getName().contains("_allocateUninitializedArray"))
if (!TFStrictDeabstraction)
return true;
// If we're forcibly flattening code into the top level function, and if the
// callee is in the same source file as that top-level function (and thus
// has visibility into its global variables) then force inline it.
if (forciblyFlattened) {
if (auto *apply = dyn_cast<ApplyInst>(site.getInstruction())) {
if (auto *callee = apply->getCalleeFunction()) {
// FIXME: We will miscompile functions that use variables in top level
// code right now. We need to implement this properly.
#if 0
if (shouldBeForciblyFlattened(*callee))
return true;
#endif
}
}
}
// Get the type of the function being called after applying substitutions
// at the call site.
auto type = site.getSubstCalleeType();
// If the call we found is to something that processes TensorFlow values,
// then we want it inlined.
if (!tfc.containsTensorFlowValue(type))
return false;
return true;
};
SmallPtrSet<SILFunction*, 16> inlinedCallees;
// Use the mandatory inlining algorithm to expose call sites that contain
// TensorFlow values as their argument or result lists.
inlineForTFDeabstraction(fn,
[&](FullApplySite site, const SILFunction &callee) -> bool {
if (!shouldInline(site, callee))
return false;
// Recognize that we're about to change this function.
aboutToChangeFunction();
inlinedCallees.insert(const_cast<SILFunction*>(&callee));
return true;
}
);
auto &module = fn.getModule();
module.invalidateSILLoaderCaches();
// Now that we've inlined some functions, clean them up to avoid burning
// compile time in later passes. We do this with a simple linear scan,
// because functions that reference each other have already been flattened
// so there should be no interdependencies.
for (auto *callee : inlinedCallees) {
// We shouldn't be trying to delete the thing we're inlining into, doing so
// would invalidate iterators.
assert(callee != &fn && "inlining self into self??");
passManager->invalidateAnalysis(callee,
SILAnalysis::InvalidationKind::Everything);
// We can't delete this function if something is still using it. That could
// be because there is some other tensor program in this module that is
// using it or (most likely) that there is a now-dead witness table.
//
// TODO: Build infra to find unused witness tables and remove them.
if (callee->getRefCount() != 0) {
// FIXME: As a super hack, disable all optimization of the inlined
// methods that are defined and kept alive by the DifferentiableModule
// abstraction in the TensorFlow module. These don't get removed because
// of the witness tables for DifferentiableModule, but we know that they
// don't matter. Don't burn time optimizing them.
if (callee->getName().contains("adjoint3for4with") || //adjoint(for:with:
callee->getName().contains("6primal3for")) // primal(for:)
callee->setOptimizationMode(OptimizationMode::NoOptimization);
continue;
}
// If this is a public function then we can't remove it either.
if (callee->isPossiblyUsedExternally())
continue;
// ObjC functions are called through the runtime and are therefore alive
// even if not referenced inside SIL.
if (callee->getRepresentation() ==SILFunctionTypeRepresentation::ObjCMethod)
continue;
passManager->notifyDeleteFunction(callee);
// Okay, erase the function from the module.
module.eraseFunction(callee);
}
}
/// If the specified value is a StructInst that has one operand, or potentially
/// a chain of them, dig through and return the underlying value inside of it.
static SILValue lookThroughSingleElementStructInsts(SILValue value) {
if (auto *str = dyn_cast_or_null<StructInst>(value->getDefiningInstruction()))
if (str->getNumOperands() == 1)
return lookThroughSingleElementStructInsts(str->getOperand(0));
return value;
}
/// Scan the operand list of the builtin. If any operand is passed indirectly
/// (i.e., an address of a stack location is passed instead of the value itself)
/// then rewrite the builtin to use a loaded version of that value.
///
/// Similarly, if a primitive integer or floating point value is passed as a
/// struct value, extract out the underlying integer or float value.
///
static BuiltinInst *simplifyOperands(BuiltinInst *inst, TFDeabstraction &TFDA) {
/// Return a VarDecl if this is a struct wrapping a single field which is a
/// primitive integer or floating point value. We accept multiple layers of
/// struct wrappers as well, but return the decl for the top level field
/// type. This returns null in any other case.
auto getPrimitiveStructField = [&](Type type) -> VarDecl* {
VarDecl *result = nullptr;
while (1) {
auto decl = type->getAnyNominal();
if (!decl || !isa<StructDecl>(decl)) return nullptr;
// Check to see if there is a single stored field.
auto fieldIt = decl->getStoredProperties().begin();
if (fieldIt == decl->getStoredProperties().end()) return nullptr;
// If this is the top level of the struct, retain the field decl.
if (result == nullptr) result = *fieldIt;
type = (*fieldIt++)->getType();
if (fieldIt != decl->getStoredProperties().end()) return nullptr;
// If we unwrapped a level and got to a builtin type, then this is a
// wrapper.
if (type->is<BuiltinIntegerType>() ||
type->is<BuiltinFloatType>())
return result;
}
};
// Predicate that returns true if the specified type is an address type for
// a loadable (non-address-only) value.
auto isLoadableAddressType = [&](SILType type) -> bool {
return type.isAddress() && type.isLoadable(inst->getModule());
};
// Predicate that returns true if an operand of the specified type should be
// rewritten - either to load an address argument or expand a struct
// parameter.
auto canSimplifyOperand = [&](SILType type) -> bool {
return isLoadableAddressType(type) ||
getPrimitiveStructField(type.getSwiftRValueType()) != nullptr;
};
// If we don't have to change any operands, don't rewrite the builtin.
bool mustChangeBuiltin = false;
for (auto &op : inst->getAllOperands()) {
if (canSimplifyOperand(op.get()->getType())) {
mustChangeBuiltin = true;
break;
}
}
if (!mustChangeBuiltin) return inst;
// Mark the function as being mutated.
TFDA.aboutToChangeFunction();
// Okay, we do have to simplify something. Scan through and rewrite operands.
SILBuilder B(inst);
SmallVector<SILValue, 8> operands;
for (auto &op : inst->getAllOperands()) {
auto operand = op.get();
// If this is an address operand, emit a load of the value.
if (isLoadableAddressType(operand->getType())) {
bool hasOwnership = inst->getFunction()->hasQualifiedOwnership();
auto loadOwnership = hasOwnership ? LoadOwnershipQualifier::Trivial
: LoadOwnershipQualifier::Unqualified;
auto load = B.createLoad(inst->getLoc(), operand, loadOwnership);
load->setDebugLocation(inst->getDebugLocation());
operand = load;
}
// If the operand is a StructInst building the value that we want to
// extract, just get the element out of it, to avoid generating bloated IR.
operand = lookThroughSingleElementStructInsts(operand);
// If this is a struct value, emit struct extraction instruction(s).
while (auto fieldDecl = getPrimitiveStructField(
operand->getType().getSwiftRValueType())) {
auto extract = B.createStructExtract(inst->getLoc(), operand, fieldDecl);
extract->setDebugLocation(inst->getDebugLocation());
operand = extract;
}
operands.push_back(operand);
}
// Now that we've rebuilt the operand list, create a new builtin and replace
// the old one.
auto *newInst =
B.createBuiltin(inst->getLoc(), inst->getName(),
inst->getType(), /*no substitions*/{}, operands);
newInst->setDebugLocation(inst->getDebugLocation());
// Replace the old with the new and delete the old instruction.
inst->replaceAllUsesPairwiseWith(newInst);
// Remove the StructInst and other random values that we leave around in the
// program, now that we directly refer to the TensorFlow values.
deleteInstAndAbandonedUses(inst);
return newInst;
}
/// If the specified instruction is an high-level aggregate operation like
/// copy_addr or destroy_addr, break it down into its more primitive operations
/// and return true. Otherwise, return false.
///
/// If 'tfc' is non-null, this will only promote ops working on a type that
/// contains a TensorFlow value.
///
/// This leaves the input instruction in place and inserts the additional
/// instructions immediately after the input instruction that is exploded.
static bool explodeAggregateInst(SILInstruction *inst,
TensorFunctionClassifier *tfc) {
// Check to see if this is an instruction we can handle below, early exiting
// if not.
if (!isa<CopyAddrInst>(inst) &&
!isa<DestroyAddrInst>(inst) &&
!isa<RetainValueInst>(inst) &&
!isa<ReleaseValueInst>(inst) &&
!isa<StrongRetainInst>(inst) &&
!isa<StrongReleaseInst>(inst))
return false;
// Check to make sure that this operation is doing something on a value
// containing a TensorFlow value. If not, just leave it alone.
auto type = inst->getOperand(0)->getType();
if (tfc && !tfc->containsTensorFlowValue(type))
return false;
// TODO: This is currently just handling loadable types. We should be able to
// scalarize address-only elements, by turning them into by-address operations
// on each element. This can occur when a struct/tuple contains tensors and
// also has some address-only type.
auto &TL = inst->getModule().getTypeLowering(type);
if (!TL.isLoadable())
return false;
// Insert any new instructions right after the one we're going to explode.
if (isa<TermInst>(inst)) return false;
SILBuilder B(++SILBasicBlock::iterator(inst));
B.setCurrentDebugScope(inst->getDebugScope());
// Lower a copy_addr into a load and store + retain/release instructions.
if (auto *copyAddr = dyn_cast<CopyAddrInst>(inst)) {
// Note, we don't use TL.emitCopyInto because that will produce a copy_addr.
auto loc = copyAddr->getLoc();
SILValue value =
TL.emitLoadOfCopy(B, loc, copyAddr->getSrc(), copyAddr->isTakeOfSrc());
TL.emitStoreOfCopy(B, loc, value, copyAddr->getDest(),
copyAddr->isInitializationOfDest());
} else if (auto *destroy = dyn_cast<DestroyAddrInst>(inst)) {
/// Turn a destroy_addr into a load+release_value pair.
TL.emitDestroyAddress(B, destroy->getLoc(), destroy->getOperand());
} else if (isa<RetainValueInst>(inst) || isa<StrongRetainInst>(inst)) {
// Turn a retain_value into a retain_value on its elements. We peephole
// StructInst values because they are so common and this generates cleaner
// IR and faster compile times.
auto op = lookThroughSingleElementStructInsts(inst->getOperand(0));
if (op != inst->getOperand(0) && op->getType().isAnyClassReferenceType())
B.createStrongRetain(inst->getLoc(), op, Atomicity::Atomic);
else
TL.emitLoweredCopyValueMostDerivedDescendents(B, inst->getLoc(),
inst->getOperand(0));
} else if (isa<ReleaseValueInst>(inst) || isa<StrongReleaseInst>(inst)) {
// Turn a retain_value into a retain_value on its elements. We peephole
// StructInst values because they are so common and this generates cleaner
// IR and faster compile times.
auto op = lookThroughSingleElementStructInsts(inst->getOperand(0));
if (op != inst->getOperand(0) && op->getType().isAnyClassReferenceType())
B.createStrongRelease(inst->getLoc(), op, Atomicity::Atomic);
else
TL.emitLoweredDestroyValueMostDerivedDescendents(B, inst->getLoc(),
inst->getOperand(0));
} else {
llvm_unreachable("unhandled instructions should be filtered above");
}
return true;
}
/// Identify all of the tensor operations in the current function, and scan them
/// to see if there are any indirect arguments, where the address of a stack
/// allocation is passed to the builtin. These occur when the tensor op was in
/// a generic context and was passed a scalar attribute value of generic type.
///
/// If we find one of these indirect values, transform it into a load of the
/// address and a use of the loaded value. This allows the stack allocation to
/// be promoted, allowing us to construct SSA def-use chains.
///
/// Similarly, if we see a struct operand that wraps a primitive value, we
/// extract out the underlying scalar value until we get to a builtin integer or
/// floating point value.
///
/// Since we're scanning the function, keep track of all of the tensor
/// operations to avoid additional linear scans over the function.
///
void TFDeabstraction::simplifyTensorOperands() {
llvm::PrettyStackTraceFormat X("TFDeabstraction::simplifyTensorOperands");
bool containsOpBuiltin = false;
bool alreadyPrinted = false;
auto logIfFirstChange = [&]() {
if (alreadyPrinted) return;
logCurrentState("After Inlining", /*detailed*/true);
alreadyPrinted = true;
};
for (auto &BB : fn) {
for (auto I = BB.begin(), E = BB.end(); I != E; ) {
// Manually move iterator to avoid invalidation if we replace 'inst'.
auto *inst = &*I++;
// Try to decode this instruction as an op. If it isn't one, ignore it.
if (auto opInfo = SILTensorOpInfo::decode(inst)) {
logIfFirstChange();
// Simplify operands if possible.
opInfo->inst = simplifyOperands(opInfo->inst, *this);
// Remember this for later passes.
tensorOps.push_back(opInfo->inst);
containsOpBuiltin = true;
continue;
}
// If we have a call to a function that is conditionally promotable to a
// tensor op, we add it to the set of tensor operations we're trying to
// deabstract. This ensures that we deabstract its operands, which makes
// it possible to tell if it is getting a variable or constant value.
if (auto *apply = dyn_cast<ApplyInst>(inst)) {
if (SILTensorOpInfo::isDecodableApply(apply)) {
logIfFirstChange();
// Remember this for later passes.
tensorOps.push_back(apply);
containsOpBuiltin = true;
continue;
}
}
// Find retain and release instructions that directly use TensorFlow
// values. We treat them as tensorOps to ensure that their operands are
// deabstracted.
if (isa<StrongRetainInst>(inst) || isa<StrongReleaseInst>(inst)) {
if (isTensorFlowValue(inst->getOperand(0)->getType())) {
tensorOps.push_back(inst);
continue;
}
}
// Check to see if this is an aggregate operation (like a copy_addr, a
// retain or release, etc) that involves a TensorFlow value. If so,
// explode it out into its components and reprocess the components. This
// ensures that nothing later in deabstraction or partitioning have to
// worry about them.
if (explodeAggregateInst(inst, &tfc)) {
logIfFirstChange();
// Reset our iterator to the first instruction we just produced so we
// walk through them and recursively expand or remember them as
// appropriate.
I = ++SILBasicBlock::iterator(inst);
// We frequently produce dead code by exploding things, for example a
// retain of a StructInst value will end up being a retain of the
// original value, and therefore strand the StructInst. Clean this
// stuff up as we go. This is better for compile time and it makes it
// a lot easier to read the debugging dumps.
deleteInstAndAbandonedUses(inst);
continue;
}
// Otherwise we leave the instruction alone.
}
}
// If the tensorOps list just contained retain/release instructions but had
// no actual tensor builtins, we'll ignore the function because there is
// nothing to partition out of it. This is probably something actually
// working on the host-side tensor operation.
if (!containsOpBuiltin)
tensorOps.clear();
}
// Return true if this is a standard LLVM arithmetic operation. We often see
// silly allocas in the way of them.
// FIXME: This is only necessary because we don't have a constexpr model.
// When that is in place and working well, this should be removed.
static bool isSimpleBuiltinArithmeticOp(BuiltinInst *builtin) {
if (TFStrictDeabstraction)
return false;
switch (builtin->getBuiltinInfo().ID) {
default: return false;
case BuiltinValueKind::Trunc:
case BuiltinValueKind::ZExt:
case BuiltinValueKind::SExt:
case BuiltinValueKind::FPToUI:
case BuiltinValueKind::FPToSI:
case BuiltinValueKind::UIToFP:
case BuiltinValueKind::SIToFP:
case BuiltinValueKind::FPTrunc:
case BuiltinValueKind::FPExt:
case BuiltinValueKind::TruncOrBitCast:
case BuiltinValueKind::ZExtOrBitCast:
case BuiltinValueKind::SExtOrBitCast:
case BuiltinValueKind::Add:
case BuiltinValueKind::FAdd:
case BuiltinValueKind::And:
case BuiltinValueKind::AShr:
case BuiltinValueKind::LShr:
case BuiltinValueKind::Or:
case BuiltinValueKind::FDiv:
case BuiltinValueKind::Mul:
case BuiltinValueKind::FMul:
case BuiltinValueKind::SDiv:
case BuiltinValueKind::ExactSDiv:
case BuiltinValueKind::Shl:
case BuiltinValueKind::SRem:
case BuiltinValueKind::Sub:
case BuiltinValueKind::FSub:
case BuiltinValueKind::UDiv:
case BuiltinValueKind::ExactUDiv:
case BuiltinValueKind::URem:
case BuiltinValueKind::FRem:
case BuiltinValueKind::Xor:
case BuiltinValueKind::SAddOver:
case BuiltinValueKind::UAddOver:
case BuiltinValueKind::SSubOver:
case BuiltinValueKind::USubOver:
case BuiltinValueKind::SMulOver:
case BuiltinValueKind::UMulOver:
case BuiltinValueKind::FNeg:
case BuiltinValueKind::AssumeNonNegative:
case BuiltinValueKind::ICMP_EQ:
case BuiltinValueKind::ICMP_NE:
case BuiltinValueKind::ICMP_SLE:
case BuiltinValueKind::ICMP_SLT:
case BuiltinValueKind::ICMP_SGE:
case BuiltinValueKind::ICMP_SGT:
case BuiltinValueKind::ICMP_ULE:
case BuiltinValueKind::ICMP_ULT:
case BuiltinValueKind::ICMP_UGE:
case BuiltinValueKind::ICMP_UGT:
case BuiltinValueKind::FCMP_OEQ:
case BuiltinValueKind::FCMP_OGT:
case BuiltinValueKind::FCMP_OGE:
case BuiltinValueKind::FCMP_OLT:
case BuiltinValueKind::FCMP_OLE:
case BuiltinValueKind::FCMP_ONE:
case BuiltinValueKind::FCMP_ORD:
case BuiltinValueKind::FCMP_UEQ:
case BuiltinValueKind::FCMP_UGT:
case BuiltinValueKind::FCMP_UGE:
case BuiltinValueKind::FCMP_ULT:
case BuiltinValueKind::FCMP_ULE:
case BuiltinValueKind::FCMP_UNE:
case BuiltinValueKind::FCMP_UNO:
return true;
}
}
namespace {
/// This helper is used to find promotable memory in the operand chains of
/// tensor operations. This operates on the pre-deabstraction code, so it has
/// to be able to look through the various cases that will be eliminated
/// later.
class PromotableMemoryFinder {
SILFunction &fn;
SmallVectorImpl<AllocStackInst*> &stackAllocs;
SmallPtrSet<SILInstruction*, 32> visited;
TensorFunctionClassifier &tfc;
public:
PromotableMemoryFinder(SmallVectorImpl<AllocStackInst*> &stackAllocs,
TensorFunctionClassifier &tfc, SILFunction &fn)
: fn(fn), stackAllocs(stackAllocs), tfc(tfc) {}
bool run(ArrayRef<SILInstruction*> tensorOps);
private:
void findPromotableMemoryFromValue(SILValue value);
void findPromotableMemoryFromLoadedAddress(SILValue pointer);
void findMainFunctionGlobalAddressRootCandidates(
SmallVectorImpl<std::pair<SILValue, bool>> &addressRoots);
bool canAddressRootBeReliablyPromoted(SILValue root);
void promoteAddressRootsToStack(
ArrayRef<std::pair<SILValue, bool>> addressRoots);
};
} // end anonymous namespace
/// Analyze the dataflow values feeding into the specified tensor operations in
/// order to find promotable stack values and address root references.
///
/// This returns true if any address roots were promoted to stack values.
///
bool PromotableMemoryFinder::run(ArrayRef<SILInstruction*> tensorOps) {
llvm::PrettyStackTraceFormat X("PromotableMemoryFinder::run");
// Find all the promotable memory reachable from tensor ops. This ensures
// we can directly connect their use-def edges together.
for (auto *op : tensorOps) {
for (auto &operand : op->getAllOperands())
findPromotableMemoryFromValue(operand.get());
}
// Next we collect address roots, which are pointers that are not stack
// allocations that we need to promote. We start by collecting candidate
// pointers, then validating them. We keep track of the root pointer as well
// as whether the value starts out uninitialized (which is the case for many
// global roots).
SmallVector<std::pair<SILValue, bool>, 8> addressRoots;
// Check the arguments to the SIL function for any indirect structs/tuples
// that contain tensors. Such functions are generally inlined into the caller
// but can appear this way when the user explicitly specifies @noinline. In
// this case we want to promote the pointer as a root because this allows
// turning the entire body into SSA.
for (auto arg : fn.getArguments()) {
auto convention = cast<SILFunctionArgument>(arg)->getArgumentConvention();
// If this is an indirect argument working on tensors, it is a candidate.
if (convention.isIndirectConvention() &&
tfc.containsTensorFlowValue(arg->getType()))
addressRoots.push_back({ arg, /*startsUninitialized*/false });
}
// If we're in the main function processing top level code, scan the function
// to collect any global_addr instructions which provide address roots. We
// want to promote tensor-related globals and the values that feed into them.
if (fn.getName() == SWIFT_ENTRY_POINT_FUNCTION)
findMainFunctionGlobalAddressRootCandidates(addressRoots);
if (addressRoots.empty())
return false;
// If we've found any address roots, check to see if the computation that
// feeds into them can be reliably promoted.
for (unsigned i = 0; i != addressRoots.size(); ++i) {
if (canAddressRootBeReliablyPromoted(addressRoots[i].first))
continue;
// If we can't promote this root, remove it from our set.
std::swap(addressRoots[i], addressRoots.back());
addressRoots.pop_back();
--i;
}
if (addressRoots.empty())
return false;
// If any address roots were found, predictably promote them to the stack to
// unblock analysis.
promoteAddressRootsToStack(addressRoots);
return true;
}
/// Scan upward through the def-use chains of the specified operand value,
/// looking through operations that we can deabstract. If we find stack
/// allocations along the way, add them to our set.
void PromotableMemoryFinder::findPromotableMemoryFromValue(SILValue value) {
// If we found a non-instruction operand, or an instruction we've already
// visited, then we're done scanning it.
auto *inst = value->getDefiningInstruction();
if (!inst || !visited.insert(inst).second)
return;
// If this is one of the instructions we can deabstract by scalarizing, just
// look through it.
if (isa<TupleInst>(inst) || isa<StructInst>(inst) ||
isa<StructExtractInst>(inst) || isa<TupleExtractInst>(inst)) {
for (auto &operand : inst->getAllOperands())
findPromotableMemoryFromValue(operand.get());
return;
}
// Look through standard LLVM arithmetic operations. We often see silly
// allocas in the way of them.
// FIXME: This is only necessary because we don't have a constexpr model.
// When that is in place and working well, this should be removed.
if (!TFStrictDeabstraction)
if (auto builtin = dyn_cast<BuiltinInst>(inst))
if (isSimpleBuiltinArithmeticOp(builtin))
findPromotableMemoryFromValue(builtin->getOperand(0));
// If this is a load, then we can deabstract it if it is a SRoA'able pointer
// to a stack allocation.
if (auto *load = dyn_cast<LoadInst>(inst))
findPromotableMemoryFromLoadedAddress(load->getOperand());
}
/// The specific pointer is being loaded by a tensor operation operand.
/// Recursively process the pointer - if it is to a stack allocation that we can
/// deabstract, then recursively process any stores into it as values that feed
/// the tensor operation.
void PromotableMemoryFinder::
findPromotableMemoryFromLoadedAddress(SILValue pointer) {
while (isa<TupleElementAddrInst>(pointer) ||
isa<StructElementAddrInst>(pointer) ||
isa<BeginAccessInst>(pointer)) {
pointer = cast<SingleValueInstruction>(pointer)->getOperand(0);
}
// If we've already processed this instruction, then we're done.
auto *pointerInst = pointer->getDefiningInstruction();
if (!pointerInst || !visited.insert(pointerInst).second)
return;
// If the base of the pointer is something other than a stack allocation or if
// we already processed this, then we're done.
auto *alloc = dyn_cast<AllocStackInst>(pointerInst);
if (!alloc)
return;
// Ok, this is a stack allocation we want to promote, remember it.
stackAllocs.push_back(alloc);
// Walk the use-def chains of the allocation, finding any stores that feed
// into it, and recursively processing the values that are store into it.
SmallVector<SILInstruction*, 4> instrsToProcess;
instrsToProcess.push_back(alloc);
while (!instrsToProcess.empty()) {
auto *inst = instrsToProcess.pop_back_val();
for (auto result : inst->getResults())
for (auto use : result->getUses()) {
auto *user = use->getUser();
// If we found a store instruction on the upward pass, and if the store
// is *to* the alloc then we can recursively process the value stored
// into it.
if (auto *store = dyn_cast<StoreInst>(user)) {
// If this is a store *to* the address, then process the stored value
// as an input.
if (use->getOperandNumber() == 1)
findPromotableMemoryFromValue(store->getSrc());
continue;
}
// copy_addr's are a load+store pair.
if (auto *copyaddr = dyn_cast<CopyAddrInst>(user)) {
// If we found a copy_addr into this address during an upward scan,
// then this is a load of the other operand.
if (use->getOperandNumber() == 1)
findPromotableMemoryFromLoadedAddress(copyaddr->getSrc());
}
// If this is the original allocation or an SRoA'able projection of its
// address, then recursively process users.
if (isa<TupleElementAddrInst>(inst) ||
isa<StructElementAddrInst>(inst)) {
instrsToProcess.push_back(user);
continue;
}
// Otherwise we don't know what kind of user this is, ignore it.
}
}
}
/// Find all global addrs in the function, whether or not they involve tensor
/// operations: they could involve tensor values but not be directly used in
/// the ops. If we find a global tensor, make sure to add it to our set. It
/// may be a use of a tensor op, but not being used by one.
///
/// The representation of global addresses is also a bit wonky: There can be
/// multiple global_addr instructions for each global. Later code wants to
/// have a single pointer to reason about, so we canonicalize to one of them.
///
void PromotableMemoryFinder::
findMainFunctionGlobalAddressRootCandidates(
SmallVectorImpl<std::pair<SILValue, bool>> &addressRoots) {
// First collect all the alloc_globals that may be present in the function,
// to ensure we have them all when we start scanning for global_addr's.
DenseMap<SILGlobalVariable*, AllocGlobalInst*> allocGlobals;
for (auto &bb : fn) {
for (auto &inst : bb) {
// If we see an alloc global, remember where it is.
if (auto agi = dyn_cast<AllocGlobalInst>(&inst)) {
auto gv = agi->getReferencedGlobal();
if (tfc.containsTensorFlowValue(gv->getLoweredType())) {
assert(allocGlobals[agi->getReferencedGlobal()] == 0 &&
"more than one alloc_global instruction in the function?");
allocGlobals[gv] = agi;
}
}
}
}
// FIXME: We are missing an important validity check here that checks to
// verify that there are no references to the global *other* than from the
// main function. This is generally true because we inline tensor ops
// aggressively, but can be incorrect in some cases: e.g. a tensor-using
// function is marked @noinline, or such a function just contains a copy.
DenseMap<SILGlobalVariable*, GlobalAddrInst*> globalAddrRoots;
for (auto &bb : fn) {
for (auto bbi = bb.begin(), e = bb.end(); bbi != e; ) {
auto &inst = *(bbi++);
// Process GlobalAddrInst's.
auto ga = dyn_cast<GlobalAddrInst>(&inst);
if (!ga || !tfc.containsTensorFlowValue(ga->getType()))
continue;
// Check to see if this is the first global_addr for this global
// variable. If not, we reuse the existing one, which we know dominates
// our current code.
auto &entry = globalAddrRoots[ga->getReferencedGlobal()];
if (entry) {
ga->replaceAllUsesWith(entry);
ga->eraseFromParent();
continue;
}
// Otherwise, this is the first one, and it will be our canonical
// pointer. If we have a global_alloc, then it starts out uninitialized
// but if we don't (as in the case of the REPL) it is known to be
// previously initialized.
auto allocGlobal = allocGlobals[ga->getReferencedGlobal()];
entry = ga;
addressRoots.push_back({ ga, /*isUninit*/allocGlobal != nullptr });
// If this global_addr is in the entry block, then it will dominate any
// other ones: we know it is the first in the entry block (because we
// scan top to bottom) and we know the entry block dominates everything
// else.
if (ga->getParent() == fn.getEntryBlock())
continue;
// Otherwise, we aren't sure it will dominate all uses. If we saw an
// alloc_global instruction, move it right after that. We know it will
// dominate all uses.
if (allocGlobal) {
ga->moveAfter(allocGlobal);
continue;
}
// Otherwise, move this to the entry block.
ga->moveBefore(fn.getEntryBlock()->getTerminator());
}
}
}
/// Once we've found address roots that we're interested in, walk their uses to
/// see if they are doing things we have confidence in promoting. Notably, we
/// cannot promote something that escapes the pointer.
///
bool PromotableMemoryFinder::canAddressRootBeReliablyPromoted(SILValue root) {
// Check all uses of the root, including direct aliases formed by things
// like begin_access.
SmallVector<SILValue, 4> addrWorklist;
addrWorklist.push_back(root);
while (!addrWorklist.empty()) {
auto addr = addrWorklist.pop_back_val();
// Walk the use chains of the addr, looking for stores to it. Any store
// to it produces a value that feeds it, which can add new stack allocations
// to our set.
for (auto *use : addr->getUses()) {
auto user = use->getUser();
// Take an extremely conservative approach to handling accesses of the
// global, whitelisting specific sorts of uses. If we find anything
// we can't handle, we abort promotion of this root.
if (isa<EndAccessInst>(user) || // Just a marker.
isa<LoadInst>(user) || // Reads are always ok.
isa<DebugValueAddrInst>(user)) // Debug info is ok.
continue;
// Anything that dives into an element of the global can continue to
// dive into the promoted value.
if (isa<StructElementAddrInst>(user) || isa<TupleElementAddrInst>(user))
continue;
// If this is a store *to* the global, analyze the input value.
if (auto *si = dyn_cast<StoreInst>(user)) {
if (use->getOperandNumber() == 1) {
findPromotableMemoryFromValue(si->getOperand(0));
continue;
}
}
// If this is a begin_access instruction, then it is a projection/copy
// of the address. Analyze it too.
if (auto *begin = dyn_cast<BeginAccessInst>(user)) {
addrWorklist.push_back(begin);
continue;
}
// If this is an apply_inst passing the global's address as an indirect
// operand, then we are ok. These generally get inlined, but can occur
// when the user specifies @noinline on a method, for example.
//
if (auto *apply = dyn_cast<ApplyInst>(user)) {
// FIXME: This seems wrong, because it is not counting indirect results.
// See DIMemoryUseCollector's use of getSubstCalleeConv for an example.
auto conventions = apply->getSubstCalleeConv();
assert(conventions.getNumIndirectSILResults() == 0 &&
"FIXME: Handle this");
unsigned opIdx = use->getOperandNumber();
if (auto argIndex = apply->getArgumentIndexForOperandIndex(opIdx)) {
auto paramConvention =
conventions.getParameters()[argIndex.getValue()].getConvention();
if (isIndirectFormalParameter(paramConvention))
continue;
}
}
// Some other unexpected user of the address is left around. We should
// handle this some day, but for now just leave the global access
// unchanged, to avoid miscompiling code.
if (getTFDumpIntermediateStream() == &llvm::outs()) {
// Make this a hard error in the testsuite.
llvm::errs() << "unexpected global_addr user in top level code"
<< " promotion: " << *user << "\n\n";
llvm::errs() << *user->getFunction();
llvm::errs() << "unexpected global_addr user in top level code"
<< " promotion: " << *user << "\n\n";
abort();
}
return false;
}
}
return true;
}
/// Our dataflow analysis of tensor operations has decided that some number of
/// address roots need to be promoted to SSA in order to perform deabstraction,
/// and has verified that this is safe. Perform this transformation now.
void PromotableMemoryFinder::
promoteAddressRootsToStack(ArrayRef<std::pair<SILValue, bool>> addressRoots) {
llvm::PrettyStackTraceFormat X("PromotableMemoryFinder::"
"promoteAddressRootsToStack");
DenseMap<SILValue, AllocStackInst*> stackAllocForRoot;
// Promote each root by making a stack allocation that corresponds to them,
// inserting loads and stores to the real root, and replacing the uses of
// the root instructions with the stack allocation.
for (auto rootInfo : addressRoots) {
auto root = rootInfo.first;
// Create a stack allocation in the entry block for the function.
SILBuilder B(&fn.getEntryBlock()->front());
auto stackAlloc = B.createAllocStack(fn.getLocation(),
root->getType().getObjectType());
stackAllocForRoot[root] = stackAlloc;
// Make sure to convert the generated alloc_stack to SSA.
stackAllocs.push_back(stackAlloc);
// Replace all uses of the root with the stack value.
root->replaceAllUsesWith(stackAlloc);
}
// Find all exit blocks from the function.
SmallVector<SILBasicBlock*, 4> exitBlocks;
for (auto &bb : fn) {
if (isa<ReturnInst>(bb.getTerminator()) ||
isa<ThrowInst>(bb.getTerminator()) ||
isa<UnwindInst>(bb.getTerminator()))
exitBlocks.push_back(&bb);
}
// Insert a stack deallocation plus cleanup in all of the exit blocks.
for (auto rootInfo : addressRoots) {
auto root = rootInfo.first;
auto stackAlloc = stackAllocForRoot[root];
assert(stackAlloc && "where'd our alloc_stack go?");
// In some cases like global variables in top level code, the root will
// start out uninitialized. In other cases, it is already initialized - as
// in indirect arguments to functions or REPL code that reuses a global.
// If it is initialize, emit code to do so.
if (!rootInfo.second) {
auto insertionPoint = rootInfo.first->getDefiningInstruction();
// Insert the initialization after the root or stack alloc.
if (!insertionPoint) insertionPoint = stackAlloc;
SILBuilder B(++SILBasicBlock::iterator(insertionPoint));
auto loc = fn.getLocation();
auto &TL = B.getTypeLowering(stackAlloc->getType());
TL.emitCopyInto(B, loc, root, stackAlloc, IsTake_t::IsNotTake,
IsInitialization_t::IsInitialization);
}
// Process each exit block, inserting epilog code.
for (auto *exit : exitBlocks) {
SILBuilder B(exit->getTerminator());
auto loc = fn.getLocation();
// Load from the stack allocation and store to the root, leaving it
// initialized with our final state.
// If the root started out uninitialized, then this is an initialization
// of it, otherwise this is a reassignment of it.
auto &TL = B.getTypeLowering(stackAlloc->getType());
TL.emitCopyInto(B, loc, stackAlloc, root, IsTake_t::IsTake,
IsInitialization_t(rootInfo.second));
B.createDeallocStack(loc, stackAlloc);
}
}
}
/// Scan the function looking for TensorFlow value AllocStack instructions to
/// promote.
void TFDeabstraction::promoteToSSA(MutableArrayRef<AllocStackInst*> allocs) {
// If there is nothing to promote, don't bother calculating dominator info.
if (allocs.empty())
return;
llvm::PrettyStackTraceFormat X("PromotableMemoryFinder::promoteToSSA");
// Do any necessary preprocessing of the stack allocations before promoting
// them.
for (auto alloc : allocs)
prepareStackAllocForPromotion(alloc);
// Otherwise the function does have tensor operations, so lets promote any
// stack allocations out of the way so we can do simple dataflow analysis.
auto domInfo = passManager->getAnalysis<DominanceAnalysis>()->get(&fn);
promoteAllocsToSSA(allocs, domInfo);
}
/// Preprocess the specified allocation instruction to make it more suitable for
/// promotion to SSA. In particularly, we eliminate CopyAddrInst and other
/// uses that could prevent us from promoting this.
void TFDeabstraction::prepareStackAllocForPromotion(AllocStackInst *alloc) {
// TODO: We will eventually need to do real SRoA to handle the case when
// we have tensor values mixed in with other random values that shouldn't
// (or can't) be loaded. For now, we can just fail to deabstract these
// cases.
for (auto UI = alloc->use_begin(); UI != alloc->use_end();) {
auto inst = (*UI)->getUser();
if (auto sea = dyn_cast<StructElementAddrInst>(inst))
if (auto *use = sea->getSingleUse()) {
// If we have a load(struct_element_addr(alloc)) turn it into
// struct_extract(load(alloc)).
if (auto *load = dyn_cast<LoadInst>(use->getUser())) {
SILBuilder B(load);
auto *newLoad = B.createLoad(load->getLoc(), sea->getOperand(),
load->getOwnershipQualifier());
auto *newVal = B.createStructExtract(load->getLoc(), newLoad,
sea->getField(),
load->getType());
load->replaceAllUsesWith(newVal);
load->eraseFromParent();
++UI;
sea->eraseFromParent();
continue;
}
}
// Explode aggregate by-address instructions like copy-addr.
if (explodeAggregateInst(inst, /*all types*/nullptr)) {
++UI;
inst->eraseFromParent();
continue;
}
// If we have an instruction other than begin_access, remember it.
auto *begin = dyn_cast<BeginAccessInst>(inst);
if (!begin) {
++UI;
continue;
}
// If we have a begin_access instruction, look through it. Add all of the
// users to the users list, and replace uses of begin_access with uses of
// the original value. Finally, ignore and remove the end_access.
for (auto UI = begin->use_begin(); UI != begin->use_end();) {
auto *use = *UI++;
auto inst = use->getUser();
if (isa<EndAccessInst>(inst)) {
inst->eraseFromParent();
} else {
use->set(alloc);
}
}
++UI;
begin->eraseFromParent();
}
}
/// The specified argument has tuple type that deabstraction needs to scalarize.
/// Explode it into its deabstracted elements, rebuilding it and the branch
/// instructions that feed it. This returns a value of the original type that
/// can be used for further analysis.
static SILValue explodeSILTupleArgument(SILPHIArgument *arg) {
SmallVector<SILValue, 4> newArgs;
auto *argBB = arg->getParent();
// Collect all the fields and add new BB arguments to the block for each of
// them.
auto tuple = arg->getType();
unsigned numElements = tuple.castTo<TupleType>()->getNumElements();
for (unsigned i = 0; i != numElements; ++i) {
auto newArg = argBB->createPHIArgument(tuple.getTupleElementType(i),
arg->getOwnershipKind());
newArgs.push_back(newArg);
}
// Now that we have created all of the BB arguments, we can create a new
// tuple inst, replace the old argument, and remove it.
SILBuilder B(&argBB->front());
auto replacement = B.createTuple(argBB->front().getLoc(),
arg->getType(), newArgs);
arg->replaceAllUsesWith(replacement);
unsigned argNo = arg->getIndex();
argBB->eraseArgument(argNo);
// Ok, now that we've exploded the BB argument itself, we need to explode the
// values passed in the predecessor blocks.
for (auto pi : argBB->getPredecessorBlocks()) {
auto *br = cast<BranchInst>(pi->getTerminator());
SmallVector<SILValue, 8> operands;
for (unsigned i = 0, e = br->getNumOperands(); i != e; ++i)
if (i != argNo)
operands.push_back(br->getOperand(i));
auto origValue = br->getOperand(argNo);
B.setInsertionPoint(br);
// Add all of the extracted versions of the elements.
for (unsigned i = 0; i != numElements; ++i)
operands.push_back(B.createTupleExtract(br->getLoc(), origValue, i));
// Replace the branch itself.
SILBuilder(br).createBranch(br->getLoc(), br->getDestBB(), operands);
br->eraseFromParent();
}
// Ok, we're done. Return the generated StructInst that aggregates the
// arguments back to the caller.
return replacement;
}
/// The specified argument has struct type that deabstraction needs to
/// scalarize. Explode it into its deabstracted elements, rebuilding it and the
/// branch instructions that feed it. This returns a value of the original type
/// that can be used for further analysis.
static SILValue explodeSILStructArgument(SILPHIArgument *arg) {
SmallVector<VarDecl*, 4> elementDecls;
SmallVector<SILValue, 4> newArgs;
auto &M = arg->getFunction()->getModule();
auto *argBB = arg->getParent();
auto fnLoc = argBB->getParent()->getLocation();
// Collect all the fields and add new BB arguments to the block for each of
// them.
auto structType = arg->getType();
auto decl = structType.getStructOrBoundGenericStruct();
for (auto fieldDecl : decl->getStoredProperties()) {
elementDecls.push_back(fieldDecl);
auto fieldTy = structType.getFieldType(fieldDecl, M);
auto newArg = argBB->createPHIArgument(fieldTy, arg->getOwnershipKind());
newArgs.push_back(newArg);
}
// Now that we have created all of the BB arguments, we can create a new
// struct inst, replace the old argument, and remove it.
SILBuilder B(&argBB->front());
auto replacement = B.createStruct(fnLoc, arg->getType(), newArgs);
arg->replaceAllUsesWith(replacement);
unsigned argNo = arg->getIndex();
argBB->eraseArgument(argNo);
// Ok, now that we've exploded the BB argument itself, we need to explode the
// values passed in the predecessor blocks.
for (auto pi : argBB->getPredecessorBlocks()) {
auto *br = cast<BranchInst>(pi->getTerminator());
SmallVector<SILValue, 8> operands;
for (unsigned i = 0, e = br->getNumOperands(); i != e; ++i)
if (i != argNo)
operands.push_back(br->getOperand(i));
B.setInsertionPoint(br);
// Add all of the extracted versions of the elements.
auto origValue = br->getOperand(argNo);
for (auto fieldDecl : elementDecls)
operands.push_back(B.createStructExtract(fnLoc, origValue, fieldDecl));
// Replace the branch itself.
SILBuilder(br).createBranch(br->getLoc(), br->getDestBB(), operands);
br->eraseFromParent();
}
// Ok, we're done. Return the generated StructInst that aggregates the
// arguments back to the caller.
return replacement;
}
/// We've promoted any stack allocations that are in the way of tensor operands
/// so we now have proper SSA. Look through struct and tuple injection and
/// projection instructions to find the underlying value that feeds the tensor
/// operation. This is typically another tensor operation or a constant (for
/// attributes) but may be variables or other things that cause a send.
///
static SILValue
propagateTensorOperand(SILValue v,
SmallPtrSet<SILPHIArgument*, 8> &checkedPhis) {
// This is the series of struct/tuple extract indices that the value is
// currently being projected through. Consider an access like this:
// B = struct { #1, #2 }
// C = tuple { #3, B }
// Y = tuple_extract C, 1
// Z = struct_extract Y, 0
// We start analysis at Z, and add the access indices of Z and Y. When we get
// to C, we know that we're accessing element 1 from the tuple because that is
// the top of our access path. When we get to B, we know we're accessing
// element 0 from the access path, so we return the #1 value.
SmallVector<unsigned, 4> accessPath;
SILValue lastRootValue;
while (1) {
// If our access path is empty, this is a candidate that we could return.
if (accessPath.empty())
lastRootValue = v;
if (auto *arg = dyn_cast<SILPHIArgument>(v)) {
// Don't reprocess a PHI argument if we've already seen it.
if (!checkedPhis.insert(arg).second)
break;
// If this is an aggregate basic block argument, explode it into its
// component values.
if (!accessPath.empty()) {
// Do a quick pass over all of the predecessors to see if they are
// unconditional branches. If not, we can't explode them.
// TODO: We should handle things like switch_enum someday.
for (auto pi : arg->getParent()->getPredecessorBlocks()) {
if (!isa<BranchInst>(pi->getTerminator()))
// Cannot explode this BB argument.
return lastRootValue;
}
// We're going to erase 'arg', so don't leave dangling pointers in the
// set.
checkedPhis.erase(arg);
if (arg->getType().is<TupleType>())
v = explodeSILTupleArgument(arg);
else if (arg->getType().is<StructType>() ||
arg->getType().is<BoundGenericStructType>())
v = explodeSILStructArgument(arg);
else
return lastRootValue; // Cannot handle this.
continue;
}
// Otherwise simplify inputs in predecessor blocks.
for (auto pi : arg->getParent()->getPredecessorBlocks()) {
if (auto *br = dyn_cast<BranchInst>(pi->getTerminator())) {
// We intentionally recalculate arg->getIndex() because its index can
// shift. We know that recursive processing won't delete the bb arg
// though, as it is in checkedPhis.
auto incomingVal = br->getOperand(arg->getIndex());
incomingVal = propagateTensorOperand(incomingVal, checkedPhis);
br->setOperand(arg->getIndex(), incomingVal);
}
}
continue;
}
// Otherwise, peer through instructions.
auto inst = v->getDefiningInstruction();
if (!inst)
break;
// Extractions add to the access path.
if (auto extract = dyn_cast<TupleExtractInst>(inst)) {
accessPath.push_back(extract->getFieldNo());
v = extract->getOperand();
continue;
}
if (auto extract = dyn_cast<StructExtractInst>(inst)) {
accessPath.push_back(extract->getFieldNo());
v = extract->getOperand();
continue;
}
// Constructions provide values to extract from if we have an access inside
// of it.
if (!accessPath.empty()) {
if (auto str = dyn_cast<StructInst>(inst)) {
v = str->getOperand(accessPath.pop_back_val());
continue;
}
if (auto tuple = dyn_cast<TupleInst>(inst)) {
v = tuple->getOperand(accessPath.pop_back_val());
continue;
}
}
// Otherwise, this is an unhandled instruction - we're done.
break;
}
return lastRootValue;
}
/// Propagate the operand values for all tensors: this ensures that all tensor
/// operands and results are directly linked together in the SSA graph at the
/// TensorFlow value level, without going through intervening struct/tuple
/// wrappers.
void TFDeabstraction::propagateTensorValues() {
llvm::PrettyStackTraceFormat X("TFDeabstraction::propagateTensorValues");
// Now that we have directly exposed retain/release instructions and tensor
// operations, go through and make sure they are directly linked to each
// other.
SmallPtrSet<SILPHIArgument*, 8> checkedPhis;
for (auto *op : tensorOps) {
for (auto &operand : op->getAllOperands()) {
// Get the propagated value. This call can change the tensor operand.
auto newVal = propagateTensorOperand(operand.get(), checkedPhis);
// Get the (possibly-changed) instruction that used to be feeding the
// tensor operation and set the new value.
auto opInst = operand.get()->getDefiningInstruction();
operand.set(newVal);
// If the old instruction is unused, try to clean up the code.
if (opInst && !opInst->hasUsesOfAnyResult())
recursivelyDeleteTriviallyDeadInstructions(opInst);
}
}
}
/// Create and return a new constant literal instruction for the specified
/// scalar constant value.
///
/// TODO: this should eventually go away when we stop using literal instructions
/// and builtin instructions to represent #tfop. We should switch to a more
/// principled design when we have a custom SIL instruction for graph ops.
static SingleValueInstruction *
emitConstantInst(SymbolicValue symVal, SILType type, SILLocation loc,
SILBuilder &B) {
assert(symVal.isConstant() && "Not a constant value");
switch (symVal.getKind()) {
case SymbolicValue::Unknown:
case SymbolicValue::UninitMemory:
case SymbolicValue::Address:
assert(0 && "Shouldn't happen");
case SymbolicValue::Aggregate:
case SymbolicValue::Function:
// TODO: Unsupported right now.
return nullptr;
case SymbolicValue::Metatype: {
auto mt = MetatypeType::get(symVal.getMetatypeValue())->getCanonicalType();
return B.createMetatype(loc, SILType::getPrimitiveObjectType(mt));
}
case SymbolicValue::Integer:
return B.createIntegerLiteral(loc, type, symVal.getIntegerValue());
case SymbolicValue::Float:
return B.createFloatLiteral(loc, type, symVal.getFloatValue());
case SymbolicValue::String:
return B.createStringLiteral(loc, symVal.getStringValue(),
StringLiteralInst::Encoding::UTF8);
}
}
/// Decode the specified array constant value (which should be an
/// array of constant integer or fp values) and add it as an expanded operand
/// to the specified op that is being built up.
static void expandArrayConstant(ArrayRef<SymbolicValue> arrayElements,
SILType arrayEltType,
StringRef attrName,
SILTensorOpInfo::OperandClass attrKind,
std::string &name,
SmallVectorImpl<SILValue> &operands,
SILInstruction *forInst) {
SILBuilder B(forInst);
// Add the first operand, which is the metatype for the element. If it was
// a 'Normal' operand, change it to an Array so we can distinguish it in the
// case of an empty array.
if (attrKind == SILTensorOpInfo::OperandClass::Normal)
attrKind = SILTensorOpInfo::OperandClass::Array;
name += ","+attrName.str();
name += SILTensorOpInfo::getOperandClassSuffix(attrKind);
auto metatypeType =
MetatypeType::get(arrayEltType.getSwiftRValueType(),
MetatypeRepresentation::Thin)
->getCanonicalType();
operands.push_back(B.createMetatype(forInst->getLoc(),
SILType::getPrimitiveObjectType(metatypeType)));
// Add all of the operands as explicit values. If the instructions came
// from an out of line array initializer, make sure to clone them over to
// our function.
for (auto eltVal : arrayElements) {
auto elt = eltVal.getConstantInstIfPresent();
if (!elt || elt->getFunction() != forInst->getFunction()) {
// Make a copy of the value, it may be computed.
elt = emitConstantInst(eltVal, arrayEltType, forInst->getLoc(), B);
elt->setDebugLocation(B.getSILDebugLocation(forInst->getLoc()));
}
operands.push_back(elt);
name += ",";
auto eltKind = SILTensorOpInfo::OperandClass::ArrayElement;
name += SILTensorOpInfo::getOperandClassSuffix(eltKind);
}
}
/// If the specified type is a Swift.Array or some element type, then return the
/// element type. Otherwise, return a null Type.
static SILType getArrayElementType(SILType ty) {
assert(0 && "FIXME: Implement when array constant prop is up and running");
abort();
#if 0
if (auto bgst = ty->getAs<BoundGenericStructType>())
if (bgst->getDecl() == bgst->getASTContext().getArrayDecl())
return bgst->getGenericArgs()[0];
return Type();
#endif
}
/// If all the operands to a call to __tf_tensor_from_scalars are constants, we
/// can promote this to a 'Const' node with an attached TF_Tensor attribute.
/// It takes a 1D array of scalars and a shape as a 1D array of integers.
///
/// On success, this removes the ApplyInst and returns a pointer to the new
/// BuiltinInst that is created. On failure, it returns a nullptr.
///
/// FIXME: This is a near duplication of the logic used by TFPartitioning in
/// SILTensorOpInfo::decodeTensorFromScalars. When constexpr propagation is
/// done, we should remove the logic in SILTensorOpInfo.
static BuiltinInst *
tryToPromoteTensorFromScalars(ApplyInst *inst,
const DenseMap<SILValue, SymbolicValue> &constants) {
assert(inst->getNumOperands() == 3 && isTensorHandle(inst->getType()) &&
"Unexpected type signature for __tf_tensor_from_scalars");
// If we can't analyze the operands as arrays of constants, give up.
auto scalarIt = constants.find(inst->getOperand(1));
if (scalarIt == constants.end() || !scalarIt->second.isConstant())
return nullptr;
auto shapeIt = constants.find(inst->getOperand(2));
if (shapeIt == constants.end() || !shapeIt->second.isConstant())
return nullptr;
// Okay, we were able to resolve the two arrays of constants. Transform this
// into the correct Const operation.
// We transform this into a __tfop_Const instruction, where the values are
// part of the 'value' tensor attribute and the shape is specified as a shape
// attribute.
SmallVector<SILValue, 8> operands;
std::string name = "__tfop_Const";
// Try to expand the array and the shape into their scalars.
expandArrayConstant(scalarIt->second.getAggregateValue(),
getArrayElementType(inst->getOperand(1)->getType()),
"value",
SILTensorOpInfo::OperandClass::Tensor,
name, operands, inst);
unsigned numElements = operands.size()-1;
expandArrayConstant(shapeIt->second.getAggregateValue(),
getArrayElementType(inst->getOperand(2)->getType()),
"value",
SILTensorOpInfo::OperandClass::Shape,
name, operands, inst);
// Verify we have the right number of scalars. If not, emit an error and
// leave the broken code without promoting it to an op.
uint64_t scalarCount = 1;
std::string errorInfo;
for (auto elt : ArrayRef<SILValue>(operands).drop_front(numElements+2)) {
auto *eltCst = cast<IntegerLiteralInst>(elt);
scalarCount *= eltCst->getValue().getLimitedValue();
}
if (scalarCount != numElements && errorInfo.empty()) {
errorInfo = "tensor literal should have " + llvm::utostr(scalarCount) +
" scalars for this shape, but has " + llvm::utostr(numElements);
}
if (!errorInfo.empty()) {
auto loc = getUserSourceLocation(inst);
diagnose(inst->getType().getSwiftRValueType()->getASTContext(),
loc.getSourceLoc(), diag::tf_op_misuse, errorInfo)
.highlight(loc.getSourceRange());
return nullptr;
}
// This takes a Tensor and a Shape operand, but needs a DType added. The
// dtype is the type of the Tensor elements, which we conveniently already
// have available as the first operand.
operands.push_back(operands[0]);
name += ",dtype";
auto scalarV = inst->getOperand(1);
auto shapeV = inst->getOperand(2);
SILBuilder B(inst);
// Finally build a new builtin instruction with the simplified operands.
auto newInst =
B.createBuiltin(inst->getLoc(),
B.getASTContext().getIdentifier(name),
inst->getType(), /*no substitions*/{},
operands);
newInst->setDebugLocation(inst->getDebugLocation());
inst->replaceAllUsesPairwiseWith(newInst);
inst->eraseFromParent();
// We are dropping a reference to the element and shape array initializers, so
// we need to remove the arrays themselves or at least release them.
SILTensorOpInfo::removeOrDestroyArrayValue(scalarV, inst->getLoc(), B);
SILTensorOpInfo::removeOrDestroyArrayValue(shapeV, inst->getLoc(), B);
return newInst;
}
/// If all the operands to a call to __tf_tensor_from_scalars_1d are constants,
/// we can promote this to a 'Const' node with an attached TF_Tensor attribute.
/// This is a specialized form of __tf_tensor_from_scalars, because the later is
/// defined in terms of a shape of "[scalars.count]" but the performance
/// optimizer is not reliably constant propagating this. When we have a
/// reliable deabstraction pass we can re-evaluate this and hopefully eliminate
/// it in favor of library code in the TensorFlow module.
///
/// On success, this removes the applyexpr and returns a pointer to the new
/// BuiltinInst that is created. On failure, it returns a nullptr.
///
/// FIXME: This is a near duplication of the logic used by TFPartitioning in
/// SILTensorOpInfo::decodeTensorFromScalars1D. When constexpr propagation is
/// done, we should remove the logic in SILTensorOpInfo.
static BuiltinInst *
tryToPromoteTensorFromScalars1D(ApplyInst *inst,
const DenseMap<SILValue, SymbolicValue> &constants) {
assert(inst->getNumOperands() == 2 && isTensorHandle(inst->getType()) &&
"Unexpected type signature for __tf_tensor_from_scalars_1d");
// If we can't analyze the scalars as an arrays of constants, give up.
auto scalarIt = constants.find(inst->getOperand(1));
if (scalarIt == constants.end() || !scalarIt->second.isConstant())
return nullptr;
// We transform this into a __tfop_Const instruction, where the values are
// part of the 'value' tensor attribute and the shape is hard coded.
SmallVector<SILValue, 8> operands;
std::string name = "__tfop_Const";
// Try to expand the array into its scalars.
expandArrayConstant(scalarIt->second.getAggregateValue(),
getArrayElementType(inst->getOperand(1)->getType()),
"value",
SILTensorOpInfo::OperandClass::Tensor,
name, operands, inst);
SILBuilder B(inst);
// This takes a Tensor operand, but needs a Shape and a DType added. At
// this point, the operands list will have a metatype for the tensor as
// the first operand then all the elements.
uint64_t scalarCount = operands.size()-1;
// The shape needs a metatype to be well formed, but nothing actually
// cares what it is. Just re-push the metatype for the tensor elements,
// even though it might be floating point or something else weird.
operands.push_back(operands[0]);
name += ",shape";
auto shapeKind = SILTensorOpInfo::OperandClass::Shape;
name += SILTensorOpInfo::getOperandClassSuffix(shapeKind);
// The shape of a 1d tensor is just the count of elements.
auto &ctx = inst->getFunction()->getASTContext();
auto scalarCountVal =
B.createIntegerLiteral(inst->getLoc(),
SILType::getBuiltinIntegerType(64, ctx),
scalarCount);
operands.push_back(scalarCountVal);
name += ",";
auto arrayEltKind = SILTensorOpInfo::OperandClass::ArrayElement;
name += SILTensorOpInfo::getOperandClassSuffix(arrayEltKind);
// The dtype is the type of the Tensor elements, which we conveniently
// already have available as the first operand.
operands.push_back(operands[0]);
name += ",dtype";
auto arrayValue = inst->getOperand(1);
// Finally build a new builtin instruction with the simplified operands.
auto newInst =
B.createBuiltin(inst->getLoc(),
B.getASTContext().getIdentifier(name),
inst->getType(), /*no substitions*/{},
operands);
newInst->setDebugLocation(inst->getDebugLocation());
inst->replaceAllUsesPairwiseWith(newInst);
inst->eraseFromParent();
// We dropped a reference to the element initializer, so we need to
// remove the array itself or at least release it. This happens after
// creating the replacement builtin, so that element initializers aren't
// dropped.
B.setInsertionPoint(newInst);
SILTensorOpInfo::removeOrDestroyArrayValue(arrayValue, inst->getLoc(), B);
return newInst;
}
/// Canonicalize tensor ops, validating their attribute arguments have
/// constants, and flattening array parameters.
void TFDeabstraction::checkAndCanonicalizeAttributes() {
llvm::PrettyStackTraceFormat
X("TFDeabstraction::checkAndCanonicalizeAttributes");
// Do a big sweep over all of the operands to tensor values, collecting ones
// that we might be interested in being constants into a single list.
SmallVector<SILValue, 32> valuesToCheck;
for (auto *op : tensorOps) {
for (auto &operand : op->getAllOperands()) {
// Dump anything that might be an attribute into the list without too much
// filtering. We take out TensorFlow values since they are the most
// obvious ones we don't care about later, but there may be other minor
// things we over-query on.
auto value = operand.get();
if (!isTensorFlowValue(value->getType()))
valuesToCheck.push_back(value);
}
}
// Eliminate duplicates and sort the array of values so we have an efficient
// way to query it later.
llvm::array_pod_sort(valuesToCheck.begin(), valuesToCheck.end());
valuesToCheck.erase(std::unique(valuesToCheck.begin(), valuesToCheck.end()),
valuesToCheck.end());
// Determine whether each value is a constant or not.
// TODO: Capture information about *WHY* values are not constants, e.g. the
// first SIL instruction that could not be folded.
SmallVector<SymbolicValue, 32> results;
constantEvaluator.computeConstantValues(valuesToCheck, results);
assert(valuesToCheck.size() == results.size() && "incorrect values returned");
// Transform the returned information about constants into a map that we can
// query. The results list should correspond directly to the values we asked
// about.
DenseMap<SILValue, SymbolicValue> constants;
for (unsigned i = 0, e = valuesToCheck.size(); i != e; ++i)
constants.insert({valuesToCheck[i], results[i]});
// Now that we've computed whether any of the operands are constants,
// substitute them into the operations that we have, eliminating abstractions.
// This makes it immediately obvious to partitioning what is and isn't a
// constant.
for (auto *&inst : tensorOps) {
// Take a look at the various well known function calls that we can promote
// to tensor operations. We can promote them if we are able to constant
// fold all of the operands to these calls. If so, we rewrite them in terms
// of a proper op, and partitioning will continue to treat them that way.
if (auto apply = dyn_cast<ApplyInst>(inst)) {
// FIXME: Move this upgrading logic out of SILTensorOpInfo into
// Deabstraction once partitioning is moved up to the mandatory passes.
if (!SILTensorOpInfo::isDecodableApply(apply))
continue;
auto name = apply->getCalleeFunction()->getName();
BuiltinInst *result;
if (name == "__tf_tensor_from_scalars")
result = tryToPromoteTensorFromScalars(apply, constants);
else if (name == "__tf_tensor_from_scalars_1d")
result = tryToPromoteTensorFromScalars1D(apply, constants);
else
llvm_unreachable("out of sync with isDecodableApply");
// If promotion failed, no change is necessary.
if (!result) continue;
// Otherwise, we got a new instruction, so remember it in our tensor op
// list.
inst = result;
// Fall into the normal op processing code.
}
// TODO: Handle normal tensor ops with their attributes, subsuming the
// following loop.
}
for (auto &BB : fn) {
for (auto I = BB.begin(), E = BB.end(); I != E; ) {
// Manually move iterator to avoid invalidation if we replace 'inst'.
auto *inst = &*I++;
// If this is a well known function that can be transformed into an op, do
// so first.
// FIXME: This should take into consideration the constants we just
// computed!
if (auto apply = dyn_cast<ApplyInst>(inst))
inst = SILTensorOpInfo::decodeApply(apply);
// Try to decode this instruction as an op. If it isn't one, ignore it.
auto opInfo = SILTensorOpInfo::decode(inst);
if (!opInfo)
continue;
// TODO: Deabstraction isn't fully handling all constant expressions and
// other canonicalizations that we expect, so for now we depend on the
// performance optimizer. When deabstraction is done, we will run the
// partitioner as part of deabstraction (including at -O0). Until we are
// ready for that, we gate the validation of tensor operations on a flag.
// This allows us to write testcases without breaking current use of the
// compiler.
if (!TFStrictDeabstraction)
continue;
// Use the constants we just computed to substitute into parameter values
// if we don't already have them.
// TODO: this should eventually create a new SILInstruction to represent
// the graph operation instead of using SIL instructions to represent the
// constants.
bool isError = false;
SILBuilder B(opInfo->inst);
for (unsigned i = 0, e = opInfo->operandClasses.size(); i != e; ++i) {
// Ignore input operands.
if (opInfo->isInput(i))
continue;
// Ok, we have an attribute operand. If it is already trivially a
// constant, just leave it alone.
auto operand = inst->getOperand(i);
if (isa<FloatLiteralInst>(operand) ||
isa<IntegerLiteralInst>(operand) ||
isa<StringLiteralInst>(operand) ||
isa<MetatypeInst>(operand))
continue;
// Otherwise, we should have been able to fold it through our constexpr
// evaluation logic.
SILValue newVal;
auto it = constants.find(operand);
assert(it != constants.end() &&
"out of sync with constant scanning loop above");
// Given that we found a constant, materialize it as an instruction and
// swap it in for our variable argument.
if (it->second.isConstant())
newVal = emitConstantInst(it->second, operand->getType(),
opInfo->inst->getLoc(), B);
if (!newVal) {
auto opClass = opInfo->operandClasses[i];
auto error = "attribute '" + opClass.first.str() +
"' requires a constant argument";
// TODO: improve the diagnostic to talk about the parameter label in
// the user code, not the internal op attribute. The bookkeeping for
// this isn't obvious though.
auto loc = getUserSourceLocation(inst);
diagnose(fn.getModule().getASTContext(), loc.getSourceLoc(),
diag::tf_op_misuse, error)
.highlight(loc.getSourceRange());
isError = true;
// If we have more specific information about what went wrong, emit
// notes.
if (it->second.getKind() == SymbolicValue::Unknown)
it->second.emitUnknownDiagnosticNotes();
break;
}
inst->setOperand(i, newVal);
}
// Don't emit a second error for this op if we already emitted one.
if (isError)
continue;
// Check to see if the usage of this op looks ok. If not, reject it with
// an error and ignore it.
auto error = opInfo->checkAndDiagnoseOperands();
if (!error.empty()) {
// TODO: improve the diagnostic to talk about the parameter label in the
// user code, not the internal op attribute. The bookkeeping for this
// isn't obvious though.
auto loc = getUserSourceLocation(inst);
diagnose(fn.getModule().getASTContext(), loc.getSourceLoc(),
diag::tf_op_misuse, error)
.highlight(loc.getSourceRange());
continue;
}
// If the tensor operation uses array parameters or has scalar values that
// are passed through memory, promote them to being simple arguments to
// make all subsequent analyses and promotion of the tensor operations
// simpler.
opInfo->canonicalizeOperands(/*configuration*/ nullptr);
}
}
}
/// Process the specified top level function as a deabstraction context: if it
/// contains Tensor operations simplify the code using predictable rules until
/// the tensor operations are exposed in a canonical form inside of this
/// function.
///
/// We currently make use of the following techniques to do this:
/// 1) Inlining. We look for direct calls to functions that take and return
/// values of TensorFlow values, possibly wrapped by structs and tuples.
/// 2) Promotion of globals to stack allocations for Playgrounds, REPL, and
/// top level code in scripts.
/// 3) SSA Promotion of stack values to registers.
/// 4) Scalarization of struct/tuple values.
///
/// TODO:
/// *) Move tensor op canonicalization up from tf-partition.
/// *) Enums. What can we reliably do with them? Should they be out of
/// model? We can definitely do ones without payload values.
///
void TFDeabstraction::doIt() {
// Start by inlining functions that take and return Tensor values.
inlineCalls();
// Scan for any Tensor operations, removing indirect operands and structs that
// interfere with SSA construction.
simplifyTensorOperands();
// If we didn't find any ops, early exit processing of this function to save
// compile time.
if (tensorOps.empty())
return;
logCurrentState("After simplifyTensorOperands", /*detailed*/true);
// Scan over all of the operands of the tensor ops, finding stack allocations
// that we want to promote to SSA.
SmallVector<AllocStackInst*, 16> stackAllocs;
if (PromotableMemoryFinder(stackAllocs, tfc, fn).run(tensorOps)) {
logCurrentState("After promoteAddressRootsToStack",
/*detailed*/true);
}
// Promote stack allocations to SSA, this allows us to do dataflow analysis,
// and eliminates mutation from tensor values.
promoteToSSA(stackAllocs);
logCurrentState("After promoteToSSA", /*detailed*/true);
// Now that we've promoted all the allocations in the way of our dataflow,
// go through and propagate any tuple/struct values that are in the way of
// our analysis.
propagateTensorValues();
logCurrentState("After propagateTensorValues", /*detailed*/true);
// Canonicalize attribute arguments, checking that they have constants, and
// flattening array attributes.
checkAndCanonicalizeAttributes();
logCurrentState("Result", /*detailed*/false);
// We're currently relying on the performance optimizer to do some stuff, but
// for large testcacses it is doing bad stuff.
// FIXME: Should be eliminated when partitioning happens as part of
// deabstraction, because then the optimizer won't be seeing all of our tensor
// stuff.
if (!TFStrictDeabstraction)
fn.getModule().getOptions().EnableARCOptimizations = false;
}
namespace {
struct TFDeabstractionPass : public SILModuleTransform {
/// The entry point to the transformation, runs deabstraction on an entire
/// module.
void run() override;
};
} // end anonymous namespace
void TFDeabstractionPass::run() {
SILModule *module = getModule();
auto &ctx = module->getASTContext();
// If the TensorFlow module hasn't been imported by the program, don't do
// anything. This avoids impacting compile time for non-TensorFlow using
// Swift programs by doing extraneous analysis.
auto tfModule = ctx.getLoadedModule(ctx.getIdentifier("TensorFlow"));
if (!tfModule)
return;
// If we are running on the TensorFlow module itself, do not perform
// deabstraction. It contains a lot of code that processes TensorHandle and
// other types as host values, and we do not want to force inline all of these
// things together.
//
// TODO: Rework the heuristics in inlineCalls() to be smarter. In an ideal
// world, we would be lazy about inlining, and only inline calls due to actual
// inter-op value uses.
if (module->getSwiftModule() == tfModule)
return;
TensorFunctionClassifier tfc;
ConstExprEvaluator constantEvaluator(*module);
// Loop over all of the functions in the current module processing them -
// iff they look like they could be the top level of a deabstraction
// context.
for (auto &fn : *module) {
// If this function is a building block of larger tensor programs (e.g.
// the ops defined in the TensorFlow module), then don't transform it in
// isolation.
if (!tfc.shouldBePartitioned(&fn))
continue;
// If something crashes, make sure the pretty stack trace says what we
// were doing.
llvm::PrettyStackTraceFormat X("TFDeabstraction on function %s",
fn.getName().str().c_str());
TFDeabstraction(fn, tfc, constantEvaluator, PM).doIt();
// TODO(clattner): This should eventually be the driver that kicks off
// the partitioning pass as part of it, and the partitioning and later
// passes are just function passes that are invoked by this one. Until
// we are ready for that, let them run later in the pipeline after the
// other optimization and cleanup passes.
}
}
SILTransform *swift::createTFDeabstraction() {
return new TFDeabstractionPass();
}