From 02f1f21b8ecc608341440c573483e69c161a06d4 Mon Sep 17 00:00:00 2001 From: joaosaffran Date: Fri, 6 Jun 2025 20:04:00 +0000 Subject: [PATCH 1/2] changing error message --- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 119 +++++++++++++++--- ...re-RootConstants-Invalid-Num32BitValues.ll | 2 +- ...ure-RootConstants-Invalid-RegisterSpace.ll | 2 +- ...re-RootConstants-Invalid-ShaderRegister.ll | 2 +- 4 files changed, 104 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 3aef7d3eb1e69..3a27afc6c660f 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "DXILRootSignature.h" #include "DirectX.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" @@ -30,6 +31,7 @@ #include #include #include +#include #include using namespace llvm; @@ -48,6 +50,71 @@ static bool reportValueError(LLVMContext *Ctx, Twine ParamName, return true; } +// Template function to get formatted type string based on C++ type +template std::string getTypeFormatted() { + if constexpr (std::is_same_v) { + return "string"; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return "metadata"; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return "constant"; + } else if constexpr (std::is_same_v) { + return "constant"; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return "constant int"; + } else if constexpr (std::is_same_v) { + return "constant int"; + } + return "unknown"; +} + +// Helper function to get the actual type of a metadata operand +std::string getActualMDType(const MDNode *Node, unsigned Index) { + if (!Node || Index >= Node->getNumOperands()) + return "null"; + + Metadata *Op = Node->getOperand(Index); + if (!Op) + return "null"; + + if (isa(Op)) + return getTypeFormatted(); + + if (isa(Op)) { + if (auto *CAM = dyn_cast(Op)) { + Type *T = CAM->getValue()->getType(); + if (T->isIntegerTy()) + return (Twine("i") + Twine(T->getIntegerBitWidth())).str(); + if (T->isFloatingPointTy()) + return T->isFloatTy() ? getTypeFormatted() + : T->isDoubleTy() ? getTypeFormatted() + : "fp"; + + return getTypeFormatted(); + } + } + if (isa(Op)) + return getTypeFormatted(); + + return "unknown"; +} + +// Helper function to simplify error reporting for invalid metadata values +template +auto reportInvalidTypeError(LLVMContext *Ctx, Twine ParamName, + const MDNode *Node, unsigned Index) { + std::string ExpectedType = getTypeFormatted(); + std::string ActualType = getActualMDType(Node, Index); + + return reportError(Ctx, "Root Signature Node: " + ParamName + + " expected metadata node of type " + + ExpectedType + " at index " + Twine(Index) + + " but got " + ActualType); +} + static std::optional extractMdIntValue(MDNode *Node, unsigned int OpId) { if (auto *CI = @@ -80,7 +147,8 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (std::optional Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return reportError(Ctx, "Invalid value for RootFlag"); + return reportInvalidTypeError(Ctx, "RootFlagNode", + RootFlagNode, 1); return false; } @@ -100,23 +168,27 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (std::optional Val = extractMdIntValue(RootConstantNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return reportInvalidTypeError(Ctx, "RootConstantNode", + RootConstantNode, 1); dxbc::RTS0::v1::RootConstants Constants; if (std::optional Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return reportInvalidTypeError(Ctx, "RootConstantNode", + RootConstantNode, 2); if (std::optional Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return reportInvalidTypeError(Ctx, "RootConstantNode", + RootConstantNode, 3); if (std::optional Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return reportError(Ctx, "Invalid value for Num32BitValues"); + return reportInvalidTypeError(Ctx, "RootConstantNode", + RootConstantNode, 4); RSD.ParametersContainer.addParameter(Header, Constants); @@ -154,18 +226,21 @@ static bool parseRootDescriptors(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(RootDescriptorNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return reportInvalidTypeError(Ctx, "RootDescriptorNode", + RootDescriptorNode, 1); dxbc::RTS0::v2::RootDescriptor Descriptor; if (std::optional Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return reportInvalidTypeError(Ctx, "RootDescriptorNode", + RootDescriptorNode, 2); if (std::optional Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return reportInvalidTypeError(Ctx, "RootDescriptorNode", + RootDescriptorNode, 3); if (RSD.Version == 1) { RSD.ParametersContainer.addParameter(Header, Descriptor); @@ -176,7 +251,8 @@ static bool parseRootDescriptors(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return reportError(Ctx, "Invalid value for Root Descriptor Flags"); + return reportInvalidTypeError(Ctx, "RootDescriptorNode", + RootDescriptorNode, 4); RSD.ParametersContainer.addParameter(Header, Descriptor); return false; @@ -196,7 +272,8 @@ static bool parseDescriptorRange(LLVMContext *Ctx, extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Descriptor Range, first element is not a string."); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 0); Range.RangeType = StringSwitch(*ElementText) @@ -213,28 +290,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 1); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for BaseShaderRegister"); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 2); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 3); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportError(Ctx, - "Invalid value for OffsetInDescriptorsFromTableStart"); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 4); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportError(Ctx, "Invalid value for Descriptor Range Flags"); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 5); Table.Ranges.push_back(Range); return false; @@ -251,7 +332,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(DescriptorTableNode, 1)) Header.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return reportInvalidTypeError(Ctx, "DescriptorTableNode", + DescriptorTableNode, 1); mcdxbc::DescriptorTable Table; Header.ParameterType = @@ -260,7 +342,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx, for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return reportInvalidTypeError(Ctx, "DescriptorTableNode", + DescriptorTableNode, I); if (parseDescriptorRange(Ctx, RSD, Table, Element)) return true; diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll index 552c128e5ab57..0d5bbdfc097c4 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll @@ -2,7 +2,7 @@ target triple = "dxil-unknown-shadermodel6.0-compute" -; CHECK: error: Invalid value for Num32BitValues +; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 4 but got string ; CHECK-NOT: Root Signature Definitions define void @main() { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll index 1087b414942e2..1384da4baca98 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll @@ -2,7 +2,7 @@ target triple = "dxil-unknown-shadermodel6.0-compute" -; CHECK: error: Invalid value for RegisterSpace +; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 3 but got string ; CHECK-NOT: Root Signature Definitions define void @main() #0 { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll index 53fd924e8f46e..e1fd6a4414609 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll @@ -2,7 +2,7 @@ target triple = "dxil-unknown-shadermodel6.0-compute" -; CHECK: error: Invalid value for ShaderRegister +; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 2 but got string ; CHECK-NOT: Root Signature Definitions define void @main() #0 { From e62419f82edd38bb027f3451dc350ecb01b0be2c Mon Sep 17 00:00:00 2001 From: joaosaffran Date: Mon, 16 Jun 2025 19:50:29 +0000 Subject: [PATCH 2/2] clean up --- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 3a27afc6c660f..57d5ee8ac467c 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "DXILRootSignature.h" #include "DirectX.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" @@ -31,7 +30,6 @@ #include #include #include -#include #include using namespace llvm; @@ -290,32 +288,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportInvalidTypeError(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 1); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 1); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportInvalidTypeError(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 2); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 2); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportInvalidTypeError(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 3); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 3); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportInvalidTypeError(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 4); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 4); if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportInvalidTypeError(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 5); + return reportInvalidTypeError(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 5); Table.Ranges.push_back(Range); return false; @@ -332,8 +330,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx, if (std::optional Val = extractMdIntValue(DescriptorTableNode, 1)) Header.ShaderVisibility = *Val; else - return reportInvalidTypeError(Ctx, "DescriptorTableNode", - DescriptorTableNode, 1); + return reportInvalidTypeError(Ctx, "DescriptorTableNode", + DescriptorTableNode, 1); mcdxbc::DescriptorTable Table; Header.ParameterType = @@ -362,67 +360,80 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (std::optional Val = extractMdIntValue(StaticSamplerNode, 1)) Sampler.Filter = *Val; else - return reportError(Ctx, "Invalid value for Filter"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 1); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 2)) Sampler.AddressU = *Val; else - return reportError(Ctx, "Invalid value for AddressU"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 2); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 3)) Sampler.AddressV = *Val; else - return reportError(Ctx, "Invalid value for AddressV"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 3); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 4)) Sampler.AddressW = *Val; else - return reportError(Ctx, "Invalid value for AddressW"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 4); if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = Val->convertToFloat(); else - return reportError(Ctx, "Invalid value for MipLODBias"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 5); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return reportError(Ctx, "Invalid value for MaxAnisotropy"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 6); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 7)) Sampler.ComparisonFunc = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 7); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 8)) Sampler.BorderColor = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 8); if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = Val->convertToFloat(); else - return reportError(Ctx, "Invalid value for MinLOD"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 9); if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = Val->convertToFloat(); else - return reportError(Ctx, "Invalid value for MaxLOD"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 10); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 11); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 12); if (std::optional Val = extractMdIntValue(StaticSamplerNode, 13)) Sampler.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return reportInvalidTypeError(Ctx, "StaticSamplerNode", + StaticSamplerNode, 13); RSD.StaticSamplers.push_back(Sampler); return false;