| //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // This implements Semantic Analysis for HLSL constructs. |
| //===----------------------------------------------------------------------===// |
| |
| #include "clang/Sema/SemaHLSL.h" |
| #include "clang/Basic/DiagnosticSema.h" |
| #include "clang/Basic/LLVM.h" |
| #include "clang/Basic/TargetInfo.h" |
| #include "clang/Sema/Sema.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include "llvm/TargetParser/Triple.h" |
| #include <iterator> |
| |
| using namespace clang; |
| |
| SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} |
| |
| Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, |
| SourceLocation KwLoc, IdentifierInfo *Ident, |
| SourceLocation IdentLoc, |
| SourceLocation LBrace) { |
| // For anonymous namespace, take the location of the left brace. |
| DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); |
| HLSLBufferDecl *Result = HLSLBufferDecl::Create( |
| getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace); |
| |
| SemaRef.PushOnScopeChains(Result, BufferScope); |
| SemaRef.PushDeclContext(BufferScope, Result); |
| |
| return Result; |
| } |
| |
| void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { |
| auto *BufDecl = cast<HLSLBufferDecl>(Dcl); |
| BufDecl->setRBraceLoc(RBrace); |
| SemaRef.PopDeclContext(); |
| } |
| |
| HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, |
| const AttributeCommonInfo &AL, |
| int X, int Y, int Z) { |
| if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { |
| if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { |
| Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; |
| Diag(AL.getLoc(), diag::note_conflicting_attribute); |
| } |
| return nullptr; |
| } |
| return ::new (getASTContext()) |
| HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); |
| } |
| |
| HLSLShaderAttr * |
| SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, |
| HLSLShaderAttr::ShaderType ShaderType) { |
| if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { |
| if (NT->getType() != ShaderType) { |
| Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; |
| Diag(AL.getLoc(), diag::note_conflicting_attribute); |
| } |
| return nullptr; |
| } |
| return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); |
| } |
| |
| HLSLParamModifierAttr * |
| SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, |
| HLSLParamModifierAttr::Spelling Spelling) { |
| // We can only merge an `in` attribute with an `out` attribute. All other |
| // combinations of duplicated attributes are ill-formed. |
| if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { |
| if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || |
| (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { |
| D->dropAttr<HLSLParamModifierAttr>(); |
| SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; |
| return HLSLParamModifierAttr::Create( |
| getASTContext(), /*MergedSpelling=*/true, AdjustedRange, |
| HLSLParamModifierAttr::Keyword_inout); |
| } |
| Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; |
| Diag(PA->getLocation(), diag::note_conflicting_attribute); |
| return nullptr; |
| } |
| return HLSLParamModifierAttr::Create(getASTContext(), AL); |
| } |
| |
| void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { |
| auto &TargetInfo = getASTContext().getTargetInfo(); |
| |
| if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) |
| return; |
| |
| StringRef Env = TargetInfo.getTriple().getEnvironmentName(); |
| HLSLShaderAttr::ShaderType ShaderType; |
| if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { |
| if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { |
| // The entry point is already annotated - check that it matches the |
| // triple. |
| if (Shader->getType() != ShaderType) { |
| Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) |
| << Shader; |
| FD->setInvalidDecl(); |
| } |
| } else { |
| // Implicitly add the shader attribute if the entry function isn't |
| // explicitly annotated. |
| FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType, |
| FD->getBeginLoc())); |
| } |
| } else { |
| switch (TargetInfo.getTriple().getEnvironment()) { |
| case llvm::Triple::UnknownEnvironment: |
| case llvm::Triple::Library: |
| break; |
| default: |
| llvm_unreachable("Unhandled environment in triple"); |
| } |
| } |
| } |
| |
| void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { |
| const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); |
| assert(ShaderAttr && "Entry point has no shader attribute"); |
| HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); |
| |
| switch (ST) { |
| case HLSLShaderAttr::Pixel: |
| case HLSLShaderAttr::Vertex: |
| case HLSLShaderAttr::Geometry: |
| case HLSLShaderAttr::Hull: |
| case HLSLShaderAttr::Domain: |
| case HLSLShaderAttr::RayGeneration: |
| case HLSLShaderAttr::Intersection: |
| case HLSLShaderAttr::AnyHit: |
| case HLSLShaderAttr::ClosestHit: |
| case HLSLShaderAttr::Miss: |
| case HLSLShaderAttr::Callable: |
| if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { |
| DiagnoseAttrStageMismatch(NT, ST, |
| {HLSLShaderAttr::Compute, |
| HLSLShaderAttr::Amplification, |
| HLSLShaderAttr::Mesh}); |
| FD->setInvalidDecl(); |
| } |
| break; |
| |
| case HLSLShaderAttr::Compute: |
| case HLSLShaderAttr::Amplification: |
| case HLSLShaderAttr::Mesh: |
| if (!FD->hasAttr<HLSLNumThreadsAttr>()) { |
| Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) |
| << HLSLShaderAttr::ConvertShaderTypeToStr(ST); |
| FD->setInvalidDecl(); |
| } |
| break; |
| } |
| |
| for (ParmVarDecl *Param : FD->parameters()) { |
| if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { |
| CheckSemanticAnnotation(FD, Param, AnnotationAttr); |
| } else { |
| // FIXME: Handle struct parameters where annotations are on struct fields. |
| // See: https://github.com/llvm/llvm-project/issues/57875 |
| Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); |
| Diag(Param->getLocation(), diag::note_previous_decl) << Param; |
| FD->setInvalidDecl(); |
| } |
| } |
| // FIXME: Verify return type semantic annotation. |
| } |
| |
| void SemaHLSL::CheckSemanticAnnotation( |
| FunctionDecl *EntryPoint, const Decl *Param, |
| const HLSLAnnotationAttr *AnnotationAttr) { |
| auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); |
| assert(ShaderAttr && "Entry point has no shader attribute"); |
| HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); |
| |
| switch (AnnotationAttr->getKind()) { |
| case attr::HLSLSV_DispatchThreadID: |
| case attr::HLSLSV_GroupIndex: |
| if (ST == HLSLShaderAttr::Compute) |
| return; |
| DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute}); |
| break; |
| default: |
| llvm_unreachable("Unknown HLSLAnnotationAttr"); |
| } |
| } |
| |
| void SemaHLSL::DiagnoseAttrStageMismatch( |
| const Attr *A, HLSLShaderAttr::ShaderType Stage, |
| std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) { |
| SmallVector<StringRef, 8> StageStrings; |
| llvm::transform(AllowedStages, std::back_inserter(StageStrings), |
| [](HLSLShaderAttr::ShaderType ST) { |
| return StringRef( |
| HLSLShaderAttr::ConvertShaderTypeToStr(ST)); |
| }); |
| Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) |
| << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) |
| << (AllowedStages.size() != 1) << join(StageStrings, ", "); |
| } |