| //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" |
| |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| |
| using namespace mlir; |
| |
| namespace mlir { |
| namespace scf { |
| namespace { |
| |
| struct ForOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { |
| |
| /// Populate bounds of values/dimensions for iter_args/OpResults. If the |
| /// value/dimension size does not change in an iteration, we can deduce that |
| /// it the same as the initial value/dimension. |
| /// |
| /// Example 1: |
| /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
| /// ... |
| /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> |
| /// scf.yield %1 : tensor<?xf32> |
| /// } |
| /// --> bound(%0)[0] == bound(%t)[0] |
| /// --> bound(%arg0)[0] == bound(%t)[0] |
| /// |
| /// Example 2: |
| /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
| /// %sz = tensor.dim %arg0 : tensor<?xf32> |
| /// %incr = arith.addi %sz, %c1 : index |
| /// %1 = tensor.empty(%incr) : tensor<?xf32> |
| /// scf.yield %1 : tensor<?xf32> |
| /// } |
| /// --> The yielded tensor dimension size changes with each iteration. Such |
| /// loops are not supported and no constraints are added. |
| static void populateIterArgBounds(scf::ForOp forOp, Value value, |
| std::optional<int64_t> dim, |
| ValueBoundsConstraintSet &cstr) { |
| // `value` is an iter_arg or an OpResult. |
| int64_t iterArgIdx; |
| if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
| iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); |
| } else { |
| iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
| } |
| |
| Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) |
| .getOperand(iterArgIdx); |
| Value iterArg = forOp.getRegionIterArg(iterArgIdx); |
| Value initArg = forOp.getInitArgs()[iterArgIdx]; |
| |
| // An EQ constraint can be added if the yielded value (dimension size) |
| // equals the corresponding block argument (dimension size). |
| if (cstr.populateAndCompare( |
| /*lhs=*/{yieldedValue, dim}, |
| ValueBoundsConstraintSet::ComparisonOperator::EQ, |
| /*rhs=*/{iterArg, dim})) { |
| if (dim.has_value()) { |
| cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); |
| } else { |
| cstr.bound(value) == cstr.getExpr(initArg); |
| } |
| } |
| } |
| |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto forOp = cast<ForOp>(op); |
| |
| if (value == forOp.getInductionVar()) { |
| // TODO: Take into account step size. |
| cstr.bound(value) >= forOp.getLowerBound(); |
| cstr.bound(value) < forOp.getUpperBound(); |
| return; |
| } |
| |
| // Handle iter_args and OpResults. |
| populateIterArgBounds(forOp, value, std::nullopt, cstr); |
| } |
| |
| void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| ValueBoundsConstraintSet &cstr) const { |
| auto forOp = cast<ForOp>(op); |
| // Handle iter_args and OpResults. |
| populateIterArgBounds(forOp, value, dim, cstr); |
| } |
| }; |
| |
| struct IfOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { |
| |
| static void populateBounds(scf::IfOp ifOp, Value value, |
| std::optional<int64_t> dim, |
| ValueBoundsConstraintSet &cstr) { |
| unsigned int resultNum = cast<OpResult>(value).getResultNumber(); |
| Value thenValue = ifOp.thenYield().getResults()[resultNum]; |
| Value elseValue = ifOp.elseYield().getResults()[resultNum]; |
| |
| auto boundsBuilder = cstr.bound(value); |
| if (dim) |
| boundsBuilder[*dim]; |
| |
| // Compare yielded values. |
| // If thenValue <= elseValue: |
| // * result <= elseValue |
| // * result >= thenValue |
| if (cstr.populateAndCompare( |
| /*lhs=*/{thenValue, dim}, |
| ValueBoundsConstraintSet::ComparisonOperator::LE, |
| /*rhs=*/{elseValue, dim})) { |
| if (dim) { |
| cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); |
| cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); |
| } else { |
| cstr.bound(value) >= thenValue; |
| cstr.bound(value) <= elseValue; |
| } |
| } |
| // If elseValue <= thenValue: |
| // * result <= thenValue |
| // * result >= elseValue |
| if (cstr.populateAndCompare( |
| /*lhs=*/{elseValue, dim}, |
| ValueBoundsConstraintSet::ComparisonOperator::LE, |
| /*rhs=*/{thenValue, dim})) { |
| if (dim) { |
| cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); |
| cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); |
| } else { |
| cstr.bound(value) >= elseValue; |
| cstr.bound(value) <= thenValue; |
| } |
| } |
| } |
| |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr); |
| } |
| |
| void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| ValueBoundsConstraintSet &cstr) const { |
| populateBounds(cast<IfOp>(op), value, dim, cstr); |
| } |
| }; |
| |
| } // namespace |
| } // namespace scf |
| } // namespace mlir |
| |
| void mlir::scf::registerValueBoundsOpInterfaceExternalModels( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
| scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); |
| scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); |
| }); |
| } |