| // Copyright (c) 2024 NVIDIA Corporation |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Validate instructions that manipulate tensor layout and view objects |
| |
| #include "source/opcode.h" |
| #include "source/spirv_target_env.h" |
| #include "source/val/instruction.h" |
| #include "source/val/validate.h" |
| #include "source/val/validation_state.h" |
| |
| namespace spvtools { |
| namespace val { |
| namespace { |
| |
| spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _, |
| const Instruction* inst) { |
| const auto result_type_index = 0; |
| const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index); |
| const auto result_type = _.FindDef(result_type_id); |
| |
| if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) { |
| return _.diag(SPV_ERROR_INVALID_ID, inst) |
| << spvOpcodeString(inst->opcode()) << " Result Type <id> " |
| << _.getIdName(result_type_id) << " is not a tensor layout type."; |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _, |
| const Instruction* inst) { |
| const auto result_type_index = 0; |
| const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index); |
| const auto result_type = _.FindDef(result_type_id); |
| |
| if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) { |
| return _.diag(SPV_ERROR_INVALID_ID, inst) |
| << spvOpcodeString(inst->opcode()) << " Result Type <id> " |
| << _.getIdName(result_type_id) << " is not a tensor view type."; |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _, |
| const Instruction* inst) { |
| if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _, |
| const Instruction* inst) { |
| if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; |
| |
| return SPV_SUCCESS; |
| } |
| |
| enum ExpectedNumValues { |
| DIM, |
| DIMx2, |
| ONE, |
| FOUR, |
| }; |
| |
| spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _, |
| const Instruction* inst, |
| ExpectedNumValues expected, |
| bool is_view) { |
| std::string type_str; |
| if (is_view) { |
| if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; |
| type_str = "TensorView"; |
| } else { |
| if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; |
| type_str = "TensorLayout"; |
| } |
| |
| const auto result_type_id = inst->GetOperandAs<uint32_t>(0); |
| const auto tensor_id = inst->GetOperandAs<uint32_t>(2); |
| const auto tensor = _.FindDef(tensor_id); |
| if (!tensor || result_type_id != tensor->type_id()) { |
| return _.diag(SPV_ERROR_INVALID_ID, inst) |
| << spvOpcodeString(inst->opcode()) << " Result Type <id> " |
| << _.getIdName(result_type_id) << " does not match " << type_str |
| << " type."; |
| } |
| |
| const auto num_values = inst->operands().size() - 3; |
| |
| const auto result_type = _.FindDef(result_type_id); |
| const auto dim_index = 1; |
| const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index); |
| uint64_t dim_value; |
| if (_.EvalConstantValUint64(dim_id, &dim_value)) { |
| uint64_t expected_num_values = 0; |
| switch (expected) { |
| case DIM: |
| expected_num_values = dim_value; |
| break; |
| case DIMx2: |
| expected_num_values = dim_value * 2; |
| break; |
| case ONE: |
| expected_num_values = 1; |
| break; |
| case FOUR: |
| expected_num_values = 4; |
| break; |
| } |
| |
| if (num_values != expected_num_values) { |
| return _.diag(SPV_ERROR_INVALID_ID, inst) |
| << spvOpcodeString(inst->opcode()) |
| << " unexpected number of operands."; |
| } |
| } |
| |
| for (uint32_t i = 0; i < num_values; ++i) { |
| const auto val_id = inst->GetOperandAs<uint32_t>(i + 3); |
| const auto val = _.FindDef(val_id); |
| if (!val || !_.IsIntScalarType(val->type_id()) || |
| _.GetBitWidth(val->type_id()) != 32) { |
| return _.diag(SPV_ERROR_INVALID_ID, inst) |
| << spvOpcodeString(inst->opcode()) << " operand <id> " |
| << _.getIdName(val_id) << " is not a 32-bit integer."; |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| } // namespace |
| |
| spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) { |
| switch (inst->opcode()) { |
| case spv::Op::OpCreateTensorLayoutNV: |
| if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error; |
| break; |
| case spv::Op::OpCreateTensorViewNV: |
| if (auto error = ValidateCreateTensorViewNV(_, inst)) return error; |
| break; |
| case spv::Op::OpTensorLayoutSetBlockSizeNV: |
| case spv::Op::OpTensorLayoutSetDimensionNV: |
| case spv::Op::OpTensorLayoutSetStrideNV: |
| if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false)) |
| return error; |
| break; |
| case spv::Op::OpTensorLayoutSliceNV: |
| if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false)) |
| return error; |
| break; |
| case spv::Op::OpTensorLayoutSetClampValueNV: |
| if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false)) |
| return error; |
| break; |
| case spv::Op::OpTensorViewSetDimensionNV: |
| case spv::Op::OpTensorViewSetStrideNV: |
| if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true)) |
| return error; |
| break; |
| case spv::Op::OpTensorViewSetClipNV: |
| if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true)) |
| return error; |
| break; |
| default: |
| break; |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| } // namespace val |
| } // namespace spvtools |