Skip to content

Commit 80c29d9

Browse files
committed
[NFC][IR2Vec] Minor refactoring of opcode access in vocabulary
1 parent 70dcd29 commit 80c29d9

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,18 @@ class Vocabulary {
163163
static constexpr unsigned MaxOperandKinds =
164164
static_cast<unsigned>(OperandKind::MaxOperandKind);
165165

166+
/// Helper function to get vocabulary key for a given Opcode
167+
static StringRef getVocabKeyForOpcode(unsigned Opcode);
168+
169+
/// Helper function to get vocabulary key for a given TypeID
170+
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
171+
166172
/// Helper function to get vocabulary key for a given OperandKind
167173
static StringRef getVocabKeyForOperandKind(OperandKind Kind);
168174

169175
/// Helper function to classify an operand into OperandKind
170176
static OperandKind getOperandKind(const Value *Op);
171177

172-
/// Helper function to get vocabulary key for a given TypeID
173-
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
174-
175178
public:
176179
Vocabulary() = default;
177180
Vocabulary(VocabVector &&Vocab);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244244
}
245245

246+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248+
#define HANDLE_INST(NUM, OPCODE, CLASS) \
249+
if (Opcode == NUM) { \
250+
return #OPCODE; \
251+
}
252+
#include "llvm/IR/Instruction.def"
253+
#undef HANDLE_INST
254+
return "UnknownOpcode";
255+
}
256+
246257
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
247258
switch (TypeID) {
248259
case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
280291
default:
281292
return "UnknownTy";
282293
}
294+
return "UnknownTy";
283295
}
284296

285297
// Operand kinds supported by IR2Vec - string mappings
@@ -297,9 +309,9 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297309
OPERAND_KINDS
298310
#undef OPERAND_KIND
299311
case Vocabulary::OperandKind::MaxOperandKind:
300-
llvm_unreachable("Invalid OperandKind");
312+
return "UnknownOperand";
301313
}
302-
llvm_unreachable("Unknown OperandKind");
314+
return "UnknownOperand";
303315
}
304316

305317
#undef OPERAND_KINDS
@@ -332,14 +344,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
332344
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
333345
"Position out of bounds in vocabulary");
334346
// Opcode
335-
if (Pos < MaxOpcodes) {
336-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
337-
if (Pos == NUM - 1) { \
338-
return #OPCODE; \
339-
}
340-
#include "llvm/IR/Instruction.def"
341-
#undef HANDLE_INST
342-
}
347+
if (Pos < MaxOpcodes)
348+
return getVocabKeyForOpcode(Pos + 1);
343349
// Type
344350
if (Pos < MaxOpcodes + MaxTypeIDs)
345351
return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -447,21 +453,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
447453
// Handle Opcodes
448454
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
449455
Embedding(Dim, 0));
450-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
451-
{ \
452-
auto It = OpcVocab.find(#OPCODE); \
453-
if (It != OpcVocab.end()) \
454-
NumericOpcodeEmbeddings[NUM - 1] = It->second; \
455-
else \
456-
handleMissingEntity(#OPCODE); \
456+
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
457+
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
458+
auto It = OpcVocab.find(VocabKey.str());
459+
if (It != OpcVocab.end())
460+
NumericOpcodeEmbeddings[Opcode] = It->second;
461+
else
462+
handleMissingEntity(VocabKey.str());
457463
}
458-
#include "llvm/IR/Instruction.def"
459-
#undef HANDLE_INST
460464
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
461465
NumericOpcodeEmbeddings.end());
462466

463-
// Handle Types using direct iteration through TypeID enum
464-
// We iterate through all possible TypeID values and map them to embeddings
467+
// Handle Types
465468
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
466469
Embedding(Dim, 0));
467470
for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {

0 commit comments

Comments
 (0)