-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[IR2Vec][NFC] Add helper methods for numeric ID mapping in Vocabulary #149212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesAdd helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for (Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/149212.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3d7edf08c8807..d87457cac7642 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,6 +170,10 @@ class Vocabulary {
unsigned getDimension() const;
size_t size() const;
+ static size_t expectedSize() {
+ return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
+ }
+
/// Helper function to get vocabulary key for a given Opcode
static StringRef getVocabKeyForOpcode(unsigned Opcode);
@@ -182,6 +186,11 @@ class Vocabulary {
/// Helper function to classify an operand into OperandKind
static OperandKind getOperandKind(const Value *Op);
+ /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
+ static unsigned getNumericID(unsigned Opcode);
+ static unsigned getNumericID(Type::TypeID TypeID);
+ static unsigned getNumericID(const Value *Op);
+
/// Accessors to get the embedding for a given entity.
const ir2vec::Embedding &operator[](unsigned Opcode) const;
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 898bf5b202feb..95f30fd3f4275 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
: Vocab(std::move(Vocab)), Valid(true) {}
bool Vocabulary::isValid() const {
- return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+ return Vocab.size() == Vocabulary::expectedSize() && Valid;
}
size_t Vocabulary::size() const {
@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+unsigned Vocabulary::getNumericID(unsigned Opcode) {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(TypeID);
+}
+
+unsigned Vocabulary::getNumericID(const Value *Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxTypeIDs + Index;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
+ assert(Pos < Vocabulary::expectedSize() &&
"Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index cb6d633306a81..7c9a5464bfe1d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
}
+TEST(IR2VecVocabularyTest, NumericIDMap) {
+ // Test getNumericID for opcodes
+ EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
+ EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
+ EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
+
+ // Test getNumericID for Type IDs
+ EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
+
+ // Test getNumericID for Value operands
+ LLVMContext Ctx;
+ Module M("TestM", Ctx);
+ FunctionType *FTy =
+ FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+
+ // Test Function operand
+ EXPECT_EQ(Vocabulary::getNumericID(F),
+ MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+
+ // Test Constant operand
+ Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+ EXPECT_EQ(Vocabulary::getNumericID(C),
+ MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+
+ // Test Pointer operand
+ BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+ AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
+ EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
+ MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+
+ // Test Variable operand (function argument)
+ Argument *Arg = F->getArg(0);
+ EXPECT_EQ(Vocabulary::getNumericID(Arg),
+ MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
+ // Test invalid opcode IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+
+ // Test invalid type IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(
+ Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+ "Invalid type ID");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
|
6ae5021
to
3ad45e3
Compare
bc03736
to
68ae9f5
Compare
42671b8
to
a395af5
Compare
68ae9f5
to
1d7ca80
Compare
a395af5
to
586947a
Compare
01c6091
to
f24c6f1
Compare
Merge activity
|
f24c6f1
to
faf9baa
Compare
@svkeerthy This didn't get reviewed at all? |
Right. Pushed it as it was a minor refactoring. Feel free to add any comments. Will fix it. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/162/builds/27073 Here is the relevant piece of the build log for the reference
|
Add helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for
llvm-ir2vec
tool (See #149214).(Tracking issue - #141817)