31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/ErrorOr.h"
36
37
#include " llvm/Support/JSON.h"
@@ -42,10 +43,10 @@ class Module;
42
43
class BasicBlock ;
43
44
class Instruction ;
44
45
class Function ;
45
- class Type ;
46
46
class Value ;
47
47
class raw_ostream ;
48
48
class LLVMContext ;
49
+ class IR2VecVocabAnalysis ;
49
50
50
51
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
51
52
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -124,9 +125,73 @@ struct Embedding {
124
125
125
126
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
126
127
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
127
- // FIXME: Current the keys are strings. This can be changed to
128
- // use integers for cheaper lookups.
129
- using Vocab = std::map<std::string, Embedding>;
128
+
129
+ // / Class for storing and accessing the IR2Vec vocabulary.
130
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
131
+ class Vocabulary {
132
+ friend class llvm ::IR2VecVocabAnalysis;
133
+ using VocabVector = std::vector<ir2vec::Embedding>;
134
+ VocabVector Vocab;
135
+ bool Valid = false ;
136
+
137
+ // / Operand kinds supported by IR2Vec Vocabulary
138
+ #define OPERAND_KINDS \
139
+ OPERAND_KIND (FunctionID, " Function" ) \
140
+ OPERAND_KIND (PointerID, " Pointer" ) \
141
+ OPERAND_KIND (ConstantID, " Constant" ) \
142
+ OPERAND_KIND (VariableID, " Variable" )
143
+
144
+ enum class OperandKind : unsigned {
145
+ #define OPERAND_KIND (Name, Str ) Name,
146
+ OPERAND_KINDS
147
+ #undef OPERAND_KIND
148
+ MaxOperandKind
149
+ };
150
+
151
+ #undef OPERAND_KINDS
152
+
153
+ // / Vocabulary layout constants
154
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
155
+ #include " llvm/IR/Instruction.def"
156
+ #undef LAST_OTHER_INST
157
+
158
+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
159
+ static constexpr unsigned MaxOperandKinds =
160
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
161
+
162
+ // / Helper function to get vocabulary key for a given OperandKind
163
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
164
+
165
+ // / Helper function to classify an operand into OperandKind
166
+ static OperandKind getOperandKind (const Value *Op);
167
+
168
+ // / Helper function to get vocabulary key for a given TypeID
169
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
170
+
171
+ public:
172
+ Vocabulary () = default ;
173
+ Vocabulary (VocabVector &&Vocab);
174
+
175
+ bool isValid () const ;
176
+ unsigned getDimension () const ;
177
+ unsigned size () const ;
178
+
179
+ const ir2vec::Embedding &at (unsigned Position) const ;
180
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
181
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
182
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
183
+
184
+ // / Returns the string key for a given index position in the vocabulary.
185
+ // / This is useful for debugging or printing the vocabulary. Do not use this
186
+ // / for embedding generation as string based lookups are inefficient.
187
+ static StringRef getStringKey (unsigned Pos);
188
+
189
+ // / Create a dummy vocabulary for testing purposes.
190
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
191
+
192
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
193
+ ModuleAnalysisManager::Invalidator &Inv) const ;
194
+ };
130
195
131
196
// / Embedder provides the interface to generate embeddings (vector
132
197
// / representations) for instructions, basic blocks, and functions. The
@@ -137,7 +202,7 @@ using Vocab = std::map<std::string, Embedding>;
137
202
class Embedder {
138
203
protected:
139
204
const Function &F;
140
- const Vocab &Vocabulary ;
205
+ const Vocabulary &Vocab ;
141
206
142
207
// / Dimension of the vector representation; captured from the input vocabulary
143
208
const unsigned Dimension;
@@ -152,7 +217,7 @@ class Embedder {
152
217
mutable BBEmbeddingsMap BBVecMap;
153
218
mutable InstEmbeddingsMap InstVecMap;
154
219
155
- Embedder (const Function &F, const Vocab &Vocabulary );
220
+ Embedder (const Function &F, const Vocabulary &Vocab );
156
221
157
222
// / Helper function to compute embeddings. It generates embeddings for all
158
223
// / the instructions and basic blocks in the function F. Logic of computing
@@ -163,16 +228,12 @@ class Embedder {
163
228
// / Specific to the kind of embeddings being computed.
164
229
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
165
230
166
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
167
- // / zero vector.
168
- Embedding lookupVocab (const std::string &Key) const ;
169
-
170
231
public:
171
232
virtual ~Embedder () = default ;
172
233
173
234
// / Factory method to create an Embedder object.
174
235
static std::unique_ptr<Embedder> create (IR2VecKind Mode, const Function &F,
175
- const Vocab &Vocabulary );
236
+ const Vocabulary &Vocab );
176
237
177
238
// / Returns a map containing instructions and the corresponding embeddings for
178
239
// / the function F if it has been computed. If not, it computes the embeddings
@@ -198,56 +259,40 @@ class Embedder {
198
259
// / representations obtained from the Vocabulary.
199
260
class SymbolicEmbedder : public Embedder {
200
261
private:
201
- // / Utility function to compute the embedding for a given type.
202
- Embedding getTypeEmbedding (const Type *Ty) const ;
203
-
204
- // / Utility function to compute the embedding for a given operand.
205
- Embedding getOperandEmbedding (const Value *Op) const ;
206
-
207
262
void computeEmbeddings () const override ;
208
263
void computeEmbeddings (const BasicBlock &BB) const override ;
209
264
210
265
public:
211
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
212
- : Embedder(F, Vocabulary ) {
266
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
267
+ : Embedder(F, Vocab ) {
213
268
FuncVector = Embedding (Dimension, 0 );
214
269
}
215
270
};
216
271
217
272
} // namespace ir2vec
218
273
219
- // / Class for storing the result of the IR2VecVocabAnalysis.
220
- class IR2VecVocabResult {
221
- ir2vec::Vocab Vocabulary;
222
- bool Valid = false ;
223
-
224
- public:
225
- IR2VecVocabResult () = default ;
226
- IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
227
-
228
- bool isValid () const { return Valid; }
229
- const ir2vec::Vocab &getVocabulary () const ;
230
- unsigned getDimension () const ;
231
- bool invalidate (Module &M, const PreservedAnalyses &PA,
232
- ModuleAnalysisManager::Invalidator &Inv) const ;
233
- };
234
-
235
274
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
236
275
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
237
276
// / its corresponding embedding.
238
277
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
239
- ir2vec::Vocab Vocabulary;
278
+ using VocabVector = std::vector<ir2vec::Embedding>;
279
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
280
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
281
+ VocabVector Vocab;
282
+
283
+ unsigned Dim = 0 ;
240
284
Error readVocabulary ();
241
285
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
242
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
286
+ VocabMap &TargetVocab, unsigned &Dim);
287
+ void generateNumMappedVocab ();
243
288
void emitError (Error Err, LLVMContext &Ctx);
244
289
245
290
public:
246
291
static AnalysisKey Key;
247
292
IR2VecVocabAnalysis () = default ;
248
- explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
249
- explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
250
- using Result = IR2VecVocabResult ;
293
+ explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
294
+ explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
295
+ using Result = ir2vec::Vocabulary ;
251
296
Result run (Module &M, ModuleAnalysisManager &MAM);
252
297
};
253
298
0 commit comments