| //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Defines the cast and conversion operations in the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| |
| #include "SPIRVOpUtils.h" |
| #include "SPIRVParsingUtils.h" |
| |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir::spirv::AttrNames; |
| |
| namespace mlir::spirv { |
| |
| static LogicalResult verifyCastOp(Operation *op, |
| bool requireSameBitWidth = true, |
| bool skipBitWidthCheck = false) { |
| // Some CastOps have no limit on bit widths for result and operand type. |
| if (skipBitWidthCheck) |
| return success(); |
| |
| Type operandType = op->getOperand(0).getType(); |
| Type resultType = op->getResult(0).getType(); |
| |
| // ODS checks that result type and operand type have the same shape. Check |
| // that composite types match and extract the element types, if any. |
| using TypePair = std::pair<Type, Type>; |
| auto [operandElemTy, resultElemTy] = |
| TypeSwitch<Type, TypePair>(operandType) |
| .Case<VectorType, spirv::CooperativeMatrixType, |
| spirv::JointMatrixINTELType>( |
| [resultType](auto concreteOperandTy) -> TypePair { |
| if (auto concreteResultTy = |
| dyn_cast<decltype(concreteOperandTy)>(resultType)) { |
| return {concreteOperandTy.getElementType(), |
| concreteResultTy.getElementType()}; |
| } |
| return {}; |
| }) |
| .Default([resultType](Type operandType) -> TypePair { |
| return {operandType, resultType}; |
| }); |
| |
| if (!operandElemTy || !resultElemTy) |
| return op->emitOpError("incompatible operand and result types"); |
| |
| unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); |
| unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); |
| bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; |
| |
| if (requireSameBitWidth) { |
| if (!isSameBitWidth) { |
| return op->emitOpError( |
| "expected the same bit widths for operand type and result " |
| "type, but provided ") |
| << operandElemTy << " and " << resultElemTy; |
| } |
| return success(); |
| } |
| |
| if (isSameBitWidth) { |
| return op->emitOpError( |
| "expected the different bit widths for operand type and result " |
| "type, but provided ") |
| << operandElemTy << " and " << resultElemTy; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.BitcastOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BitcastOp::verify() { |
| // TODO: The SPIR-V spec validation rules are different for different |
| // versions. |
| auto operandType = getOperand().getType(); |
| auto resultType = getResult().getType(); |
| if (operandType == resultType) { |
| return emitError("result type must be different from operand type"); |
| } |
| if (llvm::isa<spirv::PointerType>(operandType) && |
| !llvm::isa<spirv::PointerType>(resultType)) { |
| return emitError( |
| "unhandled bit cast conversion from pointer type to non-pointer type"); |
| } |
| if (!llvm::isa<spirv::PointerType>(operandType) && |
| llvm::isa<spirv::PointerType>(resultType)) { |
| return emitError( |
| "unhandled bit cast conversion from non-pointer type to pointer type"); |
| } |
| auto operandBitWidth = getBitWidth(operandType); |
| auto resultBitWidth = getBitWidth(resultType); |
| if (operandBitWidth != resultBitWidth) { |
| return emitOpError("mismatch in result type bitwidth ") |
| << resultBitWidth << " and operand type bitwidth " |
| << operandBitWidth; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertPtrToUOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertPtrToUOp::verify() { |
| auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
| auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType()); |
| if (!resultType || !resultType.isSignlessInteger()) |
| return emitError("result must be a scalar type of unsigned integer"); |
| auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
| if (!spirvModule) |
| return success(); |
| auto addressingModel = spirvModule.getAddressingModel(); |
| if ((addressingModel == spirv::AddressingModel::Logical) || |
| (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
| operandType.getStorageClass() != |
| spirv::StorageClass::PhysicalStorageBuffer)) |
| return emitError("operand must be a physical pointer"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertUToPtrOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertUToPtrOp::verify() { |
| auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType()); |
| auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
| if (!operandType || !operandType.isSignlessInteger()) |
| return emitError("result must be a scalar type of unsigned integer"); |
| auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
| if (!spirvModule) |
| return success(); |
| auto addressingModel = spirvModule.getAddressingModel(); |
| if ((addressingModel == spirv::AddressingModel::Logical) || |
| (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
| resultType.getStorageClass() != |
| spirv::StorageClass::PhysicalStorageBuffer)) |
| return emitError("result must be a physical pointer"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.PtrCastToGenericOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult PtrCastToGenericOp::verify() { |
| auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
| auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
| |
| spirv::StorageClass operandStorage = operandType.getStorageClass(); |
| if (operandStorage != spirv::StorageClass::Workgroup && |
| operandStorage != spirv::StorageClass::CrossWorkgroup && |
| operandStorage != spirv::StorageClass::Function) |
| return emitError("pointer must point to the Workgroup, CrossWorkgroup" |
| ", or Function Storage Class"); |
| |
| spirv::StorageClass resultStorage = resultType.getStorageClass(); |
| if (resultStorage != spirv::StorageClass::Generic) |
| return emitError("result type must be of storage class Generic"); |
| |
| Type operandPointeeType = operandType.getPointeeType(); |
| Type resultPointeeType = resultType.getPointeeType(); |
| if (operandPointeeType != resultPointeeType) |
| return emitOpError("pointer operand's pointee type must have the same " |
| "as the op result type, but found ") |
| << operandPointeeType << " vs " << resultPointeeType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GenericCastToPtrOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult GenericCastToPtrOp::verify() { |
| auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
| auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
| |
| spirv::StorageClass operandStorage = operandType.getStorageClass(); |
| if (operandStorage != spirv::StorageClass::Generic) |
| return emitError("pointer type must be of storage class Generic"); |
| |
| spirv::StorageClass resultStorage = resultType.getStorageClass(); |
| if (resultStorage != spirv::StorageClass::Workgroup && |
| resultStorage != spirv::StorageClass::CrossWorkgroup && |
| resultStorage != spirv::StorageClass::Function) |
| return emitError("result must point to the Workgroup, CrossWorkgroup, " |
| "or Function Storage Class"); |
| |
| Type operandPointeeType = operandType.getPointeeType(); |
| Type resultPointeeType = resultType.getPointeeType(); |
| if (operandPointeeType != resultPointeeType) |
| return emitOpError("pointer operand's pointee type must have the same " |
| "as the op result type, but found ") |
| << operandPointeeType << " vs " << resultPointeeType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.GenericCastToPtrExplicitOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult GenericCastToPtrExplicitOp::verify() { |
| auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
| auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
| |
| spirv::StorageClass operandStorage = operandType.getStorageClass(); |
| if (operandStorage != spirv::StorageClass::Generic) |
| return emitError("pointer type must be of storage class Generic"); |
| |
| spirv::StorageClass resultStorage = resultType.getStorageClass(); |
| if (resultStorage != spirv::StorageClass::Workgroup && |
| resultStorage != spirv::StorageClass::CrossWorkgroup && |
| resultStorage != spirv::StorageClass::Function) |
| return emitError("result must point to the Workgroup, CrossWorkgroup, " |
| "or Function Storage Class"); |
| |
| Type operandPointeeType = operandType.getPointeeType(); |
| Type resultPointeeType = resultType.getPointeeType(); |
| if (operandPointeeType != resultPointeeType) |
| return emitOpError("pointer operand's pointee type must have the same " |
| "as the op result type, but found ") |
| << operandPointeeType << " vs " << resultPointeeType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertFToSOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertFToSOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
| /*skipBitWidthCheck=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertFToUOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertFToUOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
| /*skipBitWidthCheck=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertSToFOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertSToFOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
| /*skipBitWidthCheck=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.ConvertUToFOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConvertUToFOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
| /*skipBitWidthCheck=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.INTELConvertBF16ToFOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult INTELConvertBF16ToFOp::verify() { |
| auto operandType = getOperand().getType(); |
| auto resultType = getResult().getType(); |
| // ODS checks that vector result type and vector operand type have the same |
| // shape. |
| if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
| unsigned operandNumElements = vectorType.getNumElements(); |
| unsigned resultNumElements = |
| llvm::cast<VectorType>(resultType).getNumElements(); |
| if (operandNumElements != resultNumElements) { |
| return emitOpError( |
| "operand and result must have same number of elements"); |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.INTELConvertFToBF16Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult INTELConvertFToBF16Op::verify() { |
| auto operandType = getOperand().getType(); |
| auto resultType = getResult().getType(); |
| // ODS checks that vector result type and vector operand type have the same |
| // shape. |
| if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
| unsigned operandNumElements = vectorType.getNumElements(); |
| unsigned resultNumElements = |
| llvm::cast<VectorType>(resultType).getNumElements(); |
| if (operandNumElements != resultNumElements) { |
| return emitOpError( |
| "operand and result must have same number of elements"); |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.FConvertOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::FConvertOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.SConvertOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::SConvertOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.UConvertOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult spirv::UConvertOp::verify() { |
| return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
| } |
| |
| } // namespace mlir::spirv |