blob: 5fb7953f9370076621ee70586855960090c40144 [file] [log] [blame]
//===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::arith;
/// Build Arith IR for the given affine map and its operands.
static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
ValueRange operands) {
assert(map.getNumResults() == 1 && "multiple results not supported yet");
std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
switch (e.getKind()) {
case AffineExprKind::Constant:
return b.create<ConstantIndexOp>(loc,
cast<AffineConstantExpr>(e).getValue());
case AffineExprKind::DimId:
return operands[cast<AffineDimExpr>(e).getPosition()];
case AffineExprKind::SymbolId:
return operands[cast<AffineSymbolExpr>(e).getPosition() +
map.getNumDims()];
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::FloorDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::CeilDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
}
llvm_unreachable("unsupported AffineExpr kind");
};
return buildExpr(map.getResult(0));
}
FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
SmallVector<Value> operands;
for (auto valueDim : mapOperands) {
Value value = valueDim.first;
std::optional<int64_t> dim = valueDim.second;
if (!dim.has_value()) {
// This is an index-typed value.
assert(value.getType().isIndex() && "expected index type");
operands.push_back(value);
continue;
}
assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
}
// Check for special cases where no arith ops are needed.
if (boundMap.isSingleConstant()) {
// Bound is a constant: return an IntegerAttr.
return static_cast<OpFoldResult>(
b.getIndexAttr(boundMap.getSingleConstantResult()));
}
// No arith ops are needed if the bound is a single SSA value.
if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(operands[expr.getPosition()]);
if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(
operands[expr.getPosition() + boundMap.getNumDims()]);
// General case: build Arith ops.
return static_cast<OpFoldResult>(buildArithValue(b, loc, boundMap, operands));
}
FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value expect for `value`. I.e., the bound will be computed in terms of
// any SSA values expect for `value`. The first such values are operands of
// the owner of `value`.
return v != value;
};
return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}