| //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| |
| using namespace mlir; |
| |
| namespace mlir { |
| namespace arith { |
| namespace { |
| |
| struct AddIOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> { |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto addIOp = cast<AddIOp>(op); |
| assert(value == addIOp.getResult() && "invalid value"); |
| |
| // Note: `getExpr` has a side effect: it may add a new column to the |
| // constraint system. The evaluation order of addition operands is |
| // unspecified in C++. To make sure that all compilers produce the exact |
| // same results (that can be FileCheck'd), it is important that `getExpr` |
| // is called first and assigned to temporary variables, and the addition |
| // is performed afterwards. |
| AffineExpr lhs = cstr.getExpr(addIOp.getLhs()); |
| AffineExpr rhs = cstr.getExpr(addIOp.getRhs()); |
| cstr.bound(value) == lhs + rhs; |
| } |
| }; |
| |
| struct ConstantOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface, |
| ConstantOp> { |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto constantOp = cast<ConstantOp>(op); |
| assert(value == constantOp.getResult() && "invalid value"); |
| |
| if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue())) |
| cstr.bound(value) == attr.getInt(); |
| } |
| }; |
| |
| struct SubIOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> { |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto subIOp = cast<SubIOp>(op); |
| assert(value == subIOp.getResult() && "invalid value"); |
| |
| AffineExpr lhs = cstr.getExpr(subIOp.getLhs()); |
| AffineExpr rhs = cstr.getExpr(subIOp.getRhs()); |
| cstr.bound(value) == lhs - rhs; |
| } |
| }; |
| |
| struct MulIOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> { |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto mulIOp = cast<MulIOp>(op); |
| assert(value == mulIOp.getResult() && "invalid value"); |
| |
| AffineExpr lhs = cstr.getExpr(mulIOp.getLhs()); |
| AffineExpr rhs = cstr.getExpr(mulIOp.getRhs()); |
| cstr.bound(value) == lhs *rhs; |
| } |
| }; |
| |
| struct SelectOpInterface |
| : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, |
| SelectOp> { |
| |
| static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, |
| ValueBoundsConstraintSet &cstr) { |
| Value value = selectOp.getResult(); |
| Value condition = selectOp.getCondition(); |
| Value trueValue = selectOp.getTrueValue(); |
| Value falseValue = selectOp.getFalseValue(); |
| |
| if (isa<ShapedType>(condition.getType())) { |
| // If the condition is a shaped type, the condition is applied |
| // element-wise. All three operands must have the same shape. |
| cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim); |
| cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim); |
| cstr.bound(value)[*dim] == cstr.getExpr(condition, dim); |
| return; |
| } |
| |
| // Populate constraints for the true/false values (and all values on the |
| // backward slice, as long as the current stop condition is not satisfied). |
| cstr.populateConstraints(trueValue, dim); |
| cstr.populateConstraints(falseValue, dim); |
| auto boundsBuilder = cstr.bound(value); |
| if (dim) |
| boundsBuilder[*dim]; |
| |
| // Compare yielded values. |
| // If trueValue <= falseValue: |
| // * result <= falseValue |
| // * result >= trueValue |
| if (cstr.compare(/*lhs=*/{trueValue, dim}, |
| ValueBoundsConstraintSet::ComparisonOperator::LE, |
| /*rhs=*/{falseValue, dim})) { |
| if (dim) { |
| cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); |
| cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); |
| } else { |
| cstr.bound(value) >= trueValue; |
| cstr.bound(value) <= falseValue; |
| } |
| } |
| // If falseValue <= trueValue: |
| // * result <= trueValue |
| // * result >= falseValue |
| if (cstr.compare(/*lhs=*/{falseValue, dim}, |
| ValueBoundsConstraintSet::ComparisonOperator::LE, |
| /*rhs=*/{trueValue, dim})) { |
| if (dim) { |
| cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); |
| cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); |
| } else { |
| cstr.bound(value) >= falseValue; |
| cstr.bound(value) <= trueValue; |
| } |
| } |
| } |
| |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); |
| } |
| |
| void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| ValueBoundsConstraintSet &cstr) const { |
| populateBounds(cast<SelectOp>(op), dim, cstr); |
| } |
| }; |
| } // namespace |
| } // namespace arith |
| } // namespace mlir |
| |
| void mlir::arith::registerValueBoundsOpInterfaceExternalModels( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
| arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx); |
| arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx); |
| arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx); |
| arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx); |
| arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx); |
| }); |
| } |