blob: 28531e600ffc54bfad8ccd2727cf94fa55cd27cb [file] [log] [blame]
//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc.getOperation()->getBlock();
alloc.getOperation()->moveBefore(&parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
return alloc;
}
/// This defines the function type used to process an iteration of a lowered
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
/// to the operands of the input operation, and the set of loop induction
/// variables for the iteration. It returns a value to store at the current
/// index of the iteration.
using LoopIterationFn = function_ref<Value(PatternRewriter &rewriter,
ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs)>;
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Create an empty affine loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
for (auto dim : tensorType.getShape()) {
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
loop.getBody()->clear();
loopIvs.push_back(loop.getInductionVar());
// Terminate the loop body and update the rewriter insertion point to the
// beginning of the loop.
rewriter.setInsertionPointToStart(loop.getBody());
rewriter.create<AffineTerminatorOp>(loc);
rewriter.setInsertionPointToStart(loop.getBody());
}
// Generate a call to the processing function with the rewriter, the memref
// operands, and the loop induction variables. This function will return the
// value to store at the current index.
Value valueToStore = processIteration(rewriter, operands, loopIvs);
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
llvm::makeArrayRef(loopIvs));
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
}
namespace {
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs) {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Constant operations
//===----------------------------------------------------------------------===//
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// We will be generating constant indices up-to the largest dimension.
// Create these constants up-front to avoid large amounts of redundant
// operations.
auto valueShape = memRefType.getShape();
SmallVector<Value, 8> constantIndices;
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<AffineStoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};
// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Transpose operations
//===----------------------------------------------------------------------===//
struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// ToyToAffineLoweringPass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to affine loops of the toy operations that are
/// computationally intensive (like matmul for example...) while keeping the
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void ToyToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<toy::ToyDialect>();
target.addLegalOp<toy::PrintOp>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
OwningRewritePatternList patterns;
patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
/// for a subset of the Toy IR (e.g. matmul).
std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
return std::make_unique<ToyToAffineLoweringPass>();
}