blob: b14bdd152b58b282392b4e82354ac3ec867c7b25 [file] [log] [blame]
//===- DecorateSPIRVCompositeTypeLayoutPass.cpp - Decorate composite type -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to decorate the composite types used by
// composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
// PushConstant storage classes with layout information. See SPIR-V spec
// "2.16.2. Validation Rules for Shader Capabilities" for more details.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class SPIRVGlobalVariableOpLayoutInfoDecoration
: public OpRewritePattern<spirv::GlobalVariableOp> {
public:
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
auto ptrType = op.type().cast<spirv::PointerType>();
auto structType = VulkanLayoutUtils::decorateType(
ptrType.getPointeeType().cast<spirv::StructType>());
auto decoratedType =
spirv::PointerType::get(structType, ptrType.getStorageClass());
// Save all named attributes except "type" attribute.
for (const auto &attr : op.getAttrs()) {
if (attr.first == "type") {
continue;
}
globalVarAttrs.push_back(attr);
}
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
op, TypeAttr::get(decoratedType), globalVarAttrs);
return success();
}
};
class SPIRVAddressOfOpLayoutInfoDecoration
: public OpRewritePattern<spirv::AddressOfOp> {
public:
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
auto spirvModule = op.getParentOfType<spirv::ModuleOp>();
auto varName = op.variable();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, varOp.type(), rewriter.getSymbolRefAttr(varName));
return success();
}
};
} // namespace
static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<SPIRVGlobalVariableOpLayoutInfoDecoration,
SPIRVAddressOfOpLayoutInfoDecoration>(ctx);
}
namespace {
class DecorateSPIRVCompositeTypeLayoutPass
: public SPIRVCompositeTypeLayoutBase<
DecorateSPIRVCompositeTypeLayoutPass> {
void runOnOperation() override;
};
} // namespace
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
OwningRewritePatternList patterns;
populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
[](spirv::GlobalVariableOp op) {
return VulkanLayoutUtils::isLegalType(op.type());
});
// Change the type for the direct users.
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
});
// TODO: Change the type for the indirect users such as spv.Load, spv.Store,
// spv.FunctionCall and so on.
for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
if (failed(applyFullConversion(spirvModule, target, patterns))) {
signalPassFailure();
}
}
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass() {
return std::make_unique<DecorateSPIRVCompositeTypeLayoutPass>();
}