blob: c4c7ff722175dce4384321f80bb31b6be37b9f7f [file] [log] [blame]
//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the memory operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Casting.h"
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
/// Parses optional memory access (a.k.a. memory operand) attributes attached to
/// a memory access operand/pointer. Specifically, parses the following syntax:
/// (`[` memory-access `]`)?
/// where:
/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
/// integer-literal | `"NonTemporal"`
template <typename MemoryOpTy>
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
if (parser.parseOptionalLSquare()) {
// Nothing to do
return success();
}
spirv::MemoryAccess memoryAccessAttr;
StringAttr memoryAccessAttrName =
MemoryOpTy::getMemoryAccessAttrName(state.name);
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
}
return parser.parseRSquare();
}
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
if (parser.parseOptionalLSquare()) {
// Nothing to do
return success();
}
spirv::MemoryAccess memoryAccessAttr;
StringRef memoryAccessAttrName =
MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
StringAttr alignmentAttrName =
MemoryOpTy::getSourceAlignmentAttrName(state.name);
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
}
return parser.parseRSquare();
}
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static void printSourceMemoryAccessAttribute(
MemoryOpTy memoryOp, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs,
std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
printer << ", ";
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
printer << ", " << *alignment;
}
}
printer << "]";
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}
template <typename MemoryOpTy>
static void printMemoryAccessAttribute(
MemoryOpTy memoryOp, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs,
std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
printer << ", " << *alignment;
}
}
printer << "]";
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}
template <typename LoadStoreOpTy>
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
Value val) {
// ODS already checks ptr is spirv::PointerType. Just check that the pointee
// type of the pointer and the type of the value are the same
//
// TODO: Check that the value type satisfies restrictions of
// SPIR-V OpLoad/OpStore operations
if (val.getType() !=
llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
}
template <typename MemoryOpTy>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// ODS checks for attributes values. Just need to verify that if the
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
}
return success();
}
auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
<< memAccessAttr;
}
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
}
}
return success();
}
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// ODS checks for attributes values. Just need to verify that if the
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
}
return success();
}
auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
<< memAccess;
}
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.AccessChainOp
//===----------------------------------------------------------------------===//
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
"to composite type, but provided ")
<< type;
return nullptr;
}
auto resultType = ptrType.getPointeeType();
auto resultStorageClass = ptrType.getStorageClass();
int32_t index = 0;
for (auto indexSSA : indices) {
auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
if (!cType) {
emitError(
baseLoc,
"'spirv.AccessChain' op cannot extract from non-composite type ")
<< resultType << " with index " << index;
return nullptr;
}
index = 0;
if (llvm::isa<spirv::StructType>(resultType)) {
Operation *op = indexSSA.getDefiningOp();
if (!op) {
emitError(baseLoc, "'spirv.AccessChain' op index must be an "
"integer spirv.Constant to access "
"element of spirv.struct");
return nullptr;
}
// TODO: this should be relaxed to allow
// integer literals of other bitwidths.
if (failed(spirv::extractValueFromConstOp(op, index))) {
emitError(
baseLoc,
"'spirv.AccessChain' index must be an integer spirv.Constant to "
"access element of spirv.struct, but provided ")
<< op->getName();
return nullptr;
}
if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
emitError(baseLoc, "'spirv.AccessChain' op index ")
<< index << " out of bounds for " << resultType;
return nullptr;
}
}
resultType = cType.getElementType(index);
}
return spirv::PointerType::get(resultType, resultStorageClass);
}
void AccessChainOp::build(OpBuilder &builder, OperationState &state,
Value basePtr, ValueRange indices) {
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, indices);
}
ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand ptrInfo;
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
Type type;
auto loc = parser.getCurrentLocation();
SmallVector<Type, 4> indicesTypes;
if (parser.parseOperand(ptrInfo) ||
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(type) ||
parser.resolveOperand(ptrInfo, type, result.operands)) {
return failure();
}
// Check that the provided indices list is not empty before parsing their
// type list.
if (indicesInfo.empty()) {
return mlir::emitError(result.location,
"'spirv.AccessChain' op expected at "
"least one index ");
}
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
return failure();
// Check that the indices types list is not empty and that it has a one-to-one
// mapping to the provided indices.
if (indicesTypes.size() != indicesInfo.size()) {
return mlir::emitError(
result.location, "'spirv.AccessChain' op indices types' count must be "
"equal to indices info count");
}
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
return failure();
auto resultType = getElementPtrType(
type, llvm::ArrayRef(result.operands).drop_front(), result.location);
if (!resultType) {
return failure();
}
result.addTypes(resultType);
return success();
}
template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
printer << ' ' << op.getBasePtr() << '[' << indices
<< "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
}
void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
printAccessChain(*this, getIndices(), printer);
}
template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
indices, accessChainOp.getLoc());
if (!resultType)
return failure();
auto providedResultType =
llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
if (!providedResultType)
return accessChainOp.emitOpError(
"result type must be a pointer, but provided")
<< providedResultType;
if (resultType != providedResultType)
return accessChainOp.emitOpError("invalid result type: expected ")
<< resultType << ", but provided " << providedResultType;
return success();
}
LogicalResult AccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
// spirv.LoadOp
//===----------------------------------------------------------------------===//
void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
alignment);
}
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
OpAsmParser::UnresolvedOperand ptrInfo;
Type elementType;
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
parseMemoryAccessAttributes<LoadOp>(parser, result) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
return failure();
}
result.addTypes(elementType);
return success();
}
void LoadOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
printer << " : " << getType();
}
LogicalResult LoadOp::verify() {
// SPIR-V spec : "Result Type is the type of the loaded object. It must be a
// type with fixed size; i.e., it cannot be, nor include, any
// OpTypeRuntimeArray types."
if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
return failure();
}
return verifyMemoryAccessAttribute(*this);
}
//===----------------------------------------------------------------------===//
// spirv.StoreOp
//===----------------------------------------------------------------------===//
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
auto loc = parser.getCurrentLocation();
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) ||
parseMemoryAccessAttributes<StoreOp>(parser, result) ||
parser.parseColon() || parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
result.operands)) {
return failure();
}
return success();
}
void StoreOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
printer << " : " << getValue().getType();
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
}
LogicalResult StoreOp::verify() {
// SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
// OpTypePointer whose Type operand is the same as the type of Object."
if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return verifyMemoryAccessAttribute(*this);
}
//===----------------------------------------------------------------------===//
// spirv.CopyMemory
//===----------------------------------------------------------------------===//
void CopyMemoryOp::print(OpAsmPrinter &printer) {
printer << ' ';
StringRef targetStorageClass = stringifyStorageClass(
llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
StringRef sourceStorageClass = stringifyStorageClass(
llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
printer << " \"" << sourceStorageClass << "\" " << getSource();
SmallVector<StringRef, 4> elidedAttrs;
printMemoryAccessAttribute(*this, printer, elidedAttrs);
printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
getSourceMemoryAccess(),
getSourceAlignment());
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
Type pointeeType =
llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
printer << " : " << pointeeType;
}
ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
spirv::StorageClass targetStorageClass;
OpAsmParser::UnresolvedOperand targetPtrInfo;
spirv::StorageClass sourceStorageClass;
OpAsmParser::UnresolvedOperand sourcePtrInfo;
Type elementType;
if (parseEnumStrAttr(targetStorageClass, parser) ||
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
if (!parser.parseOptionalComma()) {
// Parse 2nd memory access attributes.
if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
}
if (parser.parseColon() || parser.parseType(elementType))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
return failure();
}
return success();
}
LogicalResult CopyMemoryOp::verify() {
Type targetType =
llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
Type sourceType =
llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
if (targetType != sourceType)
return emitOpError("both operands must be pointers to the same type");
if (failed(verifyMemoryAccessAttribute(*this)))
return failure();
// TODO - According to the spec:
//
// If two masks are present, the first applies to Target and cannot include
// MakePointerVisible, and the second applies to Source and cannot include
// MakePointerAvailable.
//
// Add such verification here.
return verifySourceMemoryAccessAttribute(*this);
}
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
OpAsmParser &parser,
OperationState &state) {
OpAsmParser::UnresolvedOperand ptrInfo;
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
Type type;
auto loc = parser.getCurrentLocation();
SmallVector<Type, 4> indicesTypes;
if (parser.parseOperand(ptrInfo) ||
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(type) ||
parser.resolveOperand(ptrInfo, type, state.operands))
return failure();
// Check that the provided indices list is not empty before parsing their
// type list.
if (indicesInfo.empty())
return emitError(state.location) << opName << " expected element";
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
return failure();
// Check that the indices types list is not empty and that it has a one-to-one
// mapping to the provided indices.
if (indicesTypes.size() != indicesInfo.size())
return emitError(state.location)
<< opName
<< " indices types' count must be equal to indices info count";
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
return failure();
auto resultType = getElementPtrType(
type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
if (!resultType)
return failure();
state.addTypes(resultType);
return success();
}
template <typename Op>
static auto concatElemAndIndices(Op op) {
SmallVector<Value> ret(op.getIndices().size() + 1);
ret[0] = op.getElement();
llvm::copy(op.getIndices(), ret.begin() + 1);
return ret;
}
//===----------------------------------------------------------------------===//
// spirv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//
void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
Value basePtr, Value element,
ValueRange indices) {
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, element, indices);
}
ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
OperationState &result) {
return parsePtrAccessChainOpImpl(
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
}
void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
printAccessChain(*this, concatElemAndIndices(*this), printer);
}
LogicalResult InBoundsPtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
// spirv.PtrAccessChainOp
//===----------------------------------------------------------------------===//
void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
Value basePtr, Value element, ValueRange indices) {
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, element, indices);
}
ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
OperationState &result) {
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
parser, result);
}
void PtrAccessChainOp::print(OpAsmPrinter &printer) {
printAccessChain(*this, concatElemAndIndices(*this), printer);
}
LogicalResult PtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
// spirv.Variable
//===----------------------------------------------------------------------===//
ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse optional initializer
std::optional<OpAsmParser::UnresolvedOperand> initInfo;
if (succeeded(parser.parseOptionalKeyword("init"))) {
initInfo = OpAsmParser::UnresolvedOperand();
if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
parser.parseRParen())
return failure();
}
if (parseVariableDecorations(parser, result)) {
return failure();
}
// Parse result pointer type
Type type;
if (parser.parseColon())
return failure();
auto loc = parser.getCurrentLocation();
if (parser.parseType(type))
return failure();
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected spirv.ptr type");
result.addTypes(ptrType);
// Resolve the initializer operand
if (initInfo) {
if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
result.operands))
return failure();
}
auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
ptrType.getStorageClass());
result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
return success();
}
void VariableOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
// Print optional initializer
if (getNumOperands() != 0)
printer << " init(" << getInitializer() << ")";
printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
LogicalResult VariableOp::verify() {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
if (getStorageClass() != spirv::StorageClass::Function) {
return emitOpError(
"can only be used to model function-level variables. Use "
"spirv.GlobalVariable for module-level variables.");
}
auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
if (getStorageClass() != pointerType.getStorageClass())
return emitOpError(
"storage class must match result pointer's storage class");
if (getNumOperands() != 0) {
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
// a global (module scope) OpVariable instruction".
auto *initOp = getOperand(0).getDefiningOp();
if (!initOp || !isa<spirv::ConstantOp, // for normal constant
spirv::ReferenceOfOp, // for spec constant
spirv::AddressOfOp>(initOp))
return emitOpError("initializer must be the result of a "
"constant or spirv.GlobalVariable op");
}
auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
return op->getAttr(
llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
};
// TODO: generate these strings using ODS.
for (auto decoration :
{spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
spirv::Decoration::BuiltIn}) {
if (auto attr = getDecorationAttr(decoration))
return emitOpError("cannot have '")
<< llvm::convertToSnakeFromCamelCase(
stringifyDecoration(decoration))
<< "' attribute (only allowed in spirv.GlobalVariable)";
}
// From SPV_KHR_physical_storage_buffer:
// > If an OpVariable's pointee type is a pointer (or array of pointers) in
// > PhysicalStorageBuffer storage class, then the variable must be decorated
// > with exactly one of AliasedPointer or RestrictPointer.
auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
if (!pointeePtrType) {
if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
pointeePtrType =
dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
}
}
if (pointeePtrType && pointeePtrType.getStorageClass() ==
spirv::StorageClass::PhysicalStorageBuffer) {
bool hasAliasedPtr =
getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
bool hasRestrictPtr =
getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
if (!hasAliasedPtr && !hasRestrictPtr)
return emitOpError() << " with physical buffer pointer must be decorated "
"either 'AliasedPointer' or 'RestrictPointer'";
if (hasAliasedPtr && hasRestrictPtr)
return emitOpError()
<< " with physical buffer pointer must have exactly one "
"aliasing decoration";
}
return success();
}
} // namespace mlir::spirv