blob: c4362015ec64300b4fa0360c19a5f78729ec32a0 [file] [log] [blame]
/* Copyright (c) 2024 LunarG, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pass.h"
#include <cstdint>
#include <spirv/unified1/spirv.hpp>
#include "generated/spirv_grammar_helper.h"
#include "instruction.h"
#include "module.h"
#include "gpu/shaders/gpuav_error_codes.h"
namespace gpuav {
namespace spirv {
const Variable& Pass::GetBuiltinVariable(uint32_t built_in) {
uint32_t variable_id = 0;
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(2) == spv::DecorationBuiltIn &&
annotation->Word(3) == built_in) {
variable_id = annotation->Word(1);
break;
}
}
if (variable_id == 0) {
variable_id = module_.TakeNextId();
auto new_inst = std::make_unique<Instruction>(4, spv::OpDecorate);
new_inst->Fill({variable_id, spv::DecorationBuiltIn, built_in});
module_.annotations_.emplace_back(std::move(new_inst));
}
// Currently we only ever needed Input variables and the built-ins we are using are not those that can be used by both Input and
// Output storage classes
const Variable* built_in_variable = module_.type_manager_.FindVariableById(variable_id);
if (!built_in_variable) {
const Type& pointer_type = module_.type_manager_.GetTypePointerBuiltInInput(spv::BuiltIn(built_in));
auto new_inst = std::make_unique<Instruction>(4, spv::OpVariable);
new_inst->Fill({pointer_type.Id(), variable_id, spv::StorageClassInput});
built_in_variable = &module_.type_manager_.AddVariable(std::move(new_inst), pointer_type);
module_.AddInterfaceVariables(built_in_variable->Id(), spv::StorageClassInput);
}
return *built_in_variable;
}
// To reduce having to load this information everytime we do a OpFunctionCall, instead just create it once per Function block and
// reference it each time
uint32_t Pass::GetStageInfo(Function& function, BasicBlockIt target_block_it, InstructionIt& target_inst_it) {
// Cached so only need to compute this once
if (function.stage_info_id_ != 0) {
return function.stage_info_id_;
}
BasicBlock& block = function.GetFirstBlock();
InstructionIt inst_it = block.GetFirstInjectableInstrution();
// Stage info is always passed in as a uvec4
const Type& uint32_type = module_.type_manager_.GetTypeInt(32, false);
const Type& uvec4_type = module_.type_manager_.GetTypeVector(uint32_type, 4);
const uint32_t uint32_0_id = module_.type_manager_.GetConstantZeroUint32().Id();
uint32_t stage_info[4] = {uint32_0_id, uint32_0_id, uint32_0_id, uint32_0_id};
if (module_.entry_points_.size() > 1) {
// For Multi Entry Points it currently a lot of work to scan every function to see where it will be called from
// For now we will just report it is "unknown" and skip printing that part of the error message
stage_info[0] = module_.type_manager_.GetConstantUInt32(glsl::kHeaderStageIdMultiEntryPoint).Id();
} else {
spv::ExecutionModel execution_model = spv::ExecutionModel(module_.entry_points_.begin()->get()->Operand(0));
stage_info[0] = module_.type_manager_.GetConstantUInt32(execution_model).Id();
// Gets BuiltIn variable and creates a valid OpLoad of it
auto create_load = [this, &block, &inst_it](spv::BuiltIn built_in) {
const Variable& variable = GetBuiltinVariable(built_in);
const Type* pointer_type = variable.PointerType(module_.type_manager_);
const uint32_t load_id = module_.TakeNextId();
block.CreateInstruction(spv::OpLoad, {pointer_type->Id(), load_id, variable.Id()}, &inst_it);
return load_id;
};
switch (execution_model) {
case spv::ExecutionModelVertex: {
uint32_t load_id = create_load(spv::BuiltInVertexIndex);
stage_info[1] = CastToUint32(load_id, block, &inst_it);
load_id = create_load(spv::BuiltInInstanceIndex);
stage_info[2] = CastToUint32(load_id, block, &inst_it);
} break;
case spv::ExecutionModelFragment: {
const uint32_t load_id = create_load(spv::BuiltInFragCoord);
// convert vec4 to uvec4
const uint32_t bitcast_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {uvec4_type.Id(), bitcast_id, load_id}, &inst_it);
for (uint32_t i = 0; i < 2; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, bitcast_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case spv::ExecutionModelRayGenerationKHR:
case spv::ExecutionModelIntersectionKHR:
case spv::ExecutionModelAnyHitKHR:
case spv::ExecutionModelClosestHitKHR:
case spv::ExecutionModelMissKHR:
case spv::ExecutionModelCallableKHR: {
const uint32_t load_id = create_load(spv::BuiltInLaunchIdKHR);
for (uint32_t i = 0; i < 3; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, load_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case spv::ExecutionModelGLCompute:
case spv::ExecutionModelTaskNV:
case spv::ExecutionModelMeshNV:
case spv::ExecutionModelTaskEXT:
case spv::ExecutionModelMeshEXT: {
const uint32_t load_id = create_load(spv::BuiltInGlobalInvocationId);
for (uint32_t i = 0; i < 3; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, load_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case spv::ExecutionModelGeometry: {
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[1] = CastToUint32(primitive_id, block, &inst_it);
const uint32_t load_id = create_load(spv::BuiltInInvocationId);
stage_info[2] = CastToUint32(load_id, block, &inst_it);
} break;
case spv::ExecutionModelTessellationControl: {
const uint32_t load_id = create_load(spv::BuiltInInvocationId);
stage_info[1] = CastToUint32(load_id, block, &inst_it);
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[2] = CastToUint32(primitive_id, block, &inst_it);
} break;
case spv::ExecutionModelTessellationEvaluation: {
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[1] = CastToUint32(primitive_id, block, &inst_it);
// convert vec3 to uvec3
const Type& vec3_type = module_.type_manager_.GetTypeVector(uint32_type, 3);
const uint32_t load_id = create_load(spv::BuiltInTessCoord);
const uint32_t bitcast_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {vec3_type.Id(), bitcast_id, load_id}, &inst_it);
// TessCoord.uv values from it
for (uint32_t i = 0; i < 2; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, bitcast_id, i}, &inst_it);
stage_info[i + 2] = extract_id;
}
} break;
default:
module_.InternalError(Name(), "GetStageInfo has unsupported stage");
break;
}
}
function.stage_info_id_ = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeConstruct,
{uvec4_type.Id(), function.stage_info_id_, stage_info[0], stage_info[1], stage_info[2], stage_info[3]},
&inst_it);
function.stage_info_x_id_ = stage_info[0];
function.stage_info_y_id_ = stage_info[1];
function.stage_info_z_id_ = stage_info[2];
function.stage_info_w_id_ = stage_info[3];
// because we are injecting things in the first block, there is a chance we just destroyed the iterator if the target
// instruction was also in the first block, so need to regain it for the caller
if ((*target_block_it)->GetLabelId() == block.GetLabelId()) {
target_inst_it = FindTargetInstruction(block);
}
return function.stage_info_id_;
}
const Instruction* Pass::GetDecoration(uint32_t id, spv::Decoration decoration) {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(1) == id &&
spv::Decoration(annotation->Word(2)) == decoration) {
return annotation.get();
}
}
return nullptr;
}
const Instruction* Pass::GetMemberDecoration(uint32_t id, uint32_t member_index, spv::Decoration decoration) {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpMemberDecorate && annotation->Word(1) == id && annotation->Word(2) == member_index &&
spv::Decoration(annotation->Word(3)) == decoration) {
return annotation.get();
}
}
return nullptr;
}
// In an ideal world, this would be baked into the Type class when we construct it. The core issue is OpTypeMatrix size can be
// different depending where it is used. Because of this, we need to have a higher level view what is going on in order to correctly
// figure out the size of a given type.
uint32_t Pass::FindTypeByteSize(uint32_t type_id, uint32_t matrix_stride, bool col_major, bool in_matrix) {
const Type& type = *module_.type_manager_.FindTypeById(type_id);
switch (type.spv_type_) {
case SpvType::kPointer:
return 8; // Assuming PhysicalStorageBuffer pointer
break;
case SpvType::kMatrix: {
if (matrix_stride == 0) {
module_.InternalError("FindTypeByteSize", "missing matrix stride");
}
if (col_major) {
return type.inst_.Word(3) * matrix_stride;
} else {
const Type* vector_type = module_.type_manager_.FindTypeById(type.inst_.Word(2));
return vector_type->inst_.Word(3) * matrix_stride;
}
}
case SpvType::kVector: {
uint32_t size = type.inst_.Word(3);
const Type* component_type = module_.type_manager_.FindTypeById(type.inst_.Word(2));
// if vector in row major matrix, the vector is strided so return the number of bytes spanned by the vector
if (in_matrix && !col_major && matrix_stride > 0) {
return (size - 1) * matrix_stride + FindTypeByteSize(component_type->Id());
} else if (component_type->spv_type_ == SpvType::kFloat || component_type->spv_type_ == SpvType::kInt) {
const uint32_t width = component_type->inst_.Word(2);
size *= width;
} else {
module_.InternalError("FindTypeByteSize", "unexpected vector type");
}
return size / 8;
}
case SpvType::kFloat:
case SpvType::kInt: {
const uint32_t width = type.inst_.Word(2);
return width / 8;
}
case SpvType::kArray: {
const uint32_t array_stride = GetDecoration(type_id, spv::DecorationArrayStride)->Word(3);
const Constant* count = module_.type_manager_.FindConstantById(type.inst_.Operand(1));
// TODO - Need to handle spec constant here, for now return one to have things not blowup
assert(count && !count->is_spec_constant_);
const uint32_t array_length = (count && !count->is_spec_constant_) ? count->inst_.Operand(0) : 1;
return array_length * array_stride;
}
case SpvType::kStruct: {
const uint32_t struct_length = type.inst_.Length() - 2;
const uint32_t struct_id = type.inst_.ResultId();
// We do our best to find the "size" of the struct (see https://gitlab.khronos.org/spirv/SPIR-V/-/issues/763)
uint32_t highest_element_index = 0;
uint32_t highest_element_offset = 0;
for (uint32_t i = 0; i < struct_length; i++) {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpMemberDecorate && annotation->Word(1) == struct_id &&
annotation->Word(2) == i && spv::Decoration(annotation->Word(3)) == spv::DecorationOffset) {
const uint32_t member_offset = annotation->Word(4);
if (member_offset > highest_element_offset) {
highest_element_index = i;
highest_element_offset = member_offset;
}
break;
}
}
}
const uint32_t last_offset_id = type.inst_.Operand(highest_element_index);
const Type* last_offset_type = module_.type_manager_.FindTypeById(last_offset_id);
uint32_t highest_element_size = 0;
if (last_offset_type->spv_type_ == SpvType::kMatrix) {
// TODO - We need a better way to handle Matrix at the end of structs
const Instruction* decoration_matrix_stride =
GetMemberDecoration(struct_id, highest_element_index, spv::DecorationMatrixStride);
matrix_stride = decoration_matrix_stride ? decoration_matrix_stride->Word(4) : 0;
const Instruction* decoration_col_major =
GetMemberDecoration(struct_id, highest_element_index, spv::DecorationColMajor);
col_major = decoration_col_major != nullptr;
highest_element_size = FindTypeByteSize(last_offset_id, matrix_stride, col_major, true);
} else {
highest_element_size = FindTypeByteSize(last_offset_id);
}
return highest_element_offset + highest_element_size;
}
default:
break;
}
return 1;
}
// Find outermost buffer type and its access chain index.
// Because access chains indexes can be runtime values, we need to build arithmetic logic in the SPIR-V to get the runtime value of
// the indexing
uint32_t Pass::GetLastByte(const Type& descriptor_type, std::vector<const Instruction*>& access_chain_insts, BasicBlock& block,
InstructionIt* inst_it) {
assert(!access_chain_insts.empty());
uint32_t current_type_id = 0;
const uint32_t reset_ac_word = 4; // points to first "Index" operand of an OpAccessChain
uint32_t ac_word_index = reset_ac_word;
if (descriptor_type.IsArray()) {
current_type_id = descriptor_type.inst_.Operand(0);
ac_word_index++; // this jumps over the array of descriptors so we first start on the descriptor itself
} else if (descriptor_type.spv_type_ == SpvType::kStruct) {
current_type_id = descriptor_type.Id();
} else {
module_.InternalError(Name(), "GetLastByte has unexpected descriptor type");
return 0;
}
const Type& uint32_type = module_.type_manager_.GetTypeInt(32, false);
// instruction that will have calculated the sum of the byte offset
uint32_t sum_id = 0;
uint32_t matrix_stride = 0;
bool col_major = false;
uint32_t matrix_stride_id = 0;
bool in_matrix = false;
// This loop gets use to the last element, so if we have something like
//
// Struct foo {
// uint a; // 4 bytes
// vec4 b; // 16 bytes
// float c; <--- accessing
// }
//
// it will get us to 20 bytes
auto access_chain_iter = access_chain_insts.rbegin();
// This occurs in things like Slang where they have a single OpAccessChain for the descriptor
// (GLSL/HLSL will combine 2 indexes into the last OpAccessChain)
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
while (access_chain_iter != access_chain_insts.rend()) {
const uint32_t ac_index_id = (*access_chain_iter)->Word(ac_word_index);
uint32_t current_offset_id = 0;
const Type* current_type = module_.type_manager_.FindTypeById(current_type_id);
switch (current_type->spv_type_) {
case SpvType::kArray:
case SpvType::kRuntimeArray: {
// Get array stride and multiply by current index
const uint32_t array_stride = GetDecoration(current_type_id, spv::DecorationArrayStride)->Word(3);
const uint32_t array_stride_id = module_.type_manager_.GetConstantUInt32(array_stride).Id();
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, array_stride_id, ac_index_id_32},
inst_it);
// Get element type for next step
current_type_id = current_type->inst_.Operand(0);
} break;
case SpvType::kMatrix: {
if (matrix_stride == 0) {
module_.InternalError(Name(), "GetLastByte is missing matrix stride");
}
matrix_stride_id = module_.type_manager_.GetConstantUInt32(matrix_stride).Id();
uint32_t vec_type_id = current_type->inst_.Operand(0);
// If column major, multiply column index by matrix stride, otherwise by vector component size and save matrix
// stride for vector (row) index
uint32_t col_stride_id = 0;
if (col_major) {
col_stride_id = matrix_stride_id;
} else {
const uint32_t component_type_id = module_.type_manager_.FindTypeById(vec_type_id)->inst_.Operand(0);
const uint32_t col_stride = FindTypeByteSize(component_type_id);
col_stride_id = module_.type_manager_.GetConstantUInt32(col_stride).Id();
}
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, col_stride_id, ac_index_id_32}, inst_it);
// Get element type for next step
current_type_id = vec_type_id;
in_matrix = true;
} break;
case SpvType::kVector: {
// If inside a row major matrix type, multiply index by matrix stride,
// else multiply by component size
const uint32_t component_type_id = current_type->inst_.Operand(0);
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
if (in_matrix && !col_major) {
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, matrix_stride_id, ac_index_id_32},
inst_it);
} else {
const uint32_t component_type_size = FindTypeByteSize(component_type_id);
const uint32_t size_id = module_.type_manager_.GetConstantUInt32(component_type_size).Id();
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, size_id, ac_index_id_32}, inst_it);
}
// Get element type for next step
current_type_id = component_type_id;
} break;
case SpvType::kStruct: {
// Get buffer byte offset for the referenced member
const Constant* member_constant = module_.type_manager_.FindConstantById(ac_index_id);
assert(!member_constant->is_spec_constant_);
uint32_t member_index = member_constant->inst_.Operand(0);
uint32_t member_offset = GetMemberDecoration(current_type_id, member_index, spv::DecorationOffset)->Word(4);
current_offset_id = module_.type_manager_.GetConstantUInt32(member_offset).Id();
// Look for matrix stride for this member if there is one. The matrix
// stride is not on the matrix type, but in a OpMemberDecorate on the
// enclosing struct type at the member index. If none found, reset
// stride to 0.
const Instruction* decoration_matrix_stride =
GetMemberDecoration(current_type_id, member_index, spv::DecorationMatrixStride);
matrix_stride = decoration_matrix_stride ? decoration_matrix_stride->Word(4) : 0;
const Instruction* decoration_col_major =
GetMemberDecoration(current_type_id, member_index, spv::DecorationColMajor);
col_major = decoration_col_major != nullptr;
// Get element type for next step
current_type_id = current_type->inst_.Operand(member_index);
} break;
default: {
module_.InternalError(Name(), "GetLastByte has unexpected non-composite type");
} break;
}
if (sum_id == 0) {
sum_id = current_offset_id;
} else {
const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id}, inst_it);
sum_id = new_sum_id;
}
ac_word_index++;
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
}
// Add in offset of last byte of referenced object
const uint32_t accessed_type_size = FindTypeByteSize(current_type_id, matrix_stride, col_major, in_matrix);
const uint32_t last_byte_index = accessed_type_size - 1;
const uint32_t last_byte_index_id = module_.type_manager_.GetConstantUInt32(last_byte_index).Id();
const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, last_byte_index_id}, inst_it);
return new_sum_id;
}
// Generate code to convert integer id to 32bit, if needed.
uint32_t Pass::ConvertTo32(uint32_t id, BasicBlock& block, InstructionIt* inst_it) {
// Find type doing the indexing into the access chain
const Type* type = nullptr;
const Constant* constant = module_.type_manager_.FindConstantById(id);
if (constant) {
type = &constant->type_;
} else {
const Instruction* inst = block.function_.FindInstruction(id);
if (inst) {
type = module_.type_manager_.FindTypeById(inst->TypeId());
}
}
if (!type) {
return id;
}
assert(type->spv_type_ == SpvType::kInt);
if (type->inst_.Word(2) == 32) {
return id;
}
const bool is_signed = type->inst_.Word(3) != 0;
const uint32_t new_id = module_.TakeNextId();
const Type& uint32_type = module_.type_manager_.GetTypeInt(32, false);
if (is_signed) {
block.CreateInstruction(spv::OpSConvert, {uint32_type.Id(), new_id, id}, inst_it);
} else {
block.CreateInstruction(spv::OpUConvert, {uint32_type.Id(), new_id, id}, inst_it);
}
return new_id; // Return an id to the 32bit equivalent.
}
// Generate code to cast integer it to 32bit unsigned, if needed.
uint32_t Pass::CastToUint32(uint32_t id, BasicBlock& block, InstructionIt* inst_it) {
// Convert value to 32-bit if necessary
uint32_t int32_id = ConvertTo32(id, block, inst_it);
const Type* type = nullptr;
const Constant* constant = module_.type_manager_.FindConstantById(int32_id);
if (constant) {
type = &constant->type_;
} else {
const Instruction* inst = block.function_.FindInstruction(int32_id);
if (inst) {
type = module_.type_manager_.FindTypeById(inst->TypeId());
}
}
if (!type) {
return int32_id;
}
assert(type->spv_type_ == SpvType::kInt);
const bool is_signed = type->inst_.Word(3) != 0;
if (!is_signed) {
return int32_id;
}
const Type& uint32_type = module_.type_manager_.GetTypeInt(32, false);
const uint32_t new_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {uint32_type.Id(), new_id, int32_id}, inst_it);
return new_id; // Return an id to the Uint equivalent.
}
InstructionIt Pass::FindTargetInstruction(BasicBlock& block) const {
const uint32_t target_id = target_instruction_->ResultId();
for (auto inst_it = block.instructions_.begin(); inst_it != block.instructions_.end(); ++inst_it) {
// This has to re-loop the entire block to find the instruction, using the ResultID, we can quickly compare
if ((*inst_it)->ResultId() == target_id) {
// Things like OpStore will have a result id of zero, so need to do deep instruction comparison
if (*(*inst_it) == *target_instruction_) {
return inst_it;
}
}
}
module_.InternalError(Name(), "failed to find instruction");
return block.instructions_.end();
}
} // namespace spirv
} // namespace gpuav