Skip to content

Commit 171aa34

Browse files
joaosaffranjoaosaffran
andauthored
[DirectX] Add static sampler support to root signature (#143422)
Implements static samplers parsing from root signature metadata representation. This is required to support Root Signatures in HLSL. Closes: #[126641](#126641) --------- Co-authored-by: joaosaffran <[email protected]>
1 parent 442f99d commit 171aa34

17 files changed

+497
-17
lines changed

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,16 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
7171
BOS.reserveExtraSpace(getSize());
7272

7373
const uint32_t NumParameters = ParametersContainer.size();
74-
74+
const uint32_t NumSamplers = StaticSamplers.size();
7575
support::endian::write(BOS, Version, llvm::endianness::little);
7676
support::endian::write(BOS, NumParameters, llvm::endianness::little);
7777
support::endian::write(BOS, RootParameterOffset, llvm::endianness::little);
78-
support::endian::write(BOS, NumStaticSamplers, llvm::endianness::little);
79-
support::endian::write(BOS, StaticSamplersOffset, llvm::endianness::little);
78+
support::endian::write(BOS, NumSamplers, llvm::endianness::little);
79+
uint32_t SSO = StaticSamplersOffset;
80+
if (NumSamplers > 0)
81+
SSO = writePlaceholder(BOS);
82+
else
83+
support::endian::write(BOS, SSO, llvm::endianness::little);
8084
support::endian::write(BOS, Flags, llvm::endianness::little);
8185

8286
SmallVector<uint32_t> ParamsOffsets;
@@ -142,20 +146,23 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
142146
}
143147
}
144148
}
145-
for (const auto &S : StaticSamplers) {
146-
support::endian::write(BOS, S.Filter, llvm::endianness::little);
147-
support::endian::write(BOS, S.AddressU, llvm::endianness::little);
148-
support::endian::write(BOS, S.AddressV, llvm::endianness::little);
149-
support::endian::write(BOS, S.AddressW, llvm::endianness::little);
150-
support::endian::write(BOS, S.MipLODBias, llvm::endianness::little);
151-
support::endian::write(BOS, S.MaxAnisotropy, llvm::endianness::little);
152-
support::endian::write(BOS, S.ComparisonFunc, llvm::endianness::little);
153-
support::endian::write(BOS, S.BorderColor, llvm::endianness::little);
154-
support::endian::write(BOS, S.MinLOD, llvm::endianness::little);
155-
support::endian::write(BOS, S.MaxLOD, llvm::endianness::little);
156-
support::endian::write(BOS, S.ShaderRegister, llvm::endianness::little);
157-
support::endian::write(BOS, S.RegisterSpace, llvm::endianness::little);
158-
support::endian::write(BOS, S.ShaderVisibility, llvm::endianness::little);
149+
if (NumSamplers > 0) {
150+
rewriteOffsetToCurrentByte(BOS, SSO);
151+
for (const auto &S : StaticSamplers) {
152+
support::endian::write(BOS, S.Filter, llvm::endianness::little);
153+
support::endian::write(BOS, S.AddressU, llvm::endianness::little);
154+
support::endian::write(BOS, S.AddressV, llvm::endianness::little);
155+
support::endian::write(BOS, S.AddressW, llvm::endianness::little);
156+
support::endian::write(BOS, S.MipLODBias, llvm::endianness::little);
157+
support::endian::write(BOS, S.MaxAnisotropy, llvm::endianness::little);
158+
support::endian::write(BOS, S.ComparisonFunc, llvm::endianness::little);
159+
support::endian::write(BOS, S.BorderColor, llvm::endianness::little);
160+
support::endian::write(BOS, S.MinLOD, llvm::endianness::little);
161+
support::endian::write(BOS, S.MaxLOD, llvm::endianness::little);
162+
support::endian::write(BOS, S.ShaderRegister, llvm::endianness::little);
163+
support::endian::write(BOS, S.RegisterSpace, llvm::endianness::little);
164+
support::endian::write(BOS, S.ShaderVisibility, llvm::endianness::little);
165+
}
159166
}
160167
assert(Storage.size() == getSize());
161168
OS.write(Storage.data(), Storage.size());

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/Support/Error.h"
2828
#include "llvm/Support/ErrorHandling.h"
2929
#include "llvm/Support/raw_ostream.h"
30+
#include <cmath>
3031
#include <cstdint>
3132
#include <optional>
3233
#include <utility>
@@ -55,6 +56,13 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5556
return std::nullopt;
5657
}
5758

59+
static std::optional<float> extractMdFloatValue(MDNode *Node,
60+
unsigned int OpId) {
61+
if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62+
return CI->getValueAPF().convertToFloat();
63+
return std::nullopt;
64+
}
65+
5866
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5967
unsigned int OpId) {
6068
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
@@ -261,6 +269,81 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
261269
return false;
262270
}
263271

272+
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
273+
MDNode *StaticSamplerNode) {
274+
if (StaticSamplerNode->getNumOperands() != 14)
275+
return reportError(Ctx, "Invalid format for Static Sampler");
276+
277+
dxbc::RTS0::v1::StaticSampler Sampler;
278+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
279+
Sampler.Filter = *Val;
280+
else
281+
return reportError(Ctx, "Invalid value for Filter");
282+
283+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
284+
Sampler.AddressU = *Val;
285+
else
286+
return reportError(Ctx, "Invalid value for AddressU");
287+
288+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
289+
Sampler.AddressV = *Val;
290+
else
291+
return reportError(Ctx, "Invalid value for AddressV");
292+
293+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
294+
Sampler.AddressW = *Val;
295+
else
296+
return reportError(Ctx, "Invalid value for AddressW");
297+
298+
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
299+
Sampler.MipLODBias = *Val;
300+
else
301+
return reportError(Ctx, "Invalid value for MipLODBias");
302+
303+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
304+
Sampler.MaxAnisotropy = *Val;
305+
else
306+
return reportError(Ctx, "Invalid value for MaxAnisotropy");
307+
308+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
309+
Sampler.ComparisonFunc = *Val;
310+
else
311+
return reportError(Ctx, "Invalid value for ComparisonFunc ");
312+
313+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
314+
Sampler.BorderColor = *Val;
315+
else
316+
return reportError(Ctx, "Invalid value for ComparisonFunc ");
317+
318+
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
319+
Sampler.MinLOD = *Val;
320+
else
321+
return reportError(Ctx, "Invalid value for MinLOD");
322+
323+
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
324+
Sampler.MaxLOD = *Val;
325+
else
326+
return reportError(Ctx, "Invalid value for MaxLOD");
327+
328+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
329+
Sampler.ShaderRegister = *Val;
330+
else
331+
return reportError(Ctx, "Invalid value for ShaderRegister");
332+
333+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
334+
Sampler.RegisterSpace = *Val;
335+
else
336+
return reportError(Ctx, "Invalid value for RegisterSpace");
337+
338+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
339+
Sampler.ShaderVisibility = *Val;
340+
else
341+
return reportError(Ctx, "Invalid value for ShaderVisibility");
342+
343+
RSD.StaticSamplers.push_back(Sampler);
344+
return false;
345+
}
346+
264347
static bool parseRootSignatureElement(LLVMContext *Ctx,
265348
mcdxbc::RootSignatureDesc &RSD,
266349
MDNode *Element) {
@@ -276,6 +359,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
276359
.Case("RootSRV", RootSignatureElementKind::SRV)
277360
.Case("RootUAV", RootSignatureElementKind::UAV)
278361
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
362+
.Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
279363
.Default(RootSignatureElementKind::Error);
280364

281365
switch (ElementKind) {
@@ -290,6 +374,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
290374
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
291375
case RootSignatureElementKind::DescriptorTable:
292376
return parseDescriptorTable(Ctx, RSD, Element);
377+
case RootSignatureElementKind::StaticSamplers:
378+
return parseStaticSampler(Ctx, RSD, Element);
293379
case RootSignatureElementKind::Error:
294380
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
295381
}
@@ -406,6 +492,58 @@ static bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
406492
return (Flags & ~Mask) == FlagT::NONE;
407493
}
408494

495+
static bool verifySamplerFilter(uint32_t Value) {
496+
switch (Value) {
497+
#define STATIC_SAMPLER_FILTER(Num, Val) \
498+
case llvm::to_underlying(dxbc::StaticSamplerFilter::Val):
499+
#include "llvm/BinaryFormat/DXContainerConstants.def"
500+
return true;
501+
}
502+
return false;
503+
}
504+
505+
// Values allowed here:
506+
// https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
507+
static bool verifyAddress(uint32_t Address) {
508+
switch (Address) {
509+
#define TEXTURE_ADDRESS_MODE(Num, Val) \
510+
case llvm::to_underlying(dxbc::TextureAddressMode::Val):
511+
#include "llvm/BinaryFormat/DXContainerConstants.def"
512+
return true;
513+
}
514+
return false;
515+
}
516+
517+
static bool verifyMipLODBias(float MipLODBias) {
518+
return MipLODBias >= -16.f && MipLODBias <= 15.99f;
519+
}
520+
521+
static bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
522+
return MaxAnisotropy <= 16u;
523+
}
524+
525+
static bool verifyComparisonFunc(uint32_t ComparisonFunc) {
526+
switch (ComparisonFunc) {
527+
#define COMPARISON_FUNCTION(Num, Val) \
528+
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Val):
529+
#include "llvm/BinaryFormat/DXContainerConstants.def"
530+
return true;
531+
}
532+
return false;
533+
}
534+
535+
static bool verifyBorderColor(uint32_t BorderColor) {
536+
switch (BorderColor) {
537+
#define STATIC_BORDER_COLOR(Num, Val) \
538+
case llvm::to_underlying(dxbc::SamplersBorderColor::Val):
539+
#include "llvm/BinaryFormat/DXContainerConstants.def"
540+
return true;
541+
}
542+
return false;
543+
}
544+
545+
static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
546+
409547
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
410548

411549
if (!verifyVersion(RSD.Version)) {
@@ -463,6 +601,48 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
463601
}
464602
}
465603

604+
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
605+
if (!verifySamplerFilter(Sampler.Filter))
606+
return reportValueError(Ctx, "Filter", Sampler.Filter);
607+
608+
if (!verifyAddress(Sampler.AddressU))
609+
return reportValueError(Ctx, "AddressU", Sampler.AddressU);
610+
611+
if (!verifyAddress(Sampler.AddressV))
612+
return reportValueError(Ctx, "AddressV", Sampler.AddressV);
613+
614+
if (!verifyAddress(Sampler.AddressW))
615+
return reportValueError(Ctx, "AddressW", Sampler.AddressW);
616+
617+
if (!verifyMipLODBias(Sampler.MipLODBias))
618+
return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
619+
620+
if (!verifyMaxAnisotropy(Sampler.MaxAnisotropy))
621+
return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
622+
623+
if (!verifyComparisonFunc(Sampler.ComparisonFunc))
624+
return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
625+
626+
if (!verifyBorderColor(Sampler.BorderColor))
627+
return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
628+
629+
if (!verifyLOD(Sampler.MinLOD))
630+
return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
631+
632+
if (!verifyLOD(Sampler.MaxLOD))
633+
return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
634+
635+
if (!verifyRegisterValue(Sampler.ShaderRegister))
636+
return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
637+
638+
if (!verifyRegisterSpace(Sampler.RegisterSpace))
639+
return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
640+
641+
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
642+
return reportValueError(Ctx, "ShaderVisibility",
643+
Sampler.ShaderVisibility);
644+
}
645+
466646
return false;
467647
}
468648

@@ -542,6 +722,9 @@ analyzeModule(Module &M) {
542722
// offset will always equal to the header size.
543723
RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
544724

725+
// static sampler offset is calculated when writting dxcontainer.
726+
RSD.StaticSamplersOffset = 0u;
727+
545728
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
546729
return RSDMap;
547730
}

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum class RootSignatureElementKind {
3232
UAV = 4,
3333
CBV = 5,
3434
DescriptorTable = 6,
35+
StaticSamplers = 7
3536
};
3637
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
3738
friend AnalysisInfoMixin<RootSignatureAnalysis>;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
4+
target triple = "dxil-unknown-shadermodel6.0-compute"
5+
6+
; CHECK: error: Invalid value for AddressU: 666
7+
; CHECK-NOT: Root Signature Definitions
8+
9+
define void @main() #0 {
10+
entry:
11+
ret void
12+
}
13+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
14+
15+
16+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
17+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
18+
!3 = !{ !5 } ; list of root signature elements
19+
!5 = !{ !"StaticSampler", i32 4, i32 666, i32 3, i32 5, float 0x3FF6CCCCC0000000, i32 9, i32 3, i32 2, float -1.280000e+02, float 1.280000e+02, i32 42, i32 0, i32 0 }
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
4+
target triple = "dxil-unknown-shadermodel6.0-compute"
5+
6+
; CHECK: error: Invalid value for AddressV: 666
7+
; CHECK-NOT: Root Signature Definitions
8+
9+
define void @main() #0 {
10+
entry:
11+
ret void
12+
}
13+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
14+
15+
16+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
17+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
18+
!3 = !{ !5 } ; list of root signature elements
19+
!5 = !{ !"StaticSampler", i32 4, i32 2, i32 666, i32 5, float 0x3FF6CCCCC0000000, i32 9, i32 3, i32 2, float -1.280000e+02, float 1.280000e+02, i32 42, i32 0, i32 0 }
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
4+
target triple = "dxil-unknown-shadermodel6.0-compute"
5+
6+
; CHECK: error: Invalid value for AddressW: 666
7+
; CHECK-NOT: Root Signature Definitions
8+
9+
define void @main() #0 {
10+
entry:
11+
ret void
12+
}
13+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
14+
15+
16+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
17+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
18+
!3 = !{ !5 } ; list of root signature elements
19+
!5 = !{ !"StaticSampler", i32 4, i32 2, i32 3, i32 666, float 0x3FF6CCCCC0000000, i32 9, i32 3, i32 2, float -1.280000e+02, float 1.280000e+02, i32 42, i32 0, i32 0 }
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
4+
target triple = "dxil-unknown-shadermodel6.0-compute"
5+
6+
; CHECK: error: Invalid value for BorderColor: 666
7+
; CHECK-NOT: Root Signature Definitions
8+
9+
define void @main() #0 {
10+
entry:
11+
ret void
12+
}
13+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
14+
15+
16+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
17+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
18+
!3 = !{ !5 } ; list of root signature elements
19+
!5 = !{ !"StaticSampler", i32 4, i32 2, i32 3, i32 5, float 0x3FF6CCCCC0000000, i32 9, i32 3, i32 666, float -1.280000e+02, float 1.280000e+02, i32 42, i32 0, i32 0 }

0 commit comments

Comments
 (0)