| //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the operations in the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| |
| #include "SPIRVOpUtils.h" |
| #include "SPIRVParsingUtils.h" |
| |
| #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include <cassert> |
| #include <numeric> |
| #include <optional> |
| #include <type_traits> |
| |
| using namespace mlir; |
| using namespace mlir::spirv::AttrNames; |
| |
| //===----------------------------------------------------------------------===// |
| // Common utility functions |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { |
| auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op); |
| if (!constOp) { |
| return failure(); |
| } |
| auto valueAttr = constOp.getValue(); |
| auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr); |
| if (!integerValueAttr) { |
| return failure(); |
| } |
| |
| if (integerValueAttr.getType().isSignlessInteger()) |
| value = integerValueAttr.getInt(); |
| else |
| value = integerValueAttr.getSInt(); |
| |
| return success(); |
| } |
| |
| LogicalResult |
| spirv::verifyMemorySemantics(Operation *op, |
| spirv::MemorySemantics memorySemantics) { |
| // According to the SPIR-V specification: |
| // "Despite being a mask and allowing multiple bits to be combined, it is |
| // invalid for more than one of these four bits to be set: Acquire, Release, |
| // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and |
| // Release semantics is done by setting the AcquireRelease bit, not by setting |
| // two bits." |
| auto atMostOneInSet = spirv::MemorySemantics::Acquire | |
| spirv::MemorySemantics::Release | |
| spirv::MemorySemantics::AcquireRelease | |
| spirv::MemorySemantics::SequentiallyConsistent; |
| |
| auto bitCount = |
| llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet)); |
| if (bitCount > 1) { |
| return op->emitError( |
| "expected at most one of these four memory constraints " |
| "to be set: `Acquire`, `Release`," |
| "`AcquireRelease` or `SequentiallyConsistent`"); |
| } |
| return success(); |
| } |
| |
| void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer, |
| SmallVectorImpl<StringRef> &elidedAttrs) { |
| // Print optional descriptor binding |
| auto descriptorSetName = llvm::convertToSnakeFromCamelCase( |
| stringifyDecoration(spirv::Decoration::DescriptorSet)); |
| auto bindingName = llvm::convertToSnakeFromCamelCase( |
| stringifyDecoration(spirv::Decoration::Binding)); |
| auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); |
| auto binding = op->getAttrOfType<IntegerAttr>(bindingName); |
| if (descriptorSet && binding) { |
| elidedAttrs.push_back(descriptorSetName); |
| elidedAttrs.push_back(bindingName); |
| printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() |
| << ")"; |
| } |
| |
| // Print BuiltIn attribute if present |
| auto builtInName = llvm::convertToSnakeFromCamelCase( |
| stringifyDecoration(spirv::Decoration::BuiltIn)); |
| if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { |
| printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; |
| elidedAttrs.push_back(builtInName); |
| } |
| |
| printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); |
| } |
| |
| static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> ops; |
| Type type; |
| // If the operand list is in-between parentheses, then we have a generic form. |
| // (see the fallback in `printOneResultOp`). |
| SMLoc loc = parser.getCurrentLocation(); |
| if (!parser.parseOptionalLParen()) { |
| if (parser.parseOperandList(ops) || parser.parseRParen() || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColon() || parser.parseType(type)) |
| return failure(); |
| auto fnType = llvm::dyn_cast<FunctionType>(type); |
| if (!fnType) { |
| parser.emitError(loc, "expected function type"); |
| return failure(); |
| } |
| if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) |
| return failure(); |
| result.addTypes(fnType.getResults()); |
| return success(); |
| } |
| return failure(parser.parseOperandList(ops) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperands(ops, type, result.operands) || |
| parser.addTypeToList(type, result.types)); |
| } |
| |
| static void printOneResultOp(Operation *op, OpAsmPrinter &p) { |
| assert(op->getNumResults() == 1 && "op should have one result"); |
| |
| // If not all the operand and result types are the same, just use the |
| // generic assembly form to avoid omitting information in printing. |
| auto resultType = op->getResult(0).getType(); |
| if (llvm::any_of(op->getOperandTypes(), |
| [&](Type type) { return type != resultType; })) { |
| p.printGenericOp(op, /*printOpName=*/false); |
| return; |
| } |
| |
| p << ' '; |
| p.printOperands(op->getOperands()); |
| p.printOptionalAttrDict(op->getAttrs()); |
| // Now we can output only one type for all operands and the result. |
| p << " : " << resultType; |
| } |
| |
| template <typename Op> |
| static LogicalResult verifyImageOperands(Op imageOp, |
| spirv::ImageOperandsAttr attr, |
| Operation::operand_range operands) { |
| if (!attr) { |
| if (operands.empty()) |
| return success(); |
| |
| return imageOp.emitError("the Image Operands should encode what operands " |
| "follow, as per Image Operands"); |
| } |
| |
| // TODO: Add the validation rules for the following Image Operands. |
| spirv::ImageOperands noSupportOperands = |
| spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | |
| spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | |
| spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | |
| spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | |
| spirv::ImageOperands::MakeTexelAvailable | |
| spirv::ImageOperands::MakeTexelVisible | |
| spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; |
| |
| if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) |
| llvm_unreachable("unimplemented operands of Image Operands"); |
| |
| return success(); |
| } |
| |
| template <typename BlockReadWriteOpTy> |
| static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, |
| Value ptr, Value val) { |
| auto valType = val.getType(); |
| if (auto valVecTy = llvm::dyn_cast<VectorType>(valType)) |
| valType = valVecTy.getElementType(); |
| |
| if (valType != |
| llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) { |
| return op.emitOpError("mismatch in result type and pointer type"); |
| } |
| return success(); |
| } |
| |
| /// Walks the given type hierarchy with the given indices, potentially down |
| /// to component granularity, to select an element type. Returns null type and |
| /// emits errors with the given loc on failure. |
| static Type |
| getElementType(Type type, ArrayRef<int32_t> indices, |
| function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { |
| if (indices.empty()) { |
| emitErrorFn("expected at least one index for spirv.CompositeExtract"); |
| return nullptr; |
| } |
| |
| for (auto index : indices) { |
| if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) { |
| if (cType.hasCompileTimeKnownNumElements() && |
| (index < 0 || |
| static_cast<uint64_t>(index) >= cType.getNumElements())) { |
| emitErrorFn("index ") << index << " out of bounds for " << type; |
| return nullptr; |
| } |
| type = cType.getElementType(index); |
| } else { |
| emitErrorFn("cannot extract from non-composite type ") |
| << type << " with index " << index; |
| return nullptr; |
| } |
| } |
| return type; |
| } |
| |
| static Type |
| getElementType(Type type, Attribute indices, |
| function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { |
| auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices); |
| if (!indicesArrayAttr) { |
| emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); |
| return nullptr; |
| } |
| if (indicesArrayAttr.empty()) { |
| emitErrorFn("expected at least one index for spirv.CompositeExtract"); |
| return nullptr; |
| } |
| |
| SmallVector<int32_t, 2> indexVals; |
| for (auto indexAttr : indicesArrayAttr) { |
| auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr); |
| if (!indexIntAttr) { |
| emitErrorFn("expected an 32-bit integer for index, but found '") |
| << indexAttr << "'"; |
| return nullptr; |
| } |
| indexVals.push_back(indexIntAttr.getInt()); |
| } |
| return getElementType(type, indexVals, emitErrorFn); |
| } |
| |
| static Type getElementType(Type type, Attribute indices, Location loc) { |
| auto errorFn = [&](StringRef err) -> InFlightDiagnostic { |
| return ::mlir::emitError(loc, err); |
| }; |
| return getElementType(type, indices, errorFn); |
| } |
| |
| static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, |
| SMLoc loc) { |
| auto errorFn = [&](StringRef err) -> InFlightDiagnostic { |
| return parser.emitError(loc, err); |
| }; |
| return getElementType(type, indices, errorFn); |
| } |
| |
| template <typename ExtendedBinaryOp> |
| static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { |
| auto resultType = llvm::cast<spirv::StructType>(op.getType()); |
| if (resultType.getNumElements() != 2) |
| return op.emitOpError("expected result struct type containing two members"); |
| |
| if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), |
| resultType.getElementType(0), |
| resultType.getElementType(1)})) |
| return op.emitOpError( |
| "expected all operand types and struct member types are the same"); |
| |
| return success(); |
| } |
| |
| static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; |
| if (parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseOperandList(operands) || parser.parseColon()) |
| return failure(); |
| |
| Type resultType; |
| SMLoc loc = parser.getCurrentLocation(); |
| if (parser.parseType(resultType)) |
| return failure(); |
| |
| auto structType = llvm::dyn_cast<spirv::StructType>(resultType); |
| if (!structType || structType.getNumElements() != 2) |
| return parser.emitError(loc, "expected spirv.struct type with two members"); |
| |
| SmallVector<Type, 2> operandTypes(2, structType.getElementType(0)); |
| if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) |
| return failure(); |
| |
| result.addTypes(resultType); |
| return success(); |
| } |
| |
| static void printArithmeticExtendedBinaryOp(Operation *op, |
| OpAsmPrinter &printer) { |
| printer << ' '; |
| printer.printOptionalAttrDict(op->getAttrs()); |
| printer.printOperands(op->getOperands()); |
| printer << " : " << op->getResultTypes().front(); |
| } |
| |
| static LogicalResult verifyShiftOp(Operation *op) { |
| if (op->getOperand(0).getType() != op->getResult(0).getType()) { |
| return op->emitError("expected the same type for the first operand and " |
| "result, but provided ") |
| << op->getOperand(0).getType() << " and " |
| << op->getResult(0).getType(); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.mlir.addressof |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, |
| spirv::GlobalVariableOp var) { |
| build(builder, state, var.getType(), SymbolRefAttr::get(var)); |
| } |
| |
| LogicalResult spirv::AddressOfOp::verify() { |
| auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>( |
| SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), |
| getVariableAttr())); |
| if (!varOp) { |
| return emitOpError("expected spirv.GlobalVariable symbol"); |
| } |
| if (getPointer().getType() != varOp.getType()) { |
| return emitOpError( |
| "result type mismatch with the referenced global variable's type"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.CompositeConstruct |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::CompositeConstructOp::verify() { |
| operand_range constituents = this->getConstituents(); |
| |
| // There are 4 cases with varying verification rules: |
| // 1. Cooperative Matrices (1 constituent) |
| // 2. Structs (1 constituent for each member) |
| // 3. Arrays (1 constituent for each array element) |
| // 4. Vectors (1 constituent (sub-)element for each vector element) |
| |
| auto coopElementType = |
| llvm::TypeSwitch<Type, Type>(getType()) |
| .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>( |
| [](auto coopType) { return coopType.getElementType(); }) |
| .Default([](Type) { return nullptr; }); |
| |
| // Case 1. -- matrices. |
| if (coopElementType) { |
| if (constituents.size() != 1) |
| return emitOpError("has incorrect number of operands: expected ") |
| << "1, but provided " << constituents.size(); |
| if (coopElementType != constituents.front().getType()) |
| return emitOpError("operand type mismatch: expected operand type ") |
| << coopElementType << ", but provided " |
| << constituents.front().getType(); |
| return success(); |
| } |
| |
| // Case 2./3./4. -- number of constituents matches the number of elements. |
| auto cType = llvm::cast<spirv::CompositeType>(getType()); |
| if (constituents.size() == cType.getNumElements()) { |
| for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { |
| if (constituents[index].getType() != cType.getElementType(index)) { |
| return emitOpError("operand type mismatch: expected operand type ") |
| << cType.getElementType(index) << ", but provided " |
| << constituents[index].getType(); |
| } |
| } |
| return success(); |
| } |
| |
| // Case 4. -- check that all constituents add up tp the expected vector type. |
| auto resultType = llvm::dyn_cast<VectorType>(cType); |
| if (!resultType) |
| return emitOpError( |
| "expected to return a vector or cooperative matrix when the number of " |
| "constituents is less than what the result needs"); |
| |
| SmallVector<unsigned> sizes; |
| for (Value component : constituents) { |
| if (!llvm::isa<VectorType>(component.getType()) && |
| !component.getType().isIntOrFloat()) |
| return emitOpError("operand type mismatch: expected operand to have " |
| "a scalar or vector type, but provided ") |
| << component.getType(); |
| |
| Type elementType = component.getType(); |
| if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) { |
| sizes.push_back(vectorType.getNumElements()); |
| elementType = vectorType.getElementType(); |
| } else { |
| sizes.push_back(1); |
| } |
| |
| if (elementType != resultType.getElementType()) |
| return emitOpError("operand element type mismatch: expected to be ") |
| << resultType.getElementType() << ", but provided " << elementType; |
| } |
| unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); |
| if (totalCount != cType.getNumElements()) |
| return emitOpError("has incorrect number of operands: expected ") |
| << cType.getNumElements() << ", but provided " << totalCount; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.CompositeExtractOp |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, |
| Value composite, |
| ArrayRef<int32_t> indices) { |
| auto indexAttr = builder.getI32ArrayAttr(indices); |
| auto elementType = |
| getElementType(composite.getType(), indexAttr, state.location); |
| if (!elementType) { |
| return; |
| } |
| build(builder, state, elementType, composite, indexAttr); |
| } |
| |
| ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::UnresolvedOperand compositeInfo; |
| Attribute indicesAttr; |
| StringRef indicesAttrName = |
| spirv::CompositeExtractOp::getIndicesAttrName(result.name); |
| Type compositeType; |
| SMLoc attrLocation; |
| |
| if (parser.parseOperand(compositeInfo) || |
| parser.getCurrentLocation(&attrLocation) || |
| parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) || |
| parser.parseColonType(compositeType) || |
| parser.resolveOperand(compositeInfo, compositeType, result.operands)) { |
| return failure(); |
| } |
| |
| Type resultType = |
| getElementType(compositeType, indicesAttr, parser, attrLocation); |
| if (!resultType) { |
| return failure(); |
| } |
| result.addTypes(resultType); |
| return success(); |
| } |
| |
| void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { |
| printer << ' ' << getComposite() << getIndices() << " : " |
| << getComposite().getType(); |
| } |
| |
| LogicalResult spirv::CompositeExtractOp::verify() { |
| auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices()); |
| auto resultType = |
| getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); |
| if (!resultType) |
| return failure(); |
| |
| if (resultType != getType()) { |
| return emitOpError("invalid result type: expected ") |
| << resultType << " but provided " << getType(); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.CompositeInsert |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, |
| Value object, Value composite, |
| ArrayRef<int32_t> indices) { |
| auto indexAttr = builder.getI32ArrayAttr(indices); |
| build(builder, state, composite.getType(), object, composite, indexAttr); |
| } |
| |
| ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; |
| Type objectType, compositeType; |
| Attribute indicesAttr; |
| StringRef indicesAttrName = |
| spirv::CompositeInsertOp::getIndicesAttrName(result.name); |
| auto loc = parser.getCurrentLocation(); |
| |
| return failure( |
| parser.parseOperandList(operands, 2) || |
| parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) || |
| parser.parseColonType(objectType) || |
| parser.parseKeywordType("into", compositeType) || |
| parser.resolveOperands(operands, {objectType, compositeType}, loc, |
| result.operands) || |
| parser.addTypesToList(compositeType, result.types)); |
| } |
| |
| LogicalResult spirv::CompositeInsertOp::verify() { |
| auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices()); |
| auto objectType = |
| getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); |
| if (!objectType) |
| return failure(); |
| |
| if (objectType != getObject().getType()) { |
| return emitOpError("object operand type should be ") |
| << objectType << ", but found " << getObject().getType(); |
| } |
| |
| if (getComposite().getType() != getType()) { |
| return emitOpError("result type should be the same as " |
| "the composite type, but found ") |
| << getComposite().getType() << " vs " << getType(); |
| } |
| |
| return success(); |
| } |
| |
| void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { |
| printer << " " << getObject() << ", " << getComposite() << getIndices() |
| << " : " << getObject().getType() << " into " |
| << getComposite().getType(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.Constant |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| Attribute value; |
| StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name); |
| if (parser.parseAttribute(value, valueAttrName, result.attributes)) |
| return failure(); |
| |
| Type type = NoneType::get(parser.getContext()); |
| if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value)) |
| type = typedAttr.getType(); |
| if (llvm::isa<NoneType, TensorType>(type)) { |
| if (parser.parseColonType(type)) |
| return failure(); |
| } |
| |
| return parser.addTypeToList(type, result.types); |
| } |
| |
| void spirv::ConstantOp::print(OpAsmPrinter &printer) { |
| printer << ' ' << getValue(); |
| if (llvm::isa<spirv::ArrayType>(getType())) |
| printer << " : " << getType(); |
| } |
| |
| static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, |
| Type opType) { |
| if (llvm::isa<IntegerAttr, FloatAttr>(value)) { |
| auto valueType = llvm::cast<TypedAttr>(value).getType(); |
| if (valueType != opType) |
| return op.emitOpError("result type (") |
| << opType << ") does not match value type (" << valueType << ")"; |
| return success(); |
| } |
| if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) { |
| auto valueType = llvm::cast<TypedAttr>(value).getType(); |
| if (valueType == opType) |
| return success(); |
| auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType); |
| auto shapedType = llvm::dyn_cast<ShapedType>(valueType); |
| if (!arrayType) |
| return op.emitOpError("result or element type (") |
| << opType << ") does not match value type (" << valueType |
| << "), must be the same or spirv.array"; |
| |
| int numElements = arrayType.getNumElements(); |
| auto opElemType = arrayType.getElementType(); |
| while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) { |
| numElements *= t.getNumElements(); |
| opElemType = t.getElementType(); |
| } |
| if (!opElemType.isIntOrFloat()) |
| return op.emitOpError("only support nested array result type"); |
| |
| auto valueElemType = shapedType.getElementType(); |
| if (valueElemType != opElemType) { |
| return op.emitOpError("result element type (") |
| << opElemType << ") does not match value element type (" |
| << valueElemType << ")"; |
| } |
| |
| if (numElements != shapedType.getNumElements()) { |
| return op.emitOpError("result number of elements (") |
| << numElements << ") does not match value number of elements (" |
| << shapedType.getNumElements() << ")"; |
| } |
| return success(); |
| } |
| if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) { |
| auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType); |
| if (!arrayType) |
| return op.emitOpError( |
| "must have spirv.array result type for array value"); |
| Type elemType = arrayType.getElementType(); |
| for (Attribute element : arrayAttr.getValue()) { |
| // Verify array elements recursively. |
| if (failed(verifyConstantType(op, element, elemType))) |
| return failure(); |
| } |
| return success(); |
| } |
| return op.emitOpError("cannot have attribute: ") << value; |
| } |
| |
| LogicalResult spirv::ConstantOp::verify() { |
| // ODS already generates checks to make sure the result type is valid. We just |
| // need to additionally check that the value's attribute type is consistent |
| // with the result type. |
| return verifyConstantType(*this, getValueAttr(), getType()); |
| } |
| |
| bool spirv::ConstantOp::isBuildableWith(Type type) { |
| // Must be valid SPIR-V type first. |
| if (!llvm::isa<spirv::SPIRVType>(type)) |
| return false; |
| |
| if (isa<SPIRVDialect>(type.getDialect())) { |
| // TODO: support constant struct |
| return llvm::isa<spirv::ArrayType>(type); |
| } |
| |
| return true; |
| } |
| |
| spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, |
| OpBuilder &builder) { |
| if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
| unsigned width = intType.getWidth(); |
| if (width == 1) |
| return builder.create<spirv::ConstantOp>(loc, type, |
| builder.getBoolAttr(false)); |
| return builder.create<spirv::ConstantOp>( |
| loc, type, builder.getIntegerAttr(type, APInt(width, 0))); |
| } |
| if (auto floatType = llvm::dyn_cast<FloatType>(type)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, builder.getFloatAttr(floatType, 0.0)); |
| } |
| if (auto vectorType = llvm::dyn_cast<VectorType>(type)) { |
| Type elemType = vectorType.getElementType(); |
| if (llvm::isa<IntegerType>(elemType)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, |
| DenseElementsAttr::get(vectorType, |
| IntegerAttr::get(elemType, 0).getValue())); |
| } |
| if (llvm::isa<FloatType>(elemType)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, |
| DenseFPElementsAttr::get(vectorType, |
| FloatAttr::get(elemType, 0.0).getValue())); |
| } |
| } |
| |
| llvm_unreachable("unimplemented types for ConstantOp::getZero()"); |
| } |
| |
| spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, |
| OpBuilder &builder) { |
| if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
| unsigned width = intType.getWidth(); |
| if (width == 1) |
| return builder.create<spirv::ConstantOp>(loc, type, |
| builder.getBoolAttr(true)); |
| return builder.create<spirv::ConstantOp>( |
| loc, type, builder.getIntegerAttr(type, APInt(width, 1))); |
| } |
| if (auto floatType = llvm::dyn_cast<FloatType>(type)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, builder.getFloatAttr(floatType, 1.0)); |
| } |
| if (auto vectorType = llvm::dyn_cast<VectorType>(type)) { |
| Type elemType = vectorType.getElementType(); |
| if (llvm::isa<IntegerType>(elemType)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, |
| DenseElementsAttr::get(vectorType, |
| IntegerAttr::get(elemType, 1).getValue())); |
| } |
| if (llvm::isa<FloatType>(elemType)) { |
| return builder.create<spirv::ConstantOp>( |
| loc, type, |
| DenseFPElementsAttr::get(vectorType, |
| FloatAttr::get(elemType, 1.0).getValue())); |
| } |
| } |
| |
| llvm_unreachable("unimplemented types for ConstantOp::getOne()"); |
| } |
| |
| void mlir::spirv::ConstantOp::getAsmResultNames( |
| llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { |
| Type type = getType(); |
| |
| SmallString<32> specialNameBuffer; |
| llvm::raw_svector_ostream specialName(specialNameBuffer); |
| specialName << "cst"; |
| |
| IntegerType intTy = llvm::dyn_cast<IntegerType>(type); |
| |
| if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { |
| if (intTy && intTy.getWidth() == 1) { |
| return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); |
| } |
| |
| if (intTy.isSignless()) { |
| specialName << intCst.getInt(); |
| } else if (intTy.isUnsigned()) { |
| specialName << intCst.getUInt(); |
| } else { |
| specialName << intCst.getSInt(); |
| } |
| } |
| |
| if (intTy || llvm::isa<FloatType>(type)) { |
| specialName << '_' << type; |
| } |
| |
| if (auto vecType = llvm::dyn_cast<VectorType>(type)) { |
| specialName << "_vec_"; |
| specialName << vecType.getDimSize(0); |
| |
| Type elementType = vecType.getElementType(); |
| |
| if (llvm::isa<IntegerType>(elementType) || |
| llvm::isa<FloatType>(elementType)) { |
| specialName << "x" << elementType; |
| } |
| } |
| |
| setNameFn(getResult(), specialName.str()); |
| } |
| |
| void mlir::spirv::AddressOfOp::getAsmResultNames( |
| llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { |
| SmallString<32> specialNameBuffer; |
| llvm::raw_svector_ostream specialName(specialNameBuffer); |
| specialName << getVariable() << "_addr"; |
| setNameFn(getResult(), specialName.str()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ControlBarrierOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ControlBarrierOp::verify() { |
| return verifyMemorySemantics(getOperation(), getMemorySemantics()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.EntryPoint |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, |
| spirv::ExecutionModel executionModel, |
| spirv::FuncOp function, |
| ArrayRef<Attribute> interfaceVars) { |
| build(builder, state, |
| spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), |
| SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); |
| } |
| |
| ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| spirv::ExecutionModel execModel; |
| SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers; |
| SmallVector<Type, 0> idTypes; |
| SmallVector<Attribute, 4> interfaceVars; |
| |
| FlatSymbolRefAttr fn; |
| if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) || |
| parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { |
| return failure(); |
| } |
| |
| if (!parser.parseOptionalComma()) { |
| // Parse the interface variables |
| if (parser.parseCommaSeparatedList([&]() -> ParseResult { |
| // The name of the interface variable attribute isnt important |
| FlatSymbolRefAttr var; |
| NamedAttrList attrs; |
| if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) |
| return failure(); |
| interfaceVars.push_back(var); |
| return success(); |
| })) |
| return failure(); |
| } |
| result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name), |
| parser.getBuilder().getArrayAttr(interfaceVars)); |
| return success(); |
| } |
| |
| void spirv::EntryPointOp::print(OpAsmPrinter &printer) { |
| printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; |
| printer.printSymbolName(getFn()); |
| auto interfaceVars = getInterface().getValue(); |
| if (!interfaceVars.empty()) { |
| printer << ", "; |
| llvm::interleaveComma(interfaceVars, printer); |
| } |
| } |
| |
| LogicalResult spirv::EntryPointOp::verify() { |
| // Checks for fn and interface symbol reference are done in spirv::ModuleOp |
| // verification. |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ExecutionMode |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, |
| spirv::FuncOp function, |
| spirv::ExecutionMode executionMode, |
| ArrayRef<int32_t> params) { |
| build(builder, state, SymbolRefAttr::get(function), |
| spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), |
| builder.getI32ArrayAttr(params)); |
| } |
| |
| ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| spirv::ExecutionMode execMode; |
| Attribute fn; |
| if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) || |
| parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) { |
| return failure(); |
| } |
| |
| SmallVector<int32_t, 4> values; |
| Type i32Type = parser.getBuilder().getIntegerType(32); |
| while (!parser.parseOptionalComma()) { |
| NamedAttrList attr; |
| Attribute value; |
| if (parser.parseAttribute(value, i32Type, "value", attr)) { |
| return failure(); |
| } |
| values.push_back(llvm::cast<IntegerAttr>(value).getInt()); |
| } |
| StringRef valuesAttrName = |
| spirv::ExecutionModeOp::getValuesAttrName(result.name); |
| result.addAttribute(valuesAttrName, |
| parser.getBuilder().getI32ArrayAttr(values)); |
| return success(); |
| } |
| |
| void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { |
| printer << " "; |
| printer.printSymbolName(getFn()); |
| printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; |
| auto values = this->getValues(); |
| if (values.empty()) |
| return; |
| printer << ", "; |
| llvm::interleaveComma(values, printer, [&](Attribute a) { |
| printer << llvm::cast<IntegerAttr>(a).getInt(); |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.func |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::Argument> entryArgs; |
| SmallVector<DictionaryAttr> resultAttrs; |
| SmallVector<Type> resultTypes; |
| auto &builder = parser.getBuilder(); |
| |
| // Parse the name as a symbol. |
| StringAttr nameAttr; |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes)) |
| return failure(); |
| |
| // Parse the function signature. |
| bool isVariadic = false; |
| if (function_interface_impl::parseFunctionSignature( |
| parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
| resultAttrs)) |
| return failure(); |
| |
| SmallVector<Type> argTypes; |
| for (auto &arg : entryArgs) |
| argTypes.push_back(arg.type); |
| auto fnType = builder.getFunctionType(argTypes, resultTypes); |
| result.addAttribute(getFunctionTypeAttrName(result.name), |
| TypeAttr::get(fnType)); |
| |
| // Parse the optional function control keyword. |
| spirv::FunctionControl fnControl; |
| if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result)) |
| return failure(); |
| |
| // If additional attributes are present, parse them. |
| if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| return failure(); |
| |
| // Add the attributes to the function arguments. |
| assert(resultAttrs.size() == resultTypes.size()); |
| function_interface_impl::addArgAndResultAttrs( |
| builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), |
| getResAttrsAttrName(result.name)); |
| |
| // Parse the optional function body. |
| auto *body = result.addRegion(); |
| OptionalParseResult parseResult = |
| parser.parseOptionalRegion(*body, entryArgs); |
| return failure(parseResult.has_value() && failed(*parseResult)); |
| } |
| |
| void spirv::FuncOp::print(OpAsmPrinter &printer) { |
| // Print function name, signature, and control. |
| printer << " "; |
| printer.printSymbolName(getSymName()); |
| auto fnType = getFunctionType(); |
| function_interface_impl::printFunctionSignature( |
| printer, *this, fnType.getInputs(), |
| /*isVariadic=*/false, fnType.getResults()); |
| printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) |
| << "\""; |
| function_interface_impl::printFunctionAttributes( |
| printer, *this, |
| {spirv::attributeName<spirv::FunctionControl>(), |
| getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), |
| getFunctionControlAttrName()}); |
| |
| // Print the body if this is not an external function. |
| Region &body = this->getBody(); |
| if (!body.empty()) { |
| printer << ' '; |
| printer.printRegion(body, /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| } |
| |
| LogicalResult spirv::FuncOp::verifyType() { |
| FunctionType fnType = getFunctionType(); |
| if (fnType.getNumResults() > 1) |
| return emitOpError("cannot have more than one result"); |
| |
| auto hasDecorationAttr = [&](spirv::Decoration decoration, |
| unsigned argIndex) { |
| auto func = llvm::cast<FunctionOpInterface>(getOperation()); |
| for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) { |
| if (argAttr.getName() != spirv::DecorationAttr::name) |
| continue; |
| if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue())) |
| return decAttr.getValue() == decoration; |
| } |
| return false; |
| }; |
| |
| for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) { |
| Type param = fnType.getInputs()[i]; |
| auto inputPtrType = dyn_cast<spirv::PointerType>(param); |
| if (!inputPtrType) |
| continue; |
| |
| auto pointeePtrType = |
| dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType()); |
| if (pointeePtrType) { |
| // SPIR-V spec, from SPV_KHR_physical_storage_buffer: |
| // > If an OpFunctionParameter is a pointer (or contains a pointer) |
| // > and the type it points to is a pointer in the PhysicalStorageBuffer |
| // > storage class, the function parameter must be decorated with exactly |
| // > one of AliasedPointer or RestrictPointer. |
| if (pointeePtrType.getStorageClass() != |
| spirv::StorageClass::PhysicalStorageBuffer) |
| continue; |
| |
| bool hasAliasedPtr = |
| hasDecorationAttr(spirv::Decoration::AliasedPointer, i); |
| bool hasRestrictPtr = |
| hasDecorationAttr(spirv::Decoration::RestrictPointer, i); |
| if (!hasAliasedPtr && !hasRestrictPtr) |
| return emitOpError() |
| << "with a pointer points to a physical buffer pointer must " |
| "be decorated either 'AliasedPointer' or 'RestrictPointer'"; |
| continue; |
| } |
| // SPIR-V spec, from SPV_KHR_physical_storage_buffer: |
| // > If an OpFunctionParameter is a pointer (or contains a pointer) in |
| // > the PhysicalStorageBuffer storage class, the function parameter must |
| // > be decorated with exactly one of Aliased or Restrict. |
| if (auto pointeeArrayType = |
| dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) { |
| pointeePtrType = |
| dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType()); |
| } else { |
| pointeePtrType = inputPtrType; |
| } |
| |
| if (!pointeePtrType || pointeePtrType.getStorageClass() != |
| spirv::StorageClass::PhysicalStorageBuffer) |
| continue; |
| |
| bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i); |
| bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i); |
| if (!hasAliased && !hasRestrict) |
| return emitOpError() << "with physical buffer pointer must be decorated " |
| "either 'Aliased' or 'Restrict'"; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult spirv::FuncOp::verifyBody() { |
| FunctionType fnType = getFunctionType(); |
| |
| auto walkResult = walk([fnType](Operation *op) -> WalkResult { |
| if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) { |
| if (fnType.getNumResults() != 0) |
| return retOp.emitOpError("cannot be used in functions returning value"); |
| } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) { |
| if (fnType.getNumResults() != 1) |
| return retOp.emitOpError( |
| "returns 1 value but enclosing function requires ") |
| << fnType.getNumResults() << " results"; |
| |
| auto retOperandType = retOp.getValue().getType(); |
| auto fnResultType = fnType.getResult(0); |
| if (retOperandType != fnResultType) |
| return retOp.emitOpError(" return value's type (") |
| << retOperandType << ") mismatch with function's result type (" |
| << fnResultType << ")"; |
| } |
| return WalkResult::advance(); |
| }); |
| |
| // TODO: verify other bits like linkage type. |
| |
| return failure(walkResult.wasInterrupted()); |
| } |
| |
| void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, |
| StringRef name, FunctionType type, |
| spirv::FunctionControl control, |
| ArrayRef<NamedAttribute> attrs) { |
| state.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
| state.addAttribute(spirv::attributeName<spirv::FunctionControl>(), |
| builder.getAttr<spirv::FunctionControlAttr>(control)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| state.addRegion(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GLFClampOp |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return parseOneResultSameOperandTypeOp(parser, result); |
| } |
| void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GLUClampOp |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return parseOneResultSameOperandTypeOp(parser, result); |
| } |
| void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GLSClampOp |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return parseOneResultSameOperandTypeOp(parser, result); |
| } |
| void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GLFmaOp |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { |
| return parseOneResultSameOperandTypeOp(parser, result); |
| } |
| void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GlobalVariable |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, |
| Type type, StringRef name, |
| unsigned descriptorSet, unsigned binding) { |
| build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); |
| state.addAttribute( |
| spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), |
| builder.getI32IntegerAttr(descriptorSet)); |
| state.addAttribute( |
| spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), |
| builder.getI32IntegerAttr(binding)); |
| } |
| |
| void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, |
| Type type, StringRef name, |
| spirv::BuiltIn builtin) { |
| build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); |
| state.addAttribute( |
| spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), |
| builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); |
| } |
| |
| ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| // Parse variable name. |
| StringAttr nameAttr; |
| StringRef initializerAttrName = |
| spirv::GlobalVariableOp::getInitializerAttrName(result.name); |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes)) { |
| return failure(); |
| } |
| |
| // Parse optional initializer |
| if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) { |
| FlatSymbolRefAttr initSymbol; |
| if (parser.parseLParen() || |
| parser.parseAttribute(initSymbol, Type(), initializerAttrName, |
| result.attributes) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| if (parseVariableDecorations(parser, result)) { |
| return failure(); |
| } |
| |
| Type type; |
| StringRef typeAttrName = |
| spirv::GlobalVariableOp::getTypeAttrName(result.name); |
| auto loc = parser.getCurrentLocation(); |
| if (parser.parseColonType(type)) { |
| return failure(); |
| } |
| if (!llvm::isa<spirv::PointerType>(type)) { |
| return parser.emitError(loc, "expected spirv.ptr type"); |
| } |
| result.addAttribute(typeAttrName, TypeAttr::get(type)); |
| |
| return success(); |
| } |
| |
| void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { |
| SmallVector<StringRef, 4> elidedAttrs{ |
| spirv::attributeName<spirv::StorageClass>()}; |
| |
| // Print variable name. |
| printer << ' '; |
| printer.printSymbolName(getSymName()); |
| elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); |
| |
| StringRef initializerAttrName = this->getInitializerAttrName(); |
| // Print optional initializer |
| if (auto initializer = this->getInitializer()) { |
| printer << " " << initializerAttrName << '('; |
| printer.printSymbolName(*initializer); |
| printer << ')'; |
| elidedAttrs.push_back(initializerAttrName); |
| } |
| |
| StringRef typeAttrName = this->getTypeAttrName(); |
| elidedAttrs.push_back(typeAttrName); |
| spirv::printVariableDecorations(*this, printer, elidedAttrs); |
| printer << " : " << getType(); |
| } |
| |
| LogicalResult spirv::GlobalVariableOp::verify() { |
| if (!llvm::isa<spirv::PointerType>(getType())) |
| return emitOpError("result must be of a !spv.ptr type"); |
| |
| // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the |
| // object. It cannot be Generic. It must be the same as the Storage Class |
| // operand of the Result Type." |
| // Also, Function storage class is reserved by spirv.Variable. |
| auto storageClass = this->storageClass(); |
| if (storageClass == spirv::StorageClass::Generic || |
| storageClass == spirv::StorageClass::Function) { |
| return emitOpError("storage class cannot be '") |
| << stringifyStorageClass(storageClass) << "'"; |
| } |
| |
| if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>( |
| this->getInitializerAttrName())) { |
| Operation *initOp = SymbolTable::lookupNearestSymbolFrom( |
| (*this)->getParentOp(), init.getAttr()); |
| // TODO: Currently only variable initialization with specialization |
| // constants and other variables is supported. They could be normal |
| // constants in the module scope as well. |
| if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, |
| spirv::SpecConstantCompositeOp>(initOp)) { |
| return emitOpError("initializer must be result of a " |
| "spirv.SpecConstant or spirv.GlobalVariable or " |
| "spirv.SpecConstantCompositeOp op"); |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.INTEL.SubgroupBlockRead |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| // Parse the storage class specification |
| spirv::StorageClass storageClass; |
| OpAsmParser::UnresolvedOperand ptrInfo; |
| Type elementType; |
| if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || |
| parser.parseColon() || parser.parseType(elementType)) { |
| return failure(); |
| } |
| |
| auto ptrType = spirv::PointerType::get(elementType, storageClass); |
| if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType)) |
| ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); |
| |
| if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { |
| return failure(); |
| } |
| |
| result.addTypes(elementType); |
| return success(); |
| } |
| |
| void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { |
| printer << " " << getPtr() << " : " << getType(); |
| } |
| |
| LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { |
| if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.INTEL.SubgroupBlockWrite |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| // Parse the storage class specification |
| spirv::StorageClass storageClass; |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; |
| auto loc = parser.getCurrentLocation(); |
| Type elementType; |
| if (parseEnumStrAttr(storageClass, parser) || |
| parser.parseOperandList(operandInfo, 2) || parser.parseColon() || |
| parser.parseType(elementType)) { |
| return failure(); |
| } |
| |
| auto ptrType = spirv::PointerType::get(elementType, storageClass); |
| if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType)) |
| ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); |
| |
| if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, |
| result.operands)) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { |
| printer << " " << getPtr() << ", " << getValue() << " : " |
| << getValue().getType(); |
| } |
| |
| LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { |
| if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.IAddCarryOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::IAddCarryOp::verify() { |
| return ::verifyArithmeticExtendedBinaryOp(*this); |
| } |
| |
| ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return ::parseArithmeticExtendedBinaryOp(parser, result); |
| } |
| |
| void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { |
| ::printArithmeticExtendedBinaryOp(*this, printer); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ISubBorrowOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ISubBorrowOp::verify() { |
| return ::verifyArithmeticExtendedBinaryOp(*this); |
| } |
| |
| ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return ::parseArithmeticExtendedBinaryOp(parser, result); |
| } |
| |
| void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { |
| ::printArithmeticExtendedBinaryOp(*this, printer); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.SMulExtended |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::SMulExtendedOp::verify() { |
| return ::verifyArithmeticExtendedBinaryOp(*this); |
| } |
| |
| ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return ::parseArithmeticExtendedBinaryOp(parser, result); |
| } |
| |
| void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { |
| ::printArithmeticExtendedBinaryOp(*this, printer); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.UMulExtended |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::UMulExtendedOp::verify() { |
| return ::verifyArithmeticExtendedBinaryOp(*this); |
| } |
| |
| ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return ::parseArithmeticExtendedBinaryOp(parser, result); |
| } |
| |
| void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { |
| ::printArithmeticExtendedBinaryOp(*this, printer); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.MemoryBarrierOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::MemoryBarrierOp::verify() { |
| return verifyMemorySemantics(getOperation(), getMemorySemantics()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.module |
| //===----------------------------------------------------------------------===// |
| |
| void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, |
| std::optional<StringRef> name) { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.createBlock(state.addRegion()); |
| if (name) { |
| state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(*name)); |
| } |
| } |
| |
| void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, |
| spirv::AddressingModel addressingModel, |
| spirv::MemoryModel memoryModel, |
| std::optional<VerCapExtAttr> vceTriple, |
| std::optional<StringRef> name) { |
| state.addAttribute( |
| "addressing_model", |
| builder.getAttr<spirv::AddressingModelAttr>(addressingModel)); |
| state.addAttribute("memory_model", |
| builder.getAttr<spirv::MemoryModelAttr>(memoryModel)); |
| OpBuilder::InsertionGuard guard(builder); |
| builder.createBlock(state.addRegion()); |
| if (vceTriple) |
| state.addAttribute(getVCETripleAttrName(), *vceTriple); |
| if (name) |
| state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(*name)); |
| } |
| |
| ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| Region *body = result.addRegion(); |
| |
| // If the name is present, parse it. |
| StringAttr nameAttr; |
| (void)parser.parseOptionalSymbolName( |
| nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes); |
| |
| // Parse attributes |
| spirv::AddressingModel addrModel; |
| spirv::MemoryModel memoryModel; |
| if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser, |
| result) || |
| spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser, |
| result)) |
| return failure(); |
| |
| if (succeeded(parser.parseOptionalKeyword("requires"))) { |
| spirv::VerCapExtAttr vceTriple; |
| if (parser.parseAttribute(vceTriple, |
| spirv::ModuleOp::getVCETripleAttrName(), |
| result.attributes)) |
| return failure(); |
| } |
| |
| if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || |
| parser.parseRegion(*body, /*arguments=*/{})) |
| return failure(); |
| |
| // Make sure we have at least one block. |
| if (body->empty()) |
| body->push_back(new Block()); |
| |
| return success(); |
| } |
| |
| void spirv::ModuleOp::print(OpAsmPrinter &printer) { |
| if (std::optional<StringRef> name = getName()) { |
| printer << ' '; |
| printer.printSymbolName(*name); |
| } |
| |
| SmallVector<StringRef, 2> elidedAttrs; |
| |
| printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " |
| << spirv::stringifyMemoryModel(getMemoryModel()); |
| auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); |
| auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); |
| elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, |
| mlir::SymbolTable::getSymbolAttrName()}); |
| |
| if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) { |
| printer << " requires " << *triple; |
| elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); |
| } |
| |
| printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); |
| printer << ' '; |
| printer.printRegion(getRegion()); |
| } |
| |
| LogicalResult spirv::ModuleOp::verifyRegions() { |
| Dialect *dialect = (*this)->getDialect(); |
| DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> |
| entryPoints; |
| mlir::SymbolTable table(*this); |
| |
| for (auto &op : *getBody()) { |
| if (op.getDialect() != dialect) |
| return op.emitError("'spirv.module' can only contain spirv.* ops"); |
| |
| // For EntryPoint op, check that the function and execution model is not |
| // duplicated in EntryPointOps. Also verify that the interface specified |
| // comes from globalVariables here to make this check cheaper. |
| if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) { |
| auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn()); |
| if (!funcOp) { |
| return entryPointOp.emitError("function '") |
| << entryPointOp.getFn() << "' not found in 'spirv.module'"; |
| } |
| if (auto interface = entryPointOp.getInterface()) { |
| for (Attribute varRef : interface) { |
| auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef); |
| if (!varSymRef) { |
| return entryPointOp.emitError( |
| "expected symbol reference for interface " |
| "specification instead of '") |
| << varRef; |
| } |
| auto variableOp = |
| table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); |
| if (!variableOp) { |
| return entryPointOp.emitError("expected spirv.GlobalVariable " |
| "symbol reference instead of'") |
| << varSymRef << "'"; |
| } |
| } |
| } |
| |
| auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>( |
| funcOp, entryPointOp.getExecutionModel()); |
| auto entryPtIt = entryPoints.find(key); |
| if (entryPtIt != entryPoints.end()) { |
| return entryPointOp.emitError("duplicate of a previous EntryPointOp"); |
| } |
| entryPoints[key] = entryPointOp; |
| } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) { |
| // If the function is external and does not have 'Import' |
| // linkage_attributes(LinkageAttributes), throw an error. 'Import' |
| // LinkageAttributes is used to import external functions. |
| auto linkageAttr = funcOp.getLinkageAttributes(); |
| auto hasImportLinkage = |
| linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
| spirv::LinkageType::Import); |
| if (funcOp.isExternal() && !hasImportLinkage) |
| return op.emitError( |
| "'spirv.module' cannot contain external functions " |
| "without 'Import' linkage_attributes (LinkageAttributes)"); |
| |
| // TODO: move this check to spirv.func. |
| for (auto &block : funcOp) |
| for (auto &op : block) { |
| if (op.getDialect() != dialect) |
| return op.emitError( |
| "functions in 'spirv.module' can only contain spirv.* ops"); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.mlir.referenceof |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ReferenceOfOp::verify() { |
| auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( |
| (*this)->getParentOp(), getSpecConstAttr()); |
| Type constType; |
| |
| auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym); |
| if (specConstOp) |
| constType = specConstOp.getDefaultValue().getType(); |
| |
| auto specConstCompositeOp = |
| dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym); |
| if (specConstCompositeOp) |
| constType = specConstCompositeOp.getType(); |
| |
| if (!specConstOp && !specConstCompositeOp) |
| return emitOpError( |
| "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol"); |
| |
| if (getReference().getType() != constType) |
| return emitOpError("result type mismatch with the referenced " |
| "specialization constant's type"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.SpecConstant |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| StringAttr nameAttr; |
| Attribute valueAttr; |
| StringRef defaultValueAttrName = |
| spirv::SpecConstantOp::getDefaultValueAttrName(result.name); |
| |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes)) |
| return failure(); |
| |
| // Parse optional spec_id. |
| if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { |
| IntegerAttr specIdAttr; |
| if (parser.parseLParen() || |
| parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| if (parser.parseEqual() || |
| parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { |
| printer << ' '; |
| printer.printSymbolName(getSymName()); |
| if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) |
| printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; |
| printer << " = " << getDefaultValue(); |
| } |
| |
| LogicalResult spirv::SpecConstantOp::verify() { |
| if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) |
| if (specID.getValue().isNegative()) |
| return emitOpError("SpecId cannot be negative"); |
| |
| auto value = getDefaultValue(); |
| if (llvm::isa<IntegerAttr, FloatAttr>(value)) { |
| // Make sure bitwidth is allowed. |
| if (!llvm::isa<spirv::SPIRVType>(value.getType())) |
| return emitOpError("default value bitwidth disallowed"); |
| return success(); |
| } |
| return emitOpError( |
| "default value can only be a bool, integer, or float scalar"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.VectorShuffle |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::VectorShuffleOp::verify() { |
| VectorType resultType = llvm::cast<VectorType>(getType()); |
| |
| size_t numResultElements = resultType.getNumElements(); |
| if (numResultElements != getComponents().size()) |
| return emitOpError("result type element count (") |
| << numResultElements |
| << ") mismatch with the number of component selectors (" |
| << getComponents().size() << ")"; |
| |
| size_t totalSrcElements = |
| llvm::cast<VectorType>(getVector1().getType()).getNumElements() + |
| llvm::cast<VectorType>(getVector2().getType()).getNumElements(); |
| |
| for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) { |
| uint32_t index = selector.getZExtValue(); |
| if (index >= totalSrcElements && |
| index != std::numeric_limits<uint32_t>().max()) |
| return emitOpError("component selector ") |
| << index << " out of range: expected to be in [0, " |
| << totalSrcElements << ") or 0xffffffff"; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.MatrixTimesScalar |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::MatrixTimesScalarOp::verify() { |
| Type elementType = |
| llvm::TypeSwitch<Type, Type>(getMatrix().getType()) |
| .Case<spirv::CooperativeMatrixType, spirv::MatrixType>( |
| [](auto matrixType) { return matrixType.getElementType(); }) |
| .Default([](Type) { return nullptr; }); |
| |
| assert(elementType && "Unhandled type"); |
| |
| // Check that the scalar type is the same as the matrix element type. |
| if (getScalar().getType() != elementType) |
| return emitOpError("input matrix components' type and scaling value must " |
| "have the same type"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.Transpose |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::TransposeOp::verify() { |
| auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType()); |
| auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType()); |
| |
| // Verify that the input and output matrices have correct shapes. |
| if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) |
| return emitError("input matrix rows count must be equal to " |
| "output matrix columns count"); |
| |
| if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) |
| return emitError("input matrix columns count must be equal to " |
| "output matrix rows count"); |
| |
| // Verify that the input and output matrices have the same component type |
| if (inputMatrix.getElementType() != resultMatrix.getElementType()) |
| return emitError("input and output matrices must have the same " |
| "component type"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.MatrixTimesMatrix |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::MatrixTimesMatrixOp::verify() { |
| auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType()); |
| auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType()); |
| auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType()); |
| |
| // left matrix columns' count and right matrix rows' count must be equal |
| if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) |
| return emitError("left matrix columns' count must be equal to " |
| "the right matrix rows' count"); |
| |
| // right and result matrices columns' count must be the same |
| if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) |
| return emitError( |
| "right and result matrices must have equal columns' count"); |
| |
| // right and result matrices component type must be the same |
| if (rightMatrix.getElementType() != resultMatrix.getElementType()) |
| return emitError("right and result matrices' component type must" |
| " be the same"); |
| |
| // left and result matrices component type must be the same |
| if (leftMatrix.getElementType() != resultMatrix.getElementType()) |
| return emitError("left and result matrices' component type" |
| " must be the same"); |
| |
| // left and result matrices rows count must be the same |
| if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) |
| return emitError("left and result matrices must have equal rows' count"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.SpecConstantComposite |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| |
| StringAttr compositeName; |
| if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), |
| result.attributes)) |
| return failure(); |
| |
| if (parser.parseLParen()) |
| return failure(); |
| |
| SmallVector<Attribute, 4> constituents; |
| |
| do { |
| // The name of the constituent attribute isn't important |
| const char *attrName = "spec_const"; |
| FlatSymbolRefAttr specConstRef; |
| NamedAttrList attrs; |
| |
| if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) |
| return failure(); |
| |
| constituents.push_back(specConstRef); |
| } while (!parser.parseOptionalComma()); |
| |
| if (parser.parseRParen()) |
| return failure(); |
| |
| StringAttr compositeSpecConstituentsName = |
| spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name); |
| result.addAttribute(compositeSpecConstituentsName, |
| parser.getBuilder().getArrayAttr(constituents)); |
| |
| Type type; |
| if (parser.parseColonType(type)) |
| return failure(); |
| |
| StringAttr typeAttrName = |
| spirv::SpecConstantCompositeOp::getTypeAttrName(result.name); |
| result.addAttribute(typeAttrName, TypeAttr::get(type)); |
| |
| return success(); |
| } |
| |
| void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { |
| printer << " "; |
| printer.printSymbolName(getSymName()); |
| printer << " ("; |
| auto constituents = this->getConstituents().getValue(); |
| |
| if (!constituents.empty()) |
| llvm::interleaveComma(constituents, printer); |
| |
| printer << ") : " << getType(); |
| } |
| |
| LogicalResult spirv::SpecConstantCompositeOp::verify() { |
| auto cType = llvm::dyn_cast<spirv::CompositeType>(getType()); |
| auto constituents = this->getConstituents().getValue(); |
| |
| if (!cType) |
| return emitError("result type must be a composite type, but provided ") |
| << getType(); |
| |
| if (llvm::isa<spirv::CooperativeMatrixType>(cType)) |
| return emitError("unsupported composite type ") << cType; |
| if (llvm::isa<spirv::JointMatrixINTELType>(cType)) |
| return emitError("unsupported composite type ") << cType; |
| if (constituents.size() != cType.getNumElements()) |
| return emitError("has incorrect number of operands: expected ") |
| << cType.getNumElements() << ", but provided " |
| << constituents.size(); |
| |
| for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { |
| auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]); |
| |
| auto constituentSpecConstOp = |
| dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom( |
| (*this)->getParentOp(), constituent.getAttr())); |
| |
| if (constituentSpecConstOp.getDefaultValue().getType() != |
| cType.getElementType(index)) |
| return emitError("has incorrect types of operands: expected ") |
| << cType.getElementType(index) << ", but provided " |
| << constituentSpecConstOp.getDefaultValue().getType(); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.SpecConstantOperation |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| Region *body = result.addRegion(); |
| |
| if (parser.parseKeyword("wraps")) |
| return failure(); |
| |
| body->push_back(new Block); |
| Block &block = body->back(); |
| Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); |
| |
| if (!wrappedOp) |
| return failure(); |
| |
| OpBuilder builder(parser.getContext()); |
| builder.setInsertionPointToEnd(&block); |
| builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0)); |
| result.location = wrappedOp->getLoc(); |
| |
| result.addTypes(wrappedOp->getResult(0).getType()); |
| |
| if (parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { |
| printer << " wraps "; |
| printer.printGenericOp(&getBody().front().front()); |
| } |
| |
| LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { |
| Block &block = getRegion().getBlocks().front(); |
| |
| if (block.getOperations().size() != 2) |
| return emitOpError("expected exactly 2 nested ops"); |
| |
| Operation &enclosedOp = block.getOperations().front(); |
| |
| if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>()) |
| return emitOpError("invalid enclosed op"); |
| |
| for (auto operand : enclosedOp.getOperands()) |
| if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp, |
| spirv::SpecConstantOperationOp>(operand.getDefiningOp())) |
| return emitOpError( |
| "invalid operand, must be defined by a constant operation"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GL.FrexpStruct |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::GLFrexpStructOp::verify() { |
| spirv::StructType structTy = |
| llvm::dyn_cast<spirv::StructType>(getResult().getType()); |
| |
| if (structTy.getNumElements() != 2) |
| return emitError("result type must be a struct type with two memebers"); |
| |
| Type significandTy = structTy.getElementType(0); |
| Type exponentTy = structTy.getElementType(1); |
| VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy); |
| IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy); |
| |
| Type operandTy = getOperand().getType(); |
| VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy); |
| FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy); |
| |
| if (significandTy != operandTy) |
| return emitError("member zero of the resulting struct type must be the " |
| "same type as the operand"); |
| |
| if (exponentVecTy) { |
| IntegerType componentIntTy = |
| llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType()); |
| if (!componentIntTy || componentIntTy.getWidth() != 32) |
| return emitError("member one of the resulting struct type must" |
| "be a scalar or vector of 32 bit integer type"); |
| } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { |
| return emitError("member one of the resulting struct type " |
| "must be a scalar or vector of 32 bit integer type"); |
| } |
| |
| // Check that the two member types have the same number of components |
| if (operandVecTy && exponentVecTy && |
| (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) |
| return success(); |
| |
| if (operandFTy && exponentIntTy) |
| return success(); |
| |
| return emitError("member one of the resulting struct type must have the same " |
| "number of components as the operand type"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GL.Ldexp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::GLLdexpOp::verify() { |
| Type significandType = getX().getType(); |
| Type exponentType = getExp().getType(); |
| |
| if (llvm::isa<FloatType>(significandType) != |
| llvm::isa<IntegerType>(exponentType)) |
| return emitOpError("operands must both be scalars or vectors"); |
| |
| auto getNumElements = [](Type type) -> unsigned { |
| if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
| return vectorType.getNumElements(); |
| return 1; |
| }; |
| |
| if (getNumElements(significandType) != getNumElements(exponentType)) |
| return emitOpError("operands must have the same number of elements"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ImageDrefGather |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ImageDrefGatherOp::verify() { |
| VectorType resultType = llvm::cast<VectorType>(getResult().getType()); |
| auto sampledImageType = |
| llvm::cast<spirv::SampledImageType>(getSampledimage().getType()); |
| auto imageType = |
| llvm::cast<spirv::ImageType>(sampledImageType.getImageType()); |
| |
| if (resultType.getNumElements() != 4) |
| return emitOpError("result type must be a vector of four components"); |
| |
| Type elementType = resultType.getElementType(); |
| Type sampledElementType = imageType.getElementType(); |
| if (!llvm::isa<NoneType>(sampledElementType) && |
| elementType != sampledElementType) |
| return emitOpError( |
| "the component type of result must be the same as sampled type of the " |
| "underlying image type"); |
| |
| spirv::Dim imageDim = imageType.getDim(); |
| spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); |
| |
| if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && |
| imageDim != spirv::Dim::Rect) |
| return emitOpError( |
| "the Dim operand of the underlying image type must be 2D, Cube, or " |
| "Rect"); |
| |
| if (imageMS != spirv::ImageSamplingInfo::SingleSampled) |
| return emitOpError("the MS operand of the underlying image type must be 0"); |
| |
| spirv::ImageOperandsAttr attr = getImageoperandsAttr(); |
| auto operandArguments = getOperandArguments(); |
| |
| return verifyImageOperands(*this, attr, operandArguments); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ShiftLeftLogicalOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ShiftLeftLogicalOp::verify() { |
| return verifyShiftOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ShiftRightArithmeticOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ShiftRightArithmeticOp::verify() { |
| return verifyShiftOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ShiftRightLogicalOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ShiftRightLogicalOp::verify() { |
| return verifyShiftOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ImageQuerySize |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::ImageQuerySizeOp::verify() { |
| spirv::ImageType imageType = |
| llvm::cast<spirv::ImageType>(getImage().getType()); |
| Type resultType = getResult().getType(); |
| |
| spirv::Dim dim = imageType.getDim(); |
| spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); |
| spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); |
| switch (dim) { |
| case spirv::Dim::Dim1D: |
| case spirv::Dim::Dim2D: |
| case spirv::Dim::Dim3D: |
| case spirv::Dim::Cube: |
| if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && |
| samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && |
| samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) |
| return emitError( |
| "if Dim is 1D, 2D, 3D, or Cube, " |
| "it must also have either an MS of 1 or a Sampled of 0 or 2"); |
| break; |
| case spirv::Dim::Buffer: |
| case spirv::Dim::Rect: |
| break; |
| default: |
| return emitError("the Dim operand of the image type must " |
| "be 1D, 2D, 3D, Buffer, Cube, or Rect"); |
| } |
| |
| unsigned componentNumber = 0; |
| switch (dim) { |
| case spirv::Dim::Dim1D: |
| case spirv::Dim::Buffer: |
| componentNumber = 1; |
| break; |
| case spirv::Dim::Dim2D: |
| case spirv::Dim::Cube: |
| case spirv::Dim::Rect: |
| componentNumber = 2; |
| break; |
| case spirv::Dim::Dim3D: |
| componentNumber = 3; |
| break; |
| default: |
| break; |
| } |
| |
| if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) |
| componentNumber += 1; |
| |
| unsigned resultComponentNumber = 1; |
| if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType)) |
| resultComponentNumber = resultVectorType.getNumElements(); |
| |
| if (componentNumber != resultComponentNumber) |
| return emitError("expected the result to have ") |
| << componentNumber << " component(s), but found " |
| << resultComponentNumber << " component(s)"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.VectorTimesScalarOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::VectorTimesScalarOp::verify() { |
| if (getVector().getType() != getType()) |
| return emitOpError("vector operand and result type mismatch"); |
| auto scalarType = llvm::cast<VectorType>(getType()).getElementType(); |
| if (getScalar().getType() != scalarType) |
| return emitOpError("scalar operand and result element type match"); |
| return success(); |
| } |