diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp index bb860a8f76742..9d7c2757d6ee4 100644 --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -24,7 +24,7 @@ #include "clang/Basic/Specifiers.h" #include "clang/Basic/TypeTraits.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" #include #include diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index f2e992fb7fa69..73843247ce7f2 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -23,7 +23,7 @@ #include "clang/AST/Type.h" #include "clang/Basic/TargetOptions.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "llvm/Frontend/HLSL/RootSignatureMetadata.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalVariable.h" diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 3bab0da5edea8..ca66c71370d60 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -39,7 +39,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index f552040ab31cc..f747c8ccaeb18 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -17,6 +17,7 @@ #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/DXILABI.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -135,6 +136,21 @@ using RootElement = std::variant; +/// The following contains the serialization interface for root elements +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const RootConstants &Constants); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const DescriptorTableClause &Clause); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const RootDescriptor &Descriptor); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const StaticSampler &StaticSampler); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element); + +LLVM_ABI void dumpRootElements(raw_ostream &OS, ArrayRef Elements); + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h new file mode 100644 index 0000000000000..c473a7f1e02e5 --- /dev/null +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h @@ -0,0 +1,56 @@ +//===- RootSignatureMetadata.h - HLSL Root Signature helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains a library for working with HLSL Root Signatures and +/// their metadata representation. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H +#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" + +namespace llvm { +class LLVMContext; +class MDNode; +class Metadata; + +namespace hlsl { +namespace rootsig { + +class MetadataBuilder { +public: + MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef Elements) + : Ctx(Ctx), Elements(Elements) {} + + /// Iterates through the elements and dispatches onto the correct Build method + /// + /// Accumulates the root signature and returns the Metadata node that is just + /// a list of all the elements + LLVM_ABI MDNode *BuildRootSignature(); + +private: + /// Define the various builders for the different metadata types + MDNode *BuildRootFlags(const dxbc::RootFlags &Flags); + MDNode *BuildRootConstants(const RootConstants &Constants); + MDNode *BuildRootDescriptor(const RootDescriptor &Descriptor); + MDNode *BuildDescriptorTable(const DescriptorTable &Table); + MDNode *BuildDescriptorTableClause(const DescriptorTableClause &Clause); + MDNode *BuildStaticSampler(const StaticSampler &Sampler); + + llvm::LLVMContext &Ctx; + ArrayRef Elements; + SmallVector GeneratedMetadata; +}; + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h similarity index 55% rename from llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h rename to llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 4fa080e949d54..14eb7c482c08c 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -1,4 +1,4 @@ -//===- HLSLRootSignatureUtils.h - HLSL Root Signature helpers -------------===// +//===- RootSignatureValidations.h - HLSL Root Signature helpers -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,68 +11,16 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATUREUTILS_H -#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATUREUTILS_H +#ifndef LLVM_FRONTEND_HLSL_ROOTSIGNATUREVALIDATIONS_H +#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREVALIDATIONS_H -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/Frontend/HLSL/HLSLRootSignature.h" -#include "llvm/Support/Compiler.h" -#include "llvm/Support/raw_ostream.h" namespace llvm { -class LLVMContext; -class MDNode; -class Metadata; - namespace hlsl { namespace rootsig { -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, - const RootConstants &Constants); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, - const DescriptorTableClause &Clause); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, - const RootDescriptor &Descriptor); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, - const StaticSampler &StaticSampler); - -LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element); - -LLVM_ABI void dumpRootElements(raw_ostream &OS, ArrayRef Elements); - -class MetadataBuilder { -public: - MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef Elements) - : Ctx(Ctx), Elements(Elements) {} - - /// Iterates through the elements and dispatches onto the correct Build method - /// - /// Accumulates the root signature and returns the Metadata node that is just - /// a list of all the elements - LLVM_ABI MDNode *BuildRootSignature(); - -private: - /// Define the various builders for the different metadata types - MDNode *BuildRootFlags(const dxbc::RootFlags &Flags); - MDNode *BuildRootConstants(const RootConstants &Constants); - MDNode *BuildRootDescriptor(const RootDescriptor &Descriptor); - MDNode *BuildDescriptorTable(const DescriptorTable &Table); - MDNode *BuildDescriptorTableClause(const DescriptorTableClause &Clause); - MDNode *BuildStaticSampler(const StaticSampler &Sampler); - - llvm::LLVMContext &Ctx; - ArrayRef Elements; - SmallVector GeneratedMetadata; -}; - struct RangeInfo { const static uint32_t Unbounded = ~0u; @@ -140,4 +88,4 @@ class ResourceRange { } // namespace hlsl } // namespace llvm -#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATUREUTILS_H +#endif // LLVM_FRONTEND_HLSL_ROOTSIGNATUREVALIDATIONS_H diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 8928144730477..534346920ff19 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,7 +1,9 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp HLSLResource.cpp - HLSLRootSignatureUtils.cpp + HLSLRootSignature.cpp + RootSignatureMetadata.cpp + RootSignatureValidations.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp new file mode 100644 index 0000000000000..78c20a6c5c9ff --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp @@ -0,0 +1,244 @@ +//===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helpers for working with HLSL Root Signatures. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "llvm/Support/ScopedPrinter.h" + +namespace llvm { +namespace hlsl { +namespace rootsig { + +template +static std::optional getEnumName(const T Value, + ArrayRef> Enums) { + for (const auto &EnumItem : Enums) + if (EnumItem.Value == Value) + return EnumItem.Name; + return std::nullopt; +} + +template +static raw_ostream &printEnum(raw_ostream &OS, const T Value, + ArrayRef> Enums) { + auto MaybeName = getEnumName(Value, Enums); + if (MaybeName) + OS << *MaybeName; + return OS; +} + +template +static raw_ostream &printFlags(raw_ostream &OS, const T Value, + ArrayRef> Flags) { + bool FlagSet = false; + unsigned Remaining = llvm::to_underlying(Value); + while (Remaining) { + unsigned Bit = 1u << llvm::countr_zero(Remaining); + if (Remaining & Bit) { + if (FlagSet) + OS << " | "; + + auto MaybeFlag = getEnumName(T(Bit), Flags); + if (MaybeFlag) + OS << *MaybeFlag; + else + OS << "invalid: " << Bit; + + FlagSet = true; + } + Remaining &= ~Bit; + } + + if (!FlagSet) + OS << "None"; + return OS; +} + +static const EnumEntry RegisterNames[] = { + {"b", RegisterType::BReg}, + {"t", RegisterType::TReg}, + {"u", RegisterType::UReg}, + {"s", RegisterType::SReg}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { + printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); + OS << Reg.Number; + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const llvm::dxbc::ShaderVisibility &Visibility) { + printEnum(OS, Visibility, dxbc::getShaderVisibility()); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const llvm::dxbc::SamplerFilter &Filter) { + printEnum(OS, Filter, dxbc::getSamplerFilters()); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const dxbc::TextureAddressMode &Address) { + printEnum(OS, Address, dxbc::getTextureAddressModes()); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const dxbc::ComparisonFunc &CompFunc) { + printEnum(OS, CompFunc, dxbc::getComparisonFuncs()); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const dxbc::StaticBorderColor &BorderColor) { + printEnum(OS, BorderColor, dxbc::getStaticBorderColors()); + + return OS; +} + +static const EnumEntry ResourceClassNames[] = { + {"CBV", dxil::ResourceClass::CBuffer}, + {"SRV", dxil::ResourceClass::SRV}, + {"UAV", dxil::ResourceClass::UAV}, + {"Sampler", dxil::ResourceClass::Sampler}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { + printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), + ArrayRef(ResourceClassNames)); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const dxbc::RootDescriptorFlags &Flags) { + printFlags(OS, Flags, dxbc::getRootDescriptorFlags()); + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const llvm::dxbc::DescriptorRangeFlags &Flags) { + printFlags(OS, Flags, dxbc::getDescriptorRangeFlags()); + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) { + OS << "RootFlags("; + printFlags(OS, Flags, dxbc::getRootFlags()); + OS << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) { + OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants + << ", " << Constants.Reg << ", space = " << Constants.Space + << ", visibility = " << Constants.Visibility << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) { + OS << "DescriptorTable(numClauses = " << Table.NumClauses + << ", visibility = " << Table.Visibility << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) { + OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = "; + if (Clause.NumDescriptors == NumDescriptorsUnbounded) + OS << "unbounded"; + else + OS << Clause.NumDescriptors; + OS << ", space = " << Clause.Space << ", offset = "; + if (Clause.Offset == DescriptorTableOffsetAppend) + OS << "DescriptorTableOffsetAppend"; + else + OS << Clause.Offset; + OS << ", flags = " << Clause.Flags << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) { + ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type)); + OS << "Root" << Type << "(" << Descriptor.Reg + << ", space = " << Descriptor.Space + << ", visibility = " << Descriptor.Visibility + << ", flags = " << Descriptor.Flags << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) { + OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter + << ", addressU = " << Sampler.AddressU + << ", addressV = " << Sampler.AddressV + << ", addressW = " << Sampler.AddressW + << ", mipLODBias = " << Sampler.MipLODBias + << ", maxAnisotropy = " << Sampler.MaxAnisotropy + << ", comparisonFunc = " << Sampler.CompFunc + << ", borderColor = " << Sampler.BorderColor + << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD + << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility + << ")"; + return OS; +} + +namespace { + +// We use the OverloadVisit with std::visit to ensure the compiler catches if a +// new RootElement variant type is added but it's operator<< isn't handled. +template struct OverloadedVisit : Ts... { + using Ts::operator()...; +}; +template OverloadedVisit(Ts...) -> OverloadedVisit; + +} // namespace + +raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) { + const auto Visitor = OverloadedVisit{ + [&OS](const dxbc::RootFlags &Flags) { OS << Flags; }, + [&OS](const RootConstants &Constants) { OS << Constants; }, + [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; }, + [&OS](const DescriptorTableClause &Clause) { OS << Clause; }, + [&OS](const DescriptorTable &Table) { OS << Table; }, + [&OS](const StaticSampler &Sampler) { OS << Sampler; }, + }; + std::visit(Visitor, Element); + return OS; +} + +void dumpRootElements(raw_ostream &OS, ArrayRef Elements) { + OS << " RootElements{"; + bool First = true; + for (const RootElement &Element : Elements) { + if (!First) + OS << ","; + OS << " " << Element; + First = false; + } + OS << "}"; +} + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp deleted file mode 100644 index 67f512008b069..0000000000000 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ /dev/null @@ -1,457 +0,0 @@ -//===- HLSLRootSignatureUtils.cpp - HLSL Root Signature helpers -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file This file contains helpers for working with HLSL Root Signatures. -/// -//===----------------------------------------------------------------------===// - -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/bit.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Metadata.h" -#include "llvm/Support/ScopedPrinter.h" - -namespace llvm { -namespace hlsl { -namespace rootsig { - -template -static std::optional getEnumName(const T Value, - ArrayRef> Enums) { - for (const auto &EnumItem : Enums) - if (EnumItem.Value == Value) - return EnumItem.Name; - return std::nullopt; -} - -template -static raw_ostream &printEnum(raw_ostream &OS, const T Value, - ArrayRef> Enums) { - auto MaybeName = getEnumName(Value, Enums); - if (MaybeName) - OS << *MaybeName; - return OS; -} - -template -static raw_ostream &printFlags(raw_ostream &OS, const T Value, - ArrayRef> Flags) { - bool FlagSet = false; - unsigned Remaining = llvm::to_underlying(Value); - while (Remaining) { - unsigned Bit = 1u << llvm::countr_zero(Remaining); - if (Remaining & Bit) { - if (FlagSet) - OS << " | "; - - auto MaybeFlag = getEnumName(T(Bit), Flags); - if (MaybeFlag) - OS << *MaybeFlag; - else - OS << "invalid: " << Bit; - - FlagSet = true; - } - Remaining &= ~Bit; - } - - if (!FlagSet) - OS << "None"; - return OS; -} - -static const EnumEntry RegisterNames[] = { - {"b", RegisterType::BReg}, - {"t", RegisterType::TReg}, - {"u", RegisterType::UReg}, - {"s", RegisterType::SReg}, -}; - -static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { - printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); - OS << Reg.Number; - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const llvm::dxbc::ShaderVisibility &Visibility) { - printEnum(OS, Visibility, dxbc::getShaderVisibility()); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const llvm::dxbc::SamplerFilter &Filter) { - printEnum(OS, Filter, dxbc::getSamplerFilters()); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const dxbc::TextureAddressMode &Address) { - printEnum(OS, Address, dxbc::getTextureAddressModes()); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const dxbc::ComparisonFunc &CompFunc) { - printEnum(OS, CompFunc, dxbc::getComparisonFuncs()); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const dxbc::StaticBorderColor &BorderColor) { - printEnum(OS, BorderColor, dxbc::getStaticBorderColors()); - - return OS; -} - -static const EnumEntry ResourceClassNames[] = { - {"CBV", dxil::ResourceClass::CBuffer}, - {"SRV", dxil::ResourceClass::SRV}, - {"UAV", dxil::ResourceClass::UAV}, - {"Sampler", dxil::ResourceClass::Sampler}, -}; - -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { - printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), - ArrayRef(ResourceClassNames)); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const dxbc::RootDescriptorFlags &Flags) { - printFlags(OS, Flags, dxbc::getRootDescriptorFlags()); - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const llvm::dxbc::DescriptorRangeFlags &Flags) { - printFlags(OS, Flags, dxbc::getDescriptorRangeFlags()); - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) { - OS << "RootFlags("; - printFlags(OS, Flags, dxbc::getRootFlags()); - OS << ")"; - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) { - OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants - << ", " << Constants.Reg << ", space = " << Constants.Space - << ", visibility = " << Constants.Visibility << ")"; - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) { - OS << "DescriptorTable(numClauses = " << Table.NumClauses - << ", visibility = " << Table.Visibility << ")"; - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) { - OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = "; - if (Clause.NumDescriptors == NumDescriptorsUnbounded) - OS << "unbounded"; - else - OS << Clause.NumDescriptors; - OS << ", space = " << Clause.Space << ", offset = "; - if (Clause.Offset == DescriptorTableOffsetAppend) - OS << "DescriptorTableOffsetAppend"; - else - OS << Clause.Offset; - OS << ", flags = " << Clause.Flags << ")"; - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) { - ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type)); - OS << "Root" << Type << "(" << Descriptor.Reg - << ", space = " << Descriptor.Space - << ", visibility = " << Descriptor.Visibility - << ", flags = " << Descriptor.Flags << ")"; - - return OS; -} - -raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) { - OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter - << ", addressU = " << Sampler.AddressU - << ", addressV = " << Sampler.AddressV - << ", addressW = " << Sampler.AddressW - << ", mipLODBias = " << Sampler.MipLODBias - << ", maxAnisotropy = " << Sampler.MaxAnisotropy - << ", comparisonFunc = " << Sampler.CompFunc - << ", borderColor = " << Sampler.BorderColor - << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD - << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility - << ")"; - return OS; -} - -namespace { - -// We use the OverloadVisit with std::visit to ensure the compiler catches if a -// new RootElement variant type is added but it's operator<< or metadata -// generation isn't handled. -template struct OverloadedVisit : Ts... { - using Ts::operator()...; -}; -template OverloadedVisit(Ts...) -> OverloadedVisit; - -} // namespace - -raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) { - const auto Visitor = OverloadedVisit{ - [&OS](const dxbc::RootFlags &Flags) { OS << Flags; }, - [&OS](const RootConstants &Constants) { OS << Constants; }, - [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; }, - [&OS](const DescriptorTableClause &Clause) { OS << Clause; }, - [&OS](const DescriptorTable &Table) { OS << Table; }, - [&OS](const StaticSampler &Sampler) { OS << Sampler; }, - }; - std::visit(Visitor, Element); - return OS; -} - -void dumpRootElements(raw_ostream &OS, ArrayRef Elements) { - OS << " RootElements{"; - bool First = true; - for (const RootElement &Element : Elements) { - if (!First) - OS << ","; - OS << " " << Element; - First = false; - } - OS << "}"; -} - -MDNode *MetadataBuilder::BuildRootSignature() { - const auto Visitor = OverloadedVisit{ - [this](const dxbc::RootFlags &Flags) -> MDNode * { - return BuildRootFlags(Flags); - }, - [this](const RootConstants &Constants) -> MDNode * { - return BuildRootConstants(Constants); - }, - [this](const RootDescriptor &Descriptor) -> MDNode * { - return BuildRootDescriptor(Descriptor); - }, - [this](const DescriptorTableClause &Clause) -> MDNode * { - return BuildDescriptorTableClause(Clause); - }, - [this](const DescriptorTable &Table) -> MDNode * { - return BuildDescriptorTable(Table); - }, - [this](const StaticSampler &Sampler) -> MDNode * { - return BuildStaticSampler(Sampler); - }, - }; - - for (const RootElement &Element : Elements) { - MDNode *ElementMD = std::visit(Visitor, Element); - assert(ElementMD != nullptr && - "Root Element must be initialized and validated"); - GeneratedMetadata.push_back(ElementMD); - } - - return MDNode::get(Ctx, GeneratedMetadata); -} - -MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { - IRBuilder<> Builder(Ctx); - Metadata *Operands[] = { - MDString::get(Ctx, "RootFlags"), - ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), - }; - return MDNode::get(Ctx, Operands); -} - -MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { - IRBuilder<> Builder(Ctx); - Metadata *Operands[] = { - MDString::get(Ctx, "RootConstants"), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Constants.Visibility))), - ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), - ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), - }; - return MDNode::get(Ctx, Operands); -} - -MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { - IRBuilder<> Builder(Ctx); - std::optional TypeName = - getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)), - ArrayRef(ResourceClassNames)); - assert(TypeName && "Provided an invalid Resource Class"); - llvm::SmallString<7> Name({"Root", *TypeName}); - Metadata *Operands[] = { - MDString::get(Ctx, Name), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), - ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), - }; - return MDNode::get(Ctx, Operands); -} - -MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { - IRBuilder<> Builder(Ctx); - SmallVector TableOperands; - // Set the mandatory arguments - TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); - TableOperands.push_back(ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Table.Visibility)))); - - // Remaining operands are references to the table's clauses. The in-memory - // representation of the Root Elements created from parsing will ensure that - // the previous N elements are the clauses for this table. - assert(Table.NumClauses <= GeneratedMetadata.size() && - "Table expected all owned clauses to be generated already"); - // So, add a refence to each clause to our operands - TableOperands.append(GeneratedMetadata.end() - Table.NumClauses, - GeneratedMetadata.end()); - // Then, remove those clauses from the general list of Root Elements - GeneratedMetadata.pop_back_n(Table.NumClauses); - - return MDNode::get(Ctx, TableOperands); -} - -MDNode *MetadataBuilder::BuildDescriptorTableClause( - const DescriptorTableClause &Clause) { - IRBuilder<> Builder(Ctx); - std::optional Name = - getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)), - ArrayRef(ResourceClassNames)); - assert(Name && "Provided an invalid Resource Class"); - Metadata *Operands[] = { - MDString::get(Ctx, *Name), - ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), - }; - return MDNode::get(Ctx, Operands); -} - -MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { - IRBuilder<> Builder(Ctx); - Metadata *Operands[] = { - MDString::get(Ctx, "StaticSampler"), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Filter))), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), - ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), - Sampler.MipLODBias)), - ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), - ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), - ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), - ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), - }; - return MDNode::get(Ctx, Operands); -} - -std::optional -ResourceRange::getOverlapping(const RangeInfo &Info) const { - MapT::const_iterator Interval = Intervals.find(Info.LowerBound); - if (!Interval.valid() || Info.UpperBound < Interval.start()) - return std::nullopt; - return Interval.value(); -} - -const RangeInfo *ResourceRange::lookup(uint32_t X) const { - return Intervals.lookup(X, nullptr); -} - -void ResourceRange::clear() { return Intervals.clear(); } - -std::optional ResourceRange::insert(const RangeInfo &Info) { - uint32_t LowerBound = Info.LowerBound; - uint32_t UpperBound = Info.UpperBound; - - std::optional Res = std::nullopt; - MapT::iterator Interval = Intervals.begin(); - - while (true) { - if (UpperBound < LowerBound) - break; - - Interval.advanceTo(LowerBound); - if (!Interval.valid()) // No interval found - break; - - // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that - // a <= y implicitly from Intervals.find(LowerBound) - if (UpperBound < Interval.start()) - break; // found interval does not overlap with inserted one - - if (!Res.has_value()) // Update to be the first found intersection - Res = Interval.value(); - - if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) { - // x <= a <= b <= y implies that [a;b] is covered by [x;y] - // -> so we don't need to insert this, report an overlap - return Res; - } else if (LowerBound <= Interval.start() && - Interval.stop() <= UpperBound) { - // a <= x <= y <= b implies that [x;y] is covered by [a;b] - // -> so remove the existing interval that we will cover with the - // overwrite - Interval.erase(); - } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) { - // a < x <= b <= y implies that [a; x] is not covered but [x;b] is - // -> so set b = x - 1 such that [a;x-1] is now the interval to insert - UpperBound = Interval.start() - 1; - } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) { - // a < x <= b <= y implies that [y; b] is not covered but [a;y] is - // -> so set a = y + 1 such that [y+1;b] is now the interval to insert - LowerBound = Interval.stop() + 1; - } - } - - assert(LowerBound <= UpperBound && "Attempting to insert an empty interval"); - Intervals.insert(LowerBound, UpperBound, &Info); - return Res; -} - -} // namespace rootsig -} // namespace hlsl -} // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp new file mode 100644 index 0000000000000..f7669f09dcecc --- /dev/null +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -0,0 +1,194 @@ +//===- RootSignatureMetadata.h - HLSL Root Signature helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file implements a library for working with HLSL Root Signatures +/// and their metadata representation. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/RootSignatureMetadata.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/ScopedPrinter.h" + +namespace llvm { +namespace hlsl { +namespace rootsig { + +static const EnumEntry ResourceClassNames[] = { + {"CBV", dxil::ResourceClass::CBuffer}, + {"SRV", dxil::ResourceClass::SRV}, + {"UAV", dxil::ResourceClass::UAV}, + {"Sampler", dxil::ResourceClass::Sampler}, +}; + +static std::optional getResourceName(dxil::ResourceClass Class) { + for (const auto &ClassEnum : ResourceClassNames) + if (ClassEnum.Value == Class) + return ClassEnum.Name; + return std::nullopt; +} + +namespace { + +// We use the OverloadVisit with std::visit to ensure the compiler catches if a +// new RootElement variant type is added but it's metadata generation isn't +// handled. +template struct OverloadedVisit : Ts... { + using Ts::operator()...; +}; +template OverloadedVisit(Ts...) -> OverloadedVisit; + +} // namespace + +MDNode *MetadataBuilder::BuildRootSignature() { + const auto Visitor = OverloadedVisit{ + [this](const dxbc::RootFlags &Flags) -> MDNode * { + return BuildRootFlags(Flags); + }, + [this](const RootConstants &Constants) -> MDNode * { + return BuildRootConstants(Constants); + }, + [this](const RootDescriptor &Descriptor) -> MDNode * { + return BuildRootDescriptor(Descriptor); + }, + [this](const DescriptorTableClause &Clause) -> MDNode * { + return BuildDescriptorTableClause(Clause); + }, + [this](const DescriptorTable &Table) -> MDNode * { + return BuildDescriptorTable(Table); + }, + [this](const StaticSampler &Sampler) -> MDNode * { + return BuildStaticSampler(Sampler); + }, + }; + + for (const RootElement &Element : Elements) { + MDNode *ElementMD = std::visit(Visitor, Element); + assert(ElementMD != nullptr && + "Root Element must be initialized and validated"); + GeneratedMetadata.push_back(ElementMD); + } + + return MDNode::get(Ctx, GeneratedMetadata); +} + +MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { + IRBuilder<> Builder(Ctx); + Metadata *Operands[] = { + MDString::get(Ctx, "RootFlags"), + ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), + }; + return MDNode::get(Ctx, Operands); +} + +MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { + IRBuilder<> Builder(Ctx); + Metadata *Operands[] = { + MDString::get(Ctx, "RootConstants"), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Constants.Visibility))), + ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), + ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), + }; + return MDNode::get(Ctx, Operands); +} + +MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { + IRBuilder<> Builder(Ctx); + std::optional ResName = getResourceName( + dxil::ResourceClass(llvm::to_underlying(Descriptor.Type))); + assert(ResName && "Provided an invalid Resource Class"); + llvm::SmallString<7> Name({"Root", *ResName}); + Metadata *Operands[] = { + MDString::get(Ctx, Name), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), + ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), + }; + return MDNode::get(Ctx, Operands); +} + +MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { + IRBuilder<> Builder(Ctx); + SmallVector TableOperands; + // Set the mandatory arguments + TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); + TableOperands.push_back(ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Table.Visibility)))); + + // Remaining operands are references to the table's clauses. The in-memory + // representation of the Root Elements created from parsing will ensure that + // the previous N elements are the clauses for this table. + assert(Table.NumClauses <= GeneratedMetadata.size() && + "Table expected all owned clauses to be generated already"); + // So, add a refence to each clause to our operands + TableOperands.append(GeneratedMetadata.end() - Table.NumClauses, + GeneratedMetadata.end()); + // Then, remove those clauses from the general list of Root Elements + GeneratedMetadata.pop_back_n(Table.NumClauses); + + return MDNode::get(Ctx, TableOperands); +} + +MDNode *MetadataBuilder::BuildDescriptorTableClause( + const DescriptorTableClause &Clause) { + IRBuilder<> Builder(Ctx); + std::optional ResName = + getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type))); + assert(ResName && "Provided an invalid Resource Class"); + Metadata *Operands[] = { + MDString::get(Ctx, *ResName), + ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Clause.Flags))), + }; + return MDNode::get(Ctx, Operands); +} + +MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { + IRBuilder<> Builder(Ctx); + Metadata *Operands[] = { + MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), + ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), + Sampler.MipLODBias)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + ConstantAsMetadata::get( + llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantAsMetadata::get( + llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + }; + return MDNode::get(Ctx, Operands); +} + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp new file mode 100644 index 0000000000000..9825946d59690 --- /dev/null +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -0,0 +1,84 @@ +//===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helpers for working with HLSL Root Signatures. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" + +namespace llvm { +namespace hlsl { +namespace rootsig { + +std::optional +ResourceRange::getOverlapping(const RangeInfo &Info) const { + MapT::const_iterator Interval = Intervals.find(Info.LowerBound); + if (!Interval.valid() || Info.UpperBound < Interval.start()) + return std::nullopt; + return Interval.value(); +} + +const RangeInfo *ResourceRange::lookup(uint32_t X) const { + return Intervals.lookup(X, nullptr); +} + +void ResourceRange::clear() { return Intervals.clear(); } + +std::optional ResourceRange::insert(const RangeInfo &Info) { + uint32_t LowerBound = Info.LowerBound; + uint32_t UpperBound = Info.UpperBound; + + std::optional Res = std::nullopt; + MapT::iterator Interval = Intervals.begin(); + + while (true) { + if (UpperBound < LowerBound) + break; + + Interval.advanceTo(LowerBound); + if (!Interval.valid()) // No interval found + break; + + // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that + // a <= y implicitly from Intervals.find(LowerBound) + if (UpperBound < Interval.start()) + break; // found interval does not overlap with inserted one + + if (!Res.has_value()) // Update to be the first found intersection + Res = Interval.value(); + + if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) { + // x <= a <= b <= y implies that [a;b] is covered by [x;y] + // -> so we don't need to insert this, report an overlap + return Res; + } else if (LowerBound <= Interval.start() && + Interval.stop() <= UpperBound) { + // a <= x <= y <= b implies that [x;y] is covered by [a;b] + // -> so remove the existing interval that we will cover with the + // overwrite + Interval.erase(); + } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) { + // a < x <= b <= y implies that [a; x] is not covered but [x;b] is + // -> so set b = x - 1 such that [a;x-1] is now the interval to insert + UpperBound = Interval.start() - 1; + } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) { + // a < x <= b <= y implies that [y; b] is not covered but [a;y] is + // -> so set a = y + 1 such that [y+1;b] is now the interval to insert + LowerBound = Interval.stop() + 1; + } + } + + assert(LowerBound <= UpperBound && "Attempting to insert an empty interval"); + Intervals.insert(LowerBound, UpperBound, &Info); + return Res; +} + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp index e090f6bae470f..73a1fc67e6d28 100644 --- a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp +++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" #include "gtest/gtest.h" using namespace llvm::hlsl::rootsig; diff --git a/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp index 0ef6fe84f0ec9..be3f51e0e83d5 100644 --- a/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp +++ b/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "gtest/gtest.h" using namespace llvm::hlsl::rootsig;