@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243
243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244
244
}
245
245
246
+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
247
+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
248
+ #define HANDLE_INST (NUM, OPCODE, CLASS ) \
249
+ if (Opcode == NUM) { \
250
+ return #OPCODE; \
251
+ }
252
+ #include " llvm/IR/Instruction.def"
253
+ #undef HANDLE_INST
254
+ return " UnknownOpcode" ;
255
+ }
256
+
246
257
StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247
258
switch (TypeID) {
248
259
case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
280
291
default :
281
292
return " UnknownTy" ;
282
293
}
294
+ return " UnknownTy" ;
283
295
}
284
296
285
297
// Operand kinds supported by IR2Vec - string mappings
@@ -297,9 +309,9 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297
309
OPERAND_KINDS
298
310
#undef OPERAND_KIND
299
311
case Vocabulary::OperandKind::MaxOperandKind:
300
- llvm_unreachable ( " Invalid OperandKind " ) ;
312
+ return " UnknownOperand " ;
301
313
}
302
- llvm_unreachable ( " Unknown OperandKind " ) ;
314
+ return " UnknownOperand " ;
303
315
}
304
316
305
317
#undef OPERAND_KINDS
@@ -332,14 +344,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
332
344
assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
333
345
" Position out of bounds in vocabulary" );
334
346
// Opcode
335
- if (Pos < MaxOpcodes) {
336
- #define HANDLE_INST (NUM, OPCODE, CLASS ) \
337
- if (Pos == NUM - 1 ) { \
338
- return #OPCODE; \
339
- }
340
- #include " llvm/IR/Instruction.def"
341
- #undef HANDLE_INST
342
- }
347
+ if (Pos < MaxOpcodes)
348
+ return getVocabKeyForOpcode (Pos + 1 );
343
349
// Type
344
350
if (Pos < MaxOpcodes + MaxTypeIDs)
345
351
return getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -447,21 +453,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
447
453
// Handle Opcodes
448
454
std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
449
455
Embedding (Dim, 0 ));
450
- # define HANDLE_INST ( NUM, OPCODE, CLASS ) \
451
- { \
452
- auto It = OpcVocab.find (#OPCODE); \
453
- if (It != OpcVocab.end ()) \
454
- NumericOpcodeEmbeddings[NUM - 1 ] = It->second ; \
455
- else \
456
- handleMissingEntity (#OPCODE); \
456
+ for ( unsigned Opcode : seq ( 0u , Vocabulary::MaxOpcodes)) {
457
+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode (Opcode + 1 );
458
+ auto It = OpcVocab.find (VocabKey. str ());
459
+ if (It != OpcVocab.end ())
460
+ NumericOpcodeEmbeddings[Opcode ] = It->second ;
461
+ else
462
+ handleMissingEntity (VocabKey. str ());
457
463
}
458
- #include " llvm/IR/Instruction.def"
459
- #undef HANDLE_INST
460
464
Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
461
465
NumericOpcodeEmbeddings.end ());
462
466
463
- // Handle Types using direct iteration through TypeID enum
464
- // We iterate through all possible TypeID values and map them to embeddings
467
+ // Handle Types
465
468
std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
466
469
Embedding (Dim, 0 ));
467
470
for (unsigned TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments