Skip to content

Commit 01c6091

Browse files
committed
exposing-new-methods
1 parent f295617 commit 01c6091

File tree

3 files changed

+90
-2
lines changed

3 files changed

+90
-2
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class Vocabulary {
170170
unsigned getDimension() const;
171171
size_t size() const;
172172

173+
static size_t expectedSize() {
174+
return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
175+
}
176+
173177
/// Helper function to get vocabulary key for a given Opcode
174178
static StringRef getVocabKeyForOpcode(unsigned Opcode);
175179

@@ -182,6 +186,11 @@ class Vocabulary {
182186
/// Helper function to classify an operand into OperandKind
183187
static OperandKind getOperandKind(const Value *Op);
184188

189+
/// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
190+
static unsigned getNumericID(unsigned Opcode);
191+
static unsigned getNumericID(Type::TypeID TypeID);
192+
static unsigned getNumericID(const Value *Op);
193+
185194
/// Accessors to get the embedding for a given entity.
186195
const ir2vec::Embedding &operator[](unsigned Opcode) const;
187196
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
215215
: Vocab(std::move(Vocab)), Valid(true) {}
216216

217217
bool Vocabulary::isValid() const {
218-
return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
218+
return Vocab.size() == Vocabulary::expectedSize() && Valid;
219219
}
220220

221221
size_t Vocabulary::size() const {
@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
324324
return OperandKind::VariableID;
325325
}
326326

327+
unsigned Vocabulary::getNumericID(unsigned Opcode) {
328+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
329+
return Opcode - 1; // Convert to zero-based index
330+
}
331+
332+
unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
333+
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
334+
return MaxOpcodes + static_cast<unsigned>(TypeID);
335+
}
336+
337+
unsigned Vocabulary::getNumericID(const Value *Op) {
338+
unsigned Index = static_cast<unsigned>(getOperandKind(Op));
339+
assert(Index < MaxOperandKinds && "Invalid OperandKind");
340+
return MaxOpcodes + MaxTypeIDs + Index;
341+
}
342+
327343
StringRef Vocabulary::getStringKey(unsigned Pos) {
328-
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
344+
assert(Pos < Vocabulary::expectedSize() &&
329345
"Position out of bounds in vocabulary");
330346
// Opcode
331347
if (Pos < MaxOpcodes)

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
396396
}
397397
}
398398

399+
TEST(IR2VecVocabularyTest, NumericIDMap) {
400+
// Test getNumericID for opcodes
401+
EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
402+
EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
403+
EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
404+
405+
// Test getNumericID for Type IDs
406+
EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
407+
MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
408+
EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
409+
MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
410+
EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
411+
MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
412+
EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
413+
MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
414+
EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
415+
MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
416+
417+
// Test getNumericID for Value operands
418+
LLVMContext Ctx;
419+
Module M("TestM", Ctx);
420+
FunctionType *FTy =
421+
FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
422+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
423+
424+
// Test Function operand
425+
EXPECT_EQ(Vocabulary::getNumericID(F),
426+
MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
427+
428+
// Test Constant operand
429+
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
430+
EXPECT_EQ(Vocabulary::getNumericID(C),
431+
MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
432+
433+
// Test Pointer operand
434+
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
435+
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
436+
EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
437+
MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
438+
439+
// Test Variable operand (function argument)
440+
Argument *Arg = F->getArg(0);
441+
EXPECT_EQ(Vocabulary::getNumericID(Arg),
442+
MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
443+
}
444+
445+
#if GTEST_HAS_DEATH_TEST
446+
#ifndef NDEBUG
447+
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
448+
// Test invalid opcode IDs
449+
EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
450+
EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
451+
452+
// Test invalid type IDs
453+
EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
454+
"Invalid type ID");
455+
EXPECT_DEATH(
456+
Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
457+
"Invalid type ID");
458+
}
459+
#endif // NDEBUG
460+
#endif // GTEST_HAS_DEATH_TEST
461+
399462
TEST(IR2VecVocabularyTest, StringKeyGeneration) {
400463
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
401464
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");

0 commit comments

Comments
 (0)