9
9
// / \file
10
10
// / This file implements the IR2Vec embedding generation tool.
11
11
// /
12
- // / Currently supports triplet generation for vocabulary training.
13
- // / Future updates will support embedding generation using trained vocabulary.
12
+ // / This tool provides two main functionalities:
14
13
// /
15
- // / Usage: llvm-ir2vec input.bc -o triplets.txt
14
+ // / 1. Triplet Generation Mode (--mode=triplets):
15
+ // / Generates triplets (opcode, type, operands) for vocabulary training.
16
+ // / Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
16
17
// /
17
- // / TODO: Add embedding generation mode with vocabulary support
18
+ // / 2. Embedding Generation Mode (--mode=embeddings):
19
+ // / Generates IR2Vec embeddings using a trained vocabulary.
20
+ // / Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
21
+ // / --level=func input.bc -o embeddings.txt Levels: --level=inst
22
+ // / (instructions), --level=bb (basic blocks), --level=func (functions)
23
+ // / (See IR2Vec.cpp for more embedding generation options)
18
24
// /
19
25
// ===----------------------------------------------------------------------===//
20
26
24
30
#include " llvm/IR/Instructions.h"
25
31
#include " llvm/IR/LLVMContext.h"
26
32
#include " llvm/IR/Module.h"
33
+ #include " llvm/IR/PassInstrumentation.h"
34
+ #include " llvm/IR/PassManager.h"
27
35
#include " llvm/IR/Type.h"
28
36
#include " llvm/IRReader/IRReader.h"
29
37
#include " llvm/Support/CommandLine.h"
34
42
#include " llvm/Support/raw_ostream.h"
35
43
36
44
using namespace llvm ;
37
- using namespace ir2vec ;
45
+ using namespace llvm :: ir2vec;
38
46
39
47
#define DEBUG_TYPE " ir2vec"
40
48
@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
50
58
cl::init(" -" ),
51
59
cl::cat(IR2VecToolCategory));
52
60
61
+ enum ToolMode {
62
+ TripletMode, // Generate triplets for vocabulary training
63
+ EmbeddingMode // Generate embeddings using trained vocabulary
64
+ };
65
+
66
+ static cl::opt<ToolMode>
67
+ Mode (" mode" , cl::desc(" Tool operation mode:" ),
68
+ cl::values(clEnumValN(TripletMode, " triplets" ,
69
+ " Generate triplets for vocabulary training" ),
70
+ clEnumValN(EmbeddingMode, " embeddings" ,
71
+ " Generate embeddings using trained vocabulary" )),
72
+ cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
73
+
74
+ static cl::opt<std::string>
75
+ FunctionName (" function" , cl::desc(" Process specific function only" ),
76
+ cl::value_desc(" name" ), cl::Optional, cl::init(" " ),
77
+ cl::cat(IR2VecToolCategory));
78
+
79
+ enum EmbeddingLevel {
80
+ InstructionLevel, // Generate instruction-level embeddings
81
+ BasicBlockLevel, // Generate basic block-level embeddings
82
+ FunctionLevel // Generate function-level embeddings
83
+ };
84
+
85
+ static cl::opt<EmbeddingLevel>
86
+ Level (" level" , cl::desc(" Embedding generation level (for embedding mode):" ),
87
+ cl::values(clEnumValN(InstructionLevel, " inst" ,
88
+ " Generate instruction-level embeddings" ),
89
+ clEnumValN(BasicBlockLevel, " bb" ,
90
+ " Generate basic block-level embeddings" ),
91
+ clEnumValN(FunctionLevel, " func" ,
92
+ " Generate function-level embeddings" )),
93
+ cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
94
+
53
95
namespace {
54
96
55
- // / Helper class for collecting IR information and generating triplets
97
+ // / Helper class for collecting IR triplets and generating embeddings
56
98
class IR2VecTool {
57
99
private:
58
100
Module &M;
101
+ ModuleAnalysisManager MAM;
102
+ const Vocabulary *Vocab = nullptr ;
59
103
60
104
public:
61
105
explicit IR2VecTool (Module &M) : M(M) {}
62
106
107
+ // / Initialize the IR2Vec vocabulary analysis
108
+ bool initializeVocabulary () {
109
+ // Register and run the IR2Vec vocabulary analysis
110
+ // The vocabulary file path is specified via --ir2vec-vocab-path global
111
+ // option
112
+ MAM.registerPass ([&] { return PassInstrumentationAnalysis (); });
113
+ MAM.registerPass ([&] { return IR2VecVocabAnalysis (); });
114
+ Vocab = &MAM.getResult <IR2VecVocabAnalysis>(M);
115
+ return Vocab->isValid ();
116
+ }
117
+
63
118
// / Generate triplets for the entire module
64
119
void generateTriplets (raw_ostream &OS) const {
65
120
for (const Function &F : M)
@@ -81,6 +136,68 @@ class IR2VecTool {
81
136
OS << LocalOutput;
82
137
}
83
138
139
+ // / Generate embeddings for the entire module
140
+ void generateEmbeddings (raw_ostream &OS) const {
141
+ if (!Vocab->isValid ()) {
142
+ OS << " Error: Vocabulary is not valid. IR2VecTool not initialized.\n " ;
143
+ return ;
144
+ }
145
+
146
+ for (const Function &F : M)
147
+ generateEmbeddings (F, OS);
148
+ }
149
+
150
+ // / Generate embeddings for a single function
151
+ void generateEmbeddings (const Function &F, raw_ostream &OS) const {
152
+ if (F.isDeclaration ()) {
153
+ OS << " Function " << F.getName () << " is a declaration, skipping.\n " ;
154
+ return ;
155
+ }
156
+
157
+ // Create embedder for this function
158
+ assert (Vocab->isValid () && " Vocabulary is not valid" );
159
+ auto Emb = Embedder::create (IR2VecKind::Symbolic, F, *Vocab);
160
+ if (!Emb) {
161
+ OS << " Error: Failed to create embedder for function " << F.getName ()
162
+ << " \n " ;
163
+ return ;
164
+ }
165
+
166
+ OS << " Function: " << F.getName () << " \n " ;
167
+
168
+ // Generate embeddings based on the specified level
169
+ switch (Level) {
170
+ case FunctionLevel: {
171
+ Emb->getFunctionVector ().print (OS);
172
+ break ;
173
+ }
174
+ case BasicBlockLevel: {
175
+ const auto &BBVecMap = Emb->getBBVecMap ();
176
+ for (const BasicBlock &BB : F) {
177
+ auto It = BBVecMap.find (&BB);
178
+ if (It != BBVecMap.end ()) {
179
+ OS << BB.getName () << " :" ;
180
+ It->second .print (OS);
181
+ }
182
+ }
183
+ break ;
184
+ }
185
+ case InstructionLevel: {
186
+ const auto &InstMap = Emb->getInstVecMap ();
187
+ for (const BasicBlock &BB : F) {
188
+ for (const Instruction &I : BB) {
189
+ auto It = InstMap.find (&I);
190
+ if (It != InstMap.end ()) {
191
+ I.print (OS);
192
+ It->second .print (OS);
193
+ }
194
+ }
195
+ }
196
+ break ;
197
+ }
198
+ }
199
+ }
200
+
84
201
private:
85
202
// / Process a single basic block for triplet generation
86
203
void traverseBasicBlock (const BasicBlock &BB, raw_string_ostream &OS) const {
@@ -105,8 +222,42 @@ class IR2VecTool {
105
222
106
223
Error processModule (Module &M, raw_ostream &OS) {
107
224
IR2VecTool Tool (M);
108
- Tool.generateTriplets (OS);
109
225
226
+ if (Mode == EmbeddingMode) {
227
+ // Initialize vocabulary for embedding generation
228
+ // Note: Requires --ir2vec-vocab-path option to be set
229
+ if (!Tool.initializeVocabulary ())
230
+ return createStringError (
231
+ errc::invalid_argument,
232
+ " Failed to initialize IR2Vec vocabulary. "
233
+ " Make sure to specify --ir2vec-vocab-path for embedding mode." );
234
+
235
+ if (!FunctionName.empty ()) {
236
+ // Process single function
237
+ if (const Function *F = M.getFunction (FunctionName))
238
+ Tool.generateEmbeddings (*F, OS);
239
+ else
240
+ return createStringError (errc::invalid_argument,
241
+ " Function '%s' not found" ,
242
+ FunctionName.c_str ());
243
+ } else {
244
+ // Process all functions
245
+ Tool.generateEmbeddings (OS);
246
+ }
247
+ } else {
248
+ // Triplet generation mode - no vocabulary needed
249
+ if (!FunctionName.empty ())
250
+ // Process single function
251
+ if (const Function *F = M.getFunction (FunctionName))
252
+ Tool.generateTriplets (*F, OS);
253
+ else
254
+ return createStringError (errc::invalid_argument,
255
+ " Function '%s' not found" ,
256
+ FunctionName.c_str ());
257
+ else
258
+ // Process all functions
259
+ Tool.generateTriplets (OS);
260
+ }
110
261
return Error::success ();
111
262
}
112
263
@@ -117,11 +268,21 @@ int main(int argc, char **argv) {
117
268
cl::HideUnrelatedOptions (IR2VecToolCategory);
118
269
cl::ParseCommandLineOptions (
119
270
argc, argv,
120
- " IR2Vec - Triplet Generation Tool\n "
121
- " Generates triplets for vocabulary training from LLVM IR.\n "
122
- " Future updates will support embedding generation.\n\n "
271
+ " IR2Vec - Embedding Generation Tool\n "
272
+ " Generates embeddings for a given LLVM IR and "
273
+ " supports triplet generation for vocabulary "
274
+ " training and embedding generation.\n\n "
123
275
" Usage:\n "
124
- " llvm-ir2vec input.bc -o triplets.txt\n " );
276
+ " Triplet mode: llvm-ir2vec --mode=triplets input.bc\n "
277
+ " Embedding mode: llvm-ir2vec --mode=embeddings "
278
+ " --ir2vec-vocab-path=vocab.json --level=func input.bc\n "
279
+ " Levels: --level=inst (instructions), --level=bb (basic blocks), "
280
+ " --level=func (functions)\n " );
281
+
282
+ // Validate command line options
283
+ if (Mode == TripletMode && Level != FunctionLevel) {
284
+ errs () << " Warning: --level option is ignored in triplet mode\n " ;
285
+ }
125
286
126
287
// Parse the input LLVM IR file
127
288
SMDiagnostic Err;
0 commit comments