| //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| /// |
| /// \file This file contains class to help build DXIL op functions. |
| //===----------------------------------------------------------------------===// |
| |
| #include "DXILOpBuilder.h" |
| #include "DXILConstants.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Support/DXILABI.h" |
| #include "llvm/Support/ErrorHandling.h" |
| |
| using namespace llvm; |
| using namespace llvm::dxil; |
| |
| constexpr StringLiteral DXILOpNamePrefix = "dx.op."; |
| |
| namespace { |
| |
| enum OverloadKind : uint16_t { |
| VOID = 1, |
| HALF = 1 << 1, |
| FLOAT = 1 << 2, |
| DOUBLE = 1 << 3, |
| I1 = 1 << 4, |
| I8 = 1 << 5, |
| I16 = 1 << 6, |
| I32 = 1 << 7, |
| I64 = 1 << 8, |
| UserDefineType = 1 << 9, |
| ObjectType = 1 << 10, |
| }; |
| |
| } // namespace |
| |
| static const char *getOverloadTypeName(OverloadKind Kind) { |
| switch (Kind) { |
| case OverloadKind::HALF: |
| return "f16"; |
| case OverloadKind::FLOAT: |
| return "f32"; |
| case OverloadKind::DOUBLE: |
| return "f64"; |
| case OverloadKind::I1: |
| return "i1"; |
| case OverloadKind::I8: |
| return "i8"; |
| case OverloadKind::I16: |
| return "i16"; |
| case OverloadKind::I32: |
| return "i32"; |
| case OverloadKind::I64: |
| return "i64"; |
| case OverloadKind::VOID: |
| case OverloadKind::ObjectType: |
| case OverloadKind::UserDefineType: |
| break; |
| } |
| llvm_unreachable("invalid overload type for name"); |
| return "void"; |
| } |
| |
| static OverloadKind getOverloadKind(Type *Ty) { |
| Type::TypeID T = Ty->getTypeID(); |
| switch (T) { |
| case Type::VoidTyID: |
| return OverloadKind::VOID; |
| case Type::HalfTyID: |
| return OverloadKind::HALF; |
| case Type::FloatTyID: |
| return OverloadKind::FLOAT; |
| case Type::DoubleTyID: |
| return OverloadKind::DOUBLE; |
| case Type::IntegerTyID: { |
| IntegerType *ITy = cast<IntegerType>(Ty); |
| unsigned Bits = ITy->getBitWidth(); |
| switch (Bits) { |
| case 1: |
| return OverloadKind::I1; |
| case 8: |
| return OverloadKind::I8; |
| case 16: |
| return OverloadKind::I16; |
| case 32: |
| return OverloadKind::I32; |
| case 64: |
| return OverloadKind::I64; |
| default: |
| llvm_unreachable("invalid overload type"); |
| return OverloadKind::VOID; |
| } |
| } |
| case Type::PointerTyID: |
| return OverloadKind::UserDefineType; |
| case Type::StructTyID: |
| return OverloadKind::ObjectType; |
| default: |
| llvm_unreachable("invalid overload type"); |
| return OverloadKind::VOID; |
| } |
| } |
| |
| static std::string getTypeName(OverloadKind Kind, Type *Ty) { |
| if (Kind < OverloadKind::UserDefineType) { |
| return getOverloadTypeName(Kind); |
| } else if (Kind == OverloadKind::UserDefineType) { |
| StructType *ST = cast<StructType>(Ty); |
| return ST->getStructName().str(); |
| } else if (Kind == OverloadKind::ObjectType) { |
| StructType *ST = cast<StructType>(Ty); |
| return ST->getStructName().str(); |
| } else { |
| std::string Str; |
| raw_string_ostream OS(Str); |
| Ty->print(OS); |
| return OS.str(); |
| } |
| } |
| |
| // Static properties. |
| struct OpCodeProperty { |
| dxil::OpCode OpCode; |
| // Offset in DXILOpCodeNameTable. |
| unsigned OpCodeNameOffset; |
| dxil::OpCodeClass OpCodeClass; |
| // Offset in DXILOpCodeClassNameTable. |
| unsigned OpCodeClassNameOffset; |
| uint16_t OverloadTys; |
| llvm::Attribute::AttrKind FuncAttr; |
| int OverloadParamIndex; // parameter index which control the overload. |
| // When < 0, should be only 1 overload type. |
| unsigned NumOfParameters; // Number of parameters include return value. |
| unsigned ParameterTableOffset; // Offset in ParameterTable. |
| }; |
| |
| // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and |
| // getOpCodeParameterKind which generated by tableGen. |
| #define DXIL_OP_OPERATION_TABLE |
| #include "DXILOperation.inc" |
| #undef DXIL_OP_OPERATION_TABLE |
| |
| static std::string constructOverloadName(OverloadKind Kind, Type *Ty, |
| const OpCodeProperty &Prop) { |
| if (Kind == OverloadKind::VOID) { |
| return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); |
| } |
| return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + |
| getTypeName(Kind, Ty)) |
| .str(); |
| } |
| |
| static std::string constructOverloadTypeName(OverloadKind Kind, |
| StringRef TypeName) { |
| if (Kind == OverloadKind::VOID) |
| return TypeName.str(); |
| |
| assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); |
| return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); |
| } |
| |
| static StructType *getOrCreateStructType(StringRef Name, |
| ArrayRef<Type *> EltTys, |
| LLVMContext &Ctx) { |
| StructType *ST = StructType::getTypeByName(Ctx, Name); |
| if (ST) |
| return ST; |
| |
| return StructType::create(Ctx, EltTys, Name); |
| } |
| |
| static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { |
| OverloadKind Kind = getOverloadKind(OverloadTy); |
| std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); |
| Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, |
| Type::getInt32Ty(Ctx)}; |
| return getOrCreateStructType(TypeName, FieldTypes, Ctx); |
| } |
| |
| static StructType *getHandleType(LLVMContext &Ctx) { |
| return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx), |
| Ctx); |
| } |
| |
| static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { |
| auto &Ctx = OverloadTy->getContext(); |
| switch (Kind) { |
| case ParameterKind::Void: |
| return Type::getVoidTy(Ctx); |
| case ParameterKind::Half: |
| return Type::getHalfTy(Ctx); |
| case ParameterKind::Float: |
| return Type::getFloatTy(Ctx); |
| case ParameterKind::Double: |
| return Type::getDoubleTy(Ctx); |
| case ParameterKind::I1: |
| return Type::getInt1Ty(Ctx); |
| case ParameterKind::I8: |
| return Type::getInt8Ty(Ctx); |
| case ParameterKind::I16: |
| return Type::getInt16Ty(Ctx); |
| case ParameterKind::I32: |
| return Type::getInt32Ty(Ctx); |
| case ParameterKind::I64: |
| return Type::getInt64Ty(Ctx); |
| case ParameterKind::Overload: |
| return OverloadTy; |
| case ParameterKind::ResourceRet: |
| return getResRetType(OverloadTy, Ctx); |
| case ParameterKind::DXILHandle: |
| return getHandleType(Ctx); |
| default: |
| break; |
| } |
| llvm_unreachable("Invalid parameter kind"); |
| return nullptr; |
| } |
| |
| /// Construct DXIL function type. This is the type of a function with |
| /// the following prototype |
| /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>) |
| /// <param-types> are constructed from types in Prop. |
| /// \param Prop Structure containing DXIL Operation properties based on |
| /// its specification in DXIL.td. |
| /// \param OverloadTy Return type to be used to construct DXIL function type. |
| static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, |
| Type *ReturnTy, Type *OverloadTy) { |
| SmallVector<Type *> ArgTys; |
| |
| auto ParamKinds = getOpCodeParameterKind(*Prop); |
| |
| // Add ReturnTy as return type of the function |
| ArgTys.emplace_back(ReturnTy); |
| |
| // Add DXIL Opcode value type viz., Int32 as first argument |
| ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext())); |
| |
| // Add DXIL Operation parameter types as specified in DXIL properties |
| for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { |
| ParameterKind Kind = ParamKinds[I]; |
| ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); |
| } |
| return FunctionType::get( |
| ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); |
| } |
| |
| namespace llvm { |
| namespace dxil { |
| |
| CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, |
| Type *OverloadTy, |
| SmallVector<Value *> Args) { |
| const OpCodeProperty *Prop = getOpCodeProperty(OpCode); |
| |
| OverloadKind Kind = getOverloadKind(OverloadTy); |
| if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { |
| report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false); |
| } |
| |
| std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); |
| FunctionCallee DXILFn; |
| // Get the function with name DXILFnName, if one exists |
| if (auto *Func = M.getFunction(DXILFnName)) { |
| DXILFn = FunctionCallee(Func); |
| } else { |
| // Construct and add a function with name DXILFnName |
| FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); |
| DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); |
| } |
| |
| return B.CreateCall(DXILFn, Args); |
| } |
| |
| Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { |
| |
| const OpCodeProperty *Prop = getOpCodeProperty(OpCode); |
| // If DXIL Op has no overload parameter, just return the |
| // precise return type specified. |
| if (Prop->OverloadParamIndex < 0) { |
| auto &Ctx = FT->getContext(); |
| switch (Prop->OverloadTys) { |
| case OverloadKind::VOID: |
| return Type::getVoidTy(Ctx); |
| case OverloadKind::HALF: |
| return Type::getHalfTy(Ctx); |
| case OverloadKind::FLOAT: |
| return Type::getFloatTy(Ctx); |
| case OverloadKind::DOUBLE: |
| return Type::getDoubleTy(Ctx); |
| case OverloadKind::I1: |
| return Type::getInt1Ty(Ctx); |
| case OverloadKind::I8: |
| return Type::getInt8Ty(Ctx); |
| case OverloadKind::I16: |
| return Type::getInt16Ty(Ctx); |
| case OverloadKind::I32: |
| return Type::getInt32Ty(Ctx); |
| case OverloadKind::I64: |
| return Type::getInt64Ty(Ctx); |
| default: |
| llvm_unreachable("invalid overload type"); |
| return nullptr; |
| } |
| } |
| |
| // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). |
| Type *OverloadType = FT->getReturnType(); |
| if (Prop->OverloadParamIndex != 0) { |
| // Skip Return Type. |
| OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1); |
| } |
| |
| auto ParamKinds = getOpCodeParameterKind(*Prop); |
| auto Kind = ParamKinds[Prop->OverloadParamIndex]; |
| // For ResRet and CBufferRet, OverloadTy is in field of StructType. |
| if (Kind == ParameterKind::CBufferRet || |
| Kind == ParameterKind::ResourceRet) { |
| auto *ST = cast<StructType>(OverloadType); |
| OverloadType = ST->getElementType(0); |
| } |
| return OverloadType; |
| } |
| |
| const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { |
| return ::getOpCodeName(DXILOp); |
| } |
| } // namespace dxil |
| } // namespace llvm |