| //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Defines the control flow operations in the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| #include "mlir/Interfaces/CallInterfaces.h" |
| |
| #include "SPIRVOpUtils.h" |
| #include "SPIRVParsingUtils.h" |
| |
| using namespace mlir::spirv::AttrNames; |
| |
| namespace mlir::spirv { |
| |
| /// Parses Function, Selection and Loop control attributes. If no control is |
| /// specified, "None" is used as a default. |
| template <typename EnumAttrClass, typename EnumClass> |
| static ParseResult |
| parseControlAttribute(OpAsmParser &parser, OperationState &state, |
| StringRef attrName = spirv::attributeName<EnumClass>()) { |
| if (succeeded(parser.parseOptionalKeyword(kControl))) { |
| EnumClass control; |
| if (parser.parseLParen() || |
| spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) || |
| parser.parseRParen()) |
| return failure(); |
| return success(); |
| } |
| // Set control to "None" otherwise. |
| Builder builder = parser.getBuilder(); |
| state.addAttribute(attrName, |
| builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0))); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.BranchOp |
| //===----------------------------------------------------------------------===// |
| |
| SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { |
| assert(index == 0 && "invalid successor index"); |
| return SuccessorOperands(0, getTargetOperandsMutable()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.BranchConditionalOp |
| //===----------------------------------------------------------------------===// |
| |
| SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) { |
| assert(index < 2 && "invalid successor index"); |
| return SuccessorOperands(index == kTrueIndex |
| ? getTrueTargetOperandsMutable() |
| : getFalseTargetOperandsMutable()); |
| } |
| |
| ParseResult BranchConditionalOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| OpAsmParser::UnresolvedOperand condInfo; |
| Block *dest; |
| |
| // Parse the condition. |
| Type boolTy = builder.getI1Type(); |
| if (parser.parseOperand(condInfo) || |
| parser.resolveOperand(condInfo, boolTy, result.operands)) |
| return failure(); |
| |
| // Parse the optional branch weights. |
| if (succeeded(parser.parseOptionalLSquare())) { |
| IntegerAttr trueWeight, falseWeight; |
| NamedAttrList weights; |
| |
| auto i32Type = builder.getIntegerType(32); |
| if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || |
| parser.parseComma() || |
| parser.parseAttribute(falseWeight, i32Type, "weight", weights) || |
| parser.parseRSquare()) |
| return failure(); |
| |
| StringAttr branchWeightsAttrName = |
| BranchConditionalOp::getBranchWeightsAttrName(result.name); |
| result.addAttribute(branchWeightsAttrName, |
| builder.getArrayAttr({trueWeight, falseWeight})); |
| } |
| |
| // Parse the true branch. |
| SmallVector<Value, 4> trueOperands; |
| if (parser.parseComma() || |
| parser.parseSuccessorAndUseList(dest, trueOperands)) |
| return failure(); |
| result.addSuccessors(dest); |
| result.addOperands(trueOperands); |
| |
| // Parse the false branch. |
| SmallVector<Value, 4> falseOperands; |
| if (parser.parseComma() || |
| parser.parseSuccessorAndUseList(dest, falseOperands)) |
| return failure(); |
| result.addSuccessors(dest); |
| result.addOperands(falseOperands); |
| result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), |
| builder.getDenseI32ArrayAttr( |
| {1, static_cast<int32_t>(trueOperands.size()), |
| static_cast<int32_t>(falseOperands.size())})); |
| |
| return success(); |
| } |
| |
| void BranchConditionalOp::print(OpAsmPrinter &printer) { |
| printer << ' ' << getCondition(); |
| |
| if (auto weights = getBranchWeights()) { |
| printer << " ["; |
| llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { |
| printer << llvm::cast<IntegerAttr>(a).getInt(); |
| }); |
| printer << "]"; |
| } |
| |
| printer << ", "; |
| printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); |
| printer << ", "; |
| printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); |
| } |
| |
| LogicalResult BranchConditionalOp::verify() { |
| if (auto weights = getBranchWeights()) { |
| if (weights->getValue().size() != 2) { |
| return emitOpError("must have exactly two branch weights"); |
| } |
| if (llvm::all_of(*weights, [](Attribute attr) { |
| return llvm::cast<IntegerAttr>(attr).getValue().isZero(); |
| })) |
| return emitOpError("branch weights cannot both be zero"); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.FunctionCall |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult FunctionCallOp::verify() { |
| auto fnName = getCalleeAttr(); |
| |
| auto funcOp = dyn_cast_or_null<spirv::FuncOp>( |
| SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); |
| if (!funcOp) { |
| return emitOpError("callee function '") |
| << fnName.getValue() << "' not found in nearest symbol table"; |
| } |
| |
| auto functionType = funcOp.getFunctionType(); |
| |
| if (getNumResults() > 1) { |
| return emitOpError( |
| "expected callee function to have 0 or 1 result, but provided ") |
| << getNumResults(); |
| } |
| |
| if (functionType.getNumInputs() != getNumOperands()) { |
| return emitOpError("has incorrect number of operands for callee: expected ") |
| << functionType.getNumInputs() << ", but provided " |
| << getNumOperands(); |
| } |
| |
| for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { |
| if (getOperand(i).getType() != functionType.getInput(i)) { |
| return emitOpError("operand type mismatch: expected operand type ") |
| << functionType.getInput(i) << ", but provided " |
| << getOperand(i).getType() << " for operand number " << i; |
| } |
| } |
| |
| if (functionType.getNumResults() != getNumResults()) { |
| return emitOpError( |
| "has incorrect number of results has for callee: expected ") |
| << functionType.getNumResults() << ", but provided " |
| << getNumResults(); |
| } |
| |
| if (getNumResults() && |
| (getResult(0).getType() != functionType.getResult(0))) { |
| return emitOpError("result type mismatch: expected ") |
| << functionType.getResult(0) << ", but provided " |
| << getResult(0).getType(); |
| } |
| |
| return success(); |
| } |
| |
| CallInterfaceCallable FunctionCallOp::getCallableForCallee() { |
| return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName()); |
| } |
| |
| void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>()); |
| } |
| |
| Operation::operand_range FunctionCallOp::getArgOperands() { |
| return getArguments(); |
| } |
| |
| MutableOperandRange FunctionCallOp::getArgOperandsMutable() { |
| return getArgumentsMutable(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.mlir.loop |
| //===----------------------------------------------------------------------===// |
| |
| void LoopOp::build(OpBuilder &builder, OperationState &state) { |
| state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>( |
| spirv::LoopControl::None)); |
| state.addRegion(); |
| } |
| |
| ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { |
| if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser, |
| result)) |
| return failure(); |
| return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); |
| } |
| |
| void LoopOp::print(OpAsmPrinter &printer) { |
| auto control = getLoopControl(); |
| if (control != spirv::LoopControl::None) |
| printer << " control(" << spirv::stringifyLoopControl(control) << ")"; |
| printer << ' '; |
| printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the |
| /// given `dstBlock`. |
| static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { |
| // Check that there is only one op in the `srcBlock`. |
| if (!llvm::hasSingleElement(srcBlock)) |
| return false; |
| |
| auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back()); |
| return branchOp && branchOp.getSuccessor() == &dstBlock; |
| } |
| |
| /// Returns true if the given `block` only contains one `spirv.mlir.merge` op. |
| static bool isMergeBlock(Block &block) { |
| return !block.empty() && std::next(block.begin()) == block.end() && |
| isa<spirv::MergeOp>(block.front()); |
| } |
| |
| LogicalResult LoopOp::verifyRegions() { |
| auto *op = getOperation(); |
| |
| // We need to verify that the blocks follow the following layout: |
| // |
| // +-------------+ |
| // | entry block | |
| // +-------------+ |
| // | |
| // v |
| // +-------------+ |
| // | loop header | <-----+ |
| // +-------------+ | |
| // | |
| // ... | |
| // \ | / | |
| // v | |
| // +---------------+ | |
| // | loop continue | -----+ |
| // +---------------+ |
| // |
| // ... |
| // \ | / |
| // v |
| // +-------------+ |
| // | merge block | |
| // +-------------+ |
| |
| auto ®ion = op->getRegion(0); |
| // Allow empty region as a degenerated case, which can come from |
| // optimizations. |
| if (region.empty()) |
| return success(); |
| |
| // The last block is the merge block. |
| Block &merge = region.back(); |
| if (!isMergeBlock(merge)) |
| return emitOpError("last block must be the merge block with only one " |
| "'spirv.mlir.merge' op"); |
| |
| if (std::next(region.begin()) == region.end()) |
| return emitOpError( |
| "must have an entry block branching to the loop header block"); |
| // The first block is the entry block. |
| Block &entry = region.front(); |
| |
| if (std::next(region.begin(), 2) == region.end()) |
| return emitOpError( |
| "must have a loop header block branched from the entry block"); |
| // The second block is the loop header block. |
| Block &header = *std::next(region.begin(), 1); |
| |
| if (!hasOneBranchOpTo(entry, header)) |
| return emitOpError( |
| "entry block must only have one 'spirv.Branch' op to the second block"); |
| |
| if (std::next(region.begin(), 3) == region.end()) |
| return emitOpError( |
| "requires a loop continue block branching to the loop header block"); |
| // The second to last block is the loop continue block. |
| Block &cont = *std::prev(region.end(), 2); |
| |
| // Make sure that we have a branch from the loop continue block to the loop |
| // header block. |
| if (llvm::none_of( |
| llvm::seq<unsigned>(0, cont.getNumSuccessors()), |
| [&](unsigned index) { return cont.getSuccessor(index) == &header; })) |
| return emitOpError("second to last block must be the loop continue " |
| "block that branches to the loop header block"); |
| |
| // Make sure that no other blocks (except the entry and loop continue block) |
| // branches to the loop header block. |
| for (auto &block : llvm::make_range(std::next(region.begin(), 2), |
| std::prev(region.end(), 2))) { |
| for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) { |
| if (block.getSuccessor(i) == &header) { |
| return emitOpError("can only have the entry and loop continue " |
| "block branching to the loop header block"); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| Block *LoopOp::getEntryBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| return &getBody().front(); |
| } |
| |
| Block *LoopOp::getHeaderBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| // The second block is the loop header block. |
| return &*std::next(getBody().begin()); |
| } |
| |
| Block *LoopOp::getContinueBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| // The second to last block is the loop continue block. |
| return &*std::prev(getBody().end(), 2); |
| } |
| |
| Block *LoopOp::getMergeBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| // The last block is the loop merge block. |
| return &getBody().back(); |
| } |
| |
| void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { |
| assert(getBody().empty() && "entry and merge block already exist"); |
| OpBuilder::InsertionGuard g(builder); |
| builder.createBlock(&getBody()); |
| builder.createBlock(&getBody()); |
| |
| // Add a spirv.mlir.merge op into the merge block. |
| builder.create<spirv::MergeOp>(getLoc()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.mlir.merge |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MergeOp::verify() { |
| auto *parentOp = (*this)->getParentOp(); |
| if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp)) |
| return emitOpError( |
| "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); |
| |
| // TODO: This check should be done in `verifyRegions` of parent op. |
| Block &parentLastBlock = (*this)->getParentRegion()->back(); |
| if (getOperation() != parentLastBlock.getTerminator()) |
| return emitOpError("can only be used in the last block of " |
| "'spirv.mlir.selection' or 'spirv.mlir.loop'"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.Return |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnOp::verify() { |
| // Verification is performed in spirv.func op. |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ReturnValue |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnValueOp::verify() { |
| // Verification is performed in spirv.func op. |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.Select |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult SelectOp::verify() { |
| if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) { |
| auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType()); |
| if (!resultVectorTy) { |
| return emitOpError("result expected to be of vector type when " |
| "condition is of vector type"); |
| } |
| if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { |
| return emitOpError("result should have the same number of elements as " |
| "the condition when condition is of vector type"); |
| } |
| } |
| return success(); |
| } |
| |
| // Custom availability implementation is needed for spirv.Select given the |
| // syntax changes starting v1.4. |
| SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() { |
| return {}; |
| } |
| SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() { |
| return {}; |
| } |
| std::optional<spirv::Version> SelectOp::getMinVersion() { |
| // Per the spec, "Before version 1.4, results are only computed per |
| // component." |
| if (isa<spirv::ScalarType>(getCondition().getType()) && |
| isa<spirv::CompositeType>(getType())) |
| return Version::V_1_4; |
| |
| return Version::V_1_0; |
| } |
| std::optional<spirv::Version> SelectOp::getMaxVersion() { |
| return Version::V_1_6; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.mlir.selection |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) { |
| if (parseControlAttribute<spirv::SelectionControlAttr, |
| spirv::SelectionControl>(parser, result)) |
| return failure(); |
| return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); |
| } |
| |
| void SelectionOp::print(OpAsmPrinter &printer) { |
| auto control = getSelectionControl(); |
| if (control != spirv::SelectionControl::None) |
| printer << " control(" << spirv::stringifySelectionControl(control) << ")"; |
| printer << ' '; |
| printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| LogicalResult SelectionOp::verifyRegions() { |
| auto *op = getOperation(); |
| |
| // We need to verify that the blocks follow the following layout: |
| // |
| // +--------------+ |
| // | header block | |
| // +--------------+ |
| // / | \ |
| // ... |
| // |
| // |
| // +---------+ +---------+ +---------+ |
| // | case #0 | | case #1 | | case #2 | ... |
| // +---------+ +---------+ +---------+ |
| // |
| // |
| // ... |
| // \ | / |
| // v |
| // +-------------+ |
| // | merge block | |
| // +-------------+ |
| |
| auto ®ion = op->getRegion(0); |
| // Allow empty region as a degenerated case, which can come from |
| // optimizations. |
| if (region.empty()) |
| return success(); |
| |
| // The last block is the merge block. |
| if (!isMergeBlock(region.back())) |
| return emitOpError("last block must be the merge block with only one " |
| "'spirv.mlir.merge' op"); |
| |
| if (std::next(region.begin()) == region.end()) |
| return emitOpError("must have a selection header block"); |
| |
| return success(); |
| } |
| |
| Block *SelectionOp::getHeaderBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| // The first block is the loop header block. |
| return &getBody().front(); |
| } |
| |
| Block *SelectionOp::getMergeBlock() { |
| assert(!getBody().empty() && "op region should not be empty!"); |
| // The last block is the loop merge block. |
| return &getBody().back(); |
| } |
| |
| void SelectionOp::addMergeBlock(OpBuilder &builder) { |
| assert(getBody().empty() && "entry and merge block already exist"); |
| OpBuilder::InsertionGuard guard(builder); |
| builder.createBlock(&getBody()); |
| |
| // Add a spirv.mlir.merge op into the merge block. |
| builder.create<spirv::MergeOp>(getLoc()); |
| } |
| |
| SelectionOp |
| SelectionOp::createIfThen(Location loc, Value condition, |
| function_ref<void(OpBuilder &builder)> thenBody, |
| OpBuilder &builder) { |
| auto selectionOp = |
| builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); |
| |
| selectionOp.addMergeBlock(builder); |
| Block *mergeBlock = selectionOp.getMergeBlock(); |
| Block *thenBlock = nullptr; |
| |
| // Build the "then" block. |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| thenBlock = builder.createBlock(mergeBlock); |
| thenBody(builder); |
| builder.create<spirv::BranchOp>(loc, mergeBlock); |
| } |
| |
| // Build the header block. |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.createBlock(thenBlock); |
| builder.create<spirv::BranchConditionalOp>( |
| loc, condition, thenBlock, |
| /*trueArguments=*/ArrayRef<Value>(), mergeBlock, |
| /*falseArguments=*/ArrayRef<Value>()); |
| } |
| |
| return selectionOp; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.Unreachable |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::UnreachableOp::verify() { |
| auto *block = (*this)->getBlock(); |
| // Fast track: if this is in entry block, its invalid. Otherwise, if no |
| // predecessors, it's valid. |
| if (block->isEntryBlock()) |
| return emitOpError("cannot be used in reachable block"); |
| if (block->hasNoPredecessors()) |
| return success(); |
| |
| // TODO: further verification needs to analyze reachability from |
| // the entry block. |
| |
| return success(); |
| } |
| |
| } // namespace mlir::spirv |