12
12
// ===----------------------------------------------------------------------===//
13
13
#include " DXILRootSignature.h"
14
14
#include " DirectX.h"
15
+ #include " llvm/ADT/StringRef.h"
15
16
#include " llvm/ADT/StringSwitch.h"
16
17
#include " llvm/ADT/Twine.h"
17
18
#include " llvm/Analysis/DXILMetadataAnalysis.h"
30
31
#include < cmath>
31
32
#include < cstdint>
32
33
#include < optional>
34
+ #include < string>
33
35
#include < utility>
34
36
35
37
using namespace llvm ;
@@ -48,6 +50,71 @@ static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
48
50
return true ;
49
51
}
50
52
53
+ // Template function to get formatted type string based on C++ type
54
+ template <typename T> std::string getTypeFormatted () {
55
+ if constexpr (std::is_same_v<T, MDString>) {
56
+ return " string" ;
57
+ } else if constexpr (std::is_same_v<T, MDNode *> ||
58
+ std::is_same_v<T, const MDNode *>) {
59
+ return " metadata" ;
60
+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata *> ||
61
+ std::is_same_v<T, const ConstantAsMetadata *>) {
62
+ return " constant" ;
63
+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata>) {
64
+ return " constant" ;
65
+ } else if constexpr (std::is_same_v<T, ConstantInt *> ||
66
+ std::is_same_v<T, const ConstantInt *>) {
67
+ return " constant int" ;
68
+ } else if constexpr (std::is_same_v<T, ConstantInt>) {
69
+ return " constant int" ;
70
+ }
71
+ return " unknown" ;
72
+ }
73
+
74
+ // Helper function to get the actual type of a metadata operand
75
+ std::string getActualMDType (const MDNode *Node, unsigned Index) {
76
+ if (!Node || Index >= Node->getNumOperands ())
77
+ return " null" ;
78
+
79
+ Metadata *Op = Node->getOperand (Index);
80
+ if (!Op)
81
+ return " null" ;
82
+
83
+ if (isa<MDString>(Op))
84
+ return getTypeFormatted<MDString>();
85
+
86
+ if (isa<ConstantAsMetadata>(Op)) {
87
+ if (auto *CAM = dyn_cast<ConstantAsMetadata>(Op)) {
88
+ Type *T = CAM->getValue ()->getType ();
89
+ if (T->isIntegerTy ())
90
+ return (Twine (" i" ) + Twine (T->getIntegerBitWidth ())).str ();
91
+ if (T->isFloatingPointTy ())
92
+ return T->isFloatTy () ? getTypeFormatted<float >()
93
+ : T->isDoubleTy () ? getTypeFormatted<double >()
94
+ : " fp" ;
95
+
96
+ return getTypeFormatted<ConstantAsMetadata>();
97
+ }
98
+ }
99
+ if (isa<MDNode>(Op))
100
+ return getTypeFormatted<MDNode *>();
101
+
102
+ return " unknown" ;
103
+ }
104
+
105
+ // Helper function to simplify error reporting for invalid metadata values
106
+ template <typename ET>
107
+ auto reportInvalidTypeError (LLVMContext *Ctx, Twine ParamName,
108
+ const MDNode *Node, unsigned Index) {
109
+ std::string ExpectedType = getTypeFormatted<ET>();
110
+ std::string ActualType = getActualMDType (Node, Index);
111
+
112
+ return reportError (Ctx, " Root Signature Node: " + ParamName +
113
+ " expected metadata node of type " +
114
+ ExpectedType + " at index " + Twine (Index) +
115
+ " but got " + ActualType);
116
+ }
117
+
51
118
static std::optional<uint32_t > extractMdIntValue (MDNode *Node,
52
119
unsigned int OpId) {
53
120
if (auto *CI =
@@ -80,7 +147,8 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
80
147
if (std::optional<uint32_t > Val = extractMdIntValue (RootFlagNode, 1 ))
81
148
RSD.Flags = *Val;
82
149
else
83
- return reportError (Ctx, " Invalid value for RootFlag" );
150
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootFlagNode" ,
151
+ RootFlagNode, 1 );
84
152
85
153
return false ;
86
154
}
@@ -100,23 +168,27 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
100
168
if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 1 ))
101
169
Header.ShaderVisibility = *Val;
102
170
else
103
- return reportError (Ctx, " Invalid value for ShaderVisibility" );
171
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
172
+ RootConstantNode, 1 );
104
173
105
174
dxbc::RTS0::v1::RootConstants Constants;
106
175
if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 2 ))
107
176
Constants.ShaderRegister = *Val;
108
177
else
109
- return reportError (Ctx, " Invalid value for ShaderRegister" );
178
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
179
+ RootConstantNode, 2 );
110
180
111
181
if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 3 ))
112
182
Constants.RegisterSpace = *Val;
113
183
else
114
- return reportError (Ctx, " Invalid value for RegisterSpace" );
184
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
185
+ RootConstantNode, 3 );
115
186
116
187
if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 4 ))
117
188
Constants.Num32BitValues = *Val;
118
189
else
119
- return reportError (Ctx, " Invalid value for Num32BitValues" );
190
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
191
+ RootConstantNode, 4 );
120
192
121
193
RSD.ParametersContainer .addParameter (Header, Constants);
122
194
@@ -154,18 +226,21 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
154
226
if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
155
227
Header.ShaderVisibility = *Val;
156
228
else
157
- return reportError (Ctx, " Invalid value for ShaderVisibility" );
229
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
230
+ RootDescriptorNode, 1 );
158
231
159
232
dxbc::RTS0::v2::RootDescriptor Descriptor;
160
233
if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 2 ))
161
234
Descriptor.ShaderRegister = *Val;
162
235
else
163
- return reportError (Ctx, " Invalid value for ShaderRegister" );
236
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
237
+ RootDescriptorNode, 2 );
164
238
165
239
if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 3 ))
166
240
Descriptor.RegisterSpace = *Val;
167
241
else
168
- return reportError (Ctx, " Invalid value for RegisterSpace" );
242
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
243
+ RootDescriptorNode, 3 );
169
244
170
245
if (RSD.Version == 1 ) {
171
246
RSD.ParametersContainer .addParameter (Header, Descriptor);
@@ -176,7 +251,8 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
176
251
if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 4 ))
177
252
Descriptor.Flags = *Val;
178
253
else
179
- return reportError (Ctx, " Invalid value for Root Descriptor Flags" );
254
+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
255
+ RootDescriptorNode, 4 );
180
256
181
257
RSD.ParametersContainer .addParameter (Header, Descriptor);
182
258
return false ;
@@ -196,7 +272,8 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
196
272
extractMdStringValue (RangeDescriptorNode, 0 );
197
273
198
274
if (!ElementText.has_value ())
199
- return reportError (Ctx, " Descriptor Range, first element is not a string." );
275
+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
276
+ RangeDescriptorNode, 0 );
200
277
201
278
Range.RangeType =
202
279
StringSwitch<uint32_t >(*ElementText)
@@ -213,28 +290,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
213
290
if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 1 ))
214
291
Range.NumDescriptors = *Val;
215
292
else
216
- return reportError (Ctx, " Invalid value for Number of Descriptor in Range" );
293
+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
294
+ RangeDescriptorNode, 1 );
217
295
218
296
if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 2 ))
219
297
Range.BaseShaderRegister = *Val;
220
298
else
221
- return reportError (Ctx, " Invalid value for BaseShaderRegister" );
299
+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
300
+ RangeDescriptorNode, 2 );
222
301
223
302
if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 3 ))
224
303
Range.RegisterSpace = *Val;
225
304
else
226
- return reportError (Ctx, " Invalid value for RegisterSpace" );
305
+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
306
+ RangeDescriptorNode, 3 );
227
307
228
308
if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 4 ))
229
309
Range.OffsetInDescriptorsFromTableStart = *Val;
230
310
else
231
- return reportError (Ctx,
232
- " Invalid value for OffsetInDescriptorsFromTableStart " );
311
+ return reportInvalidTypeError<MDString> (Ctx, " RangeDescriptorNode " ,
312
+ RangeDescriptorNode, 4 );
233
313
234
314
if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 5 ))
235
315
Range.Flags = *Val;
236
316
else
237
- return reportError (Ctx, " Invalid value for Descriptor Range Flags" );
317
+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
318
+ RangeDescriptorNode, 5 );
238
319
239
320
Table.Ranges .push_back (Range);
240
321
return false ;
@@ -251,7 +332,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
251
332
if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
252
333
Header.ShaderVisibility = *Val;
253
334
else
254
- return reportError (Ctx, " Invalid value for ShaderVisibility" );
335
+ return reportInvalidTypeError<MDString>(Ctx, " DescriptorTableNode" ,
336
+ DescriptorTableNode, 1 );
255
337
256
338
mcdxbc::DescriptorTable Table;
257
339
Header.ParameterType =
@@ -260,7 +342,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
260
342
for (unsigned int I = 2 ; I < NumOperands; I++) {
261
343
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand (I));
262
344
if (Element == nullptr )
263
- return reportError (Ctx, " Missing Root Element Metadata Node." );
345
+ return reportInvalidTypeError<MDNode>(Ctx, " DescriptorTableNode" ,
346
+ DescriptorTableNode, I);
264
347
265
348
if (parseDescriptorRange (Ctx, RSD, Table, Element))
266
349
return true ;
0 commit comments