Skip to content

[IR2Vec] Minor vocab changes and exposing weights #143200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
#include <map>

namespace llvm {
Expand All @@ -43,6 +45,7 @@ class Function;
class Type;
class Value;
class raw_ostream;
class LLVMContext;

/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
Expand All @@ -53,6 +56,11 @@ class raw_ostream;
enum class IR2VecKind { Symbolic };

namespace ir2vec {

extern cl::opt<float> OpcWeight;
extern cl::opt<float> TypeWeight;
extern cl::opt<float> ArgWeight;

/// Embedding is a datatype that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
/// It is meant to be used *like* std::vector<double> but is more restrictive
Expand Down Expand Up @@ -226,10 +234,13 @@ class IR2VecVocabResult {
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
Error readVocabulary();
void emitError(Error Err, LLVMContext &Ctx);

public:
static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab);
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
using Result = IR2VecVocabResult;
Result run(Module &M, ModuleAnalysisManager &MAM);
};
Expand Down
82 changes: 51 additions & 31 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"

using namespace llvm;
Expand All @@ -33,25 +31,26 @@ using namespace ir2vec;
STATISTIC(VocabMissCounter,
"Number of lookups to entites not present in the vocabulary");

namespace llvm {
namespace ir2vec {
static cl::OptionCategory IR2VecCategory("IR2Vec Options");

// FIXME: Use a default vocab when not specified
static cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
cl::init(1.0),
cl::desc("Weight for opcode embeddings"),
cl::cat(IR2VecCategory));
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
cl::init(0.5),
cl::desc("Weight for type embeddings"),
cl::cat(IR2VecCategory));
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for opcode embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::desc("Weight for type embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
} // namespace ir2vec
} // namespace llvm

AnalysisKey IR2VecVocabAnalysis::Key;

Expand Down Expand Up @@ -251,49 +250,70 @@ bool IR2VecVocabResult::invalidate(
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
if (!BufOrError) {
if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
}

auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");
Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();

bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
if (!Res) {
if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
}
assert(Vocabulary.size() > 0 && "Vocabulary is empty");

if (Vocabulary.empty())
return createStringError(errc::illegal_byte_sequence,
"Vocabulary is empty");

unsigned Dim = Vocabulary.begin()->second.size();
assert(Dim > 0 && "Dimension of vocabulary is zero");
(void)Dim;
assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}) &&
"All vectors in the vocabulary are not of the same dimension");
if (Dim == 0)
return createStringError(errc::illegal_byte_sequence,
"Dimension of vocabulary is zero");

if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the vocabulary are not of the same dimension");

return Error::success();
}

IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary)
: Vocabulary(Vocabulary) {}

IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
: Vocabulary(std::move(Vocabulary)) {}

void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
Ctx.emitError("Error reading vocabulary: " + EI.message());
});
}

IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
// If vocabulary is already populated by the constructor, use it.
if (!Vocabulary.empty())
return IR2VecVocabResult(std::move(Vocabulary));

// Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified");
return IR2VecVocabResult(); // Return invalid result
}
if (auto Err = readVocabulary()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
Ctx->emitError("Error reading vocabulary: " + EI.message());
});
emitError(std::move(Err), *Ctx);
return IR2VecVocabResult();
}
// FIXME: Scale the vocabulary here once. This would avoid scaling per use
// later.
return IR2VecVocabResult(std::move(Vocabulary));
}

Expand Down
137 changes: 102 additions & 35 deletions llvm/unittests/Analysis/IR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
EXPECT_EQ(validResult.getDimension(), 2u);
}

// Helper to create a minimal function and embedder for getter tests
struct GetterTestEnv {
Vocab V = {};
// Fixture for IR2Vec tests requiring IR setup and weight management.
class IR2VecTestFixture : public ::testing::Test {
protected:
Vocab V;
LLVMContext Ctx;
std::unique_ptr<Module> M = nullptr;
std::unique_ptr<Module> M;
Function *F = nullptr;
BasicBlock *BB = nullptr;
Instruction *Add = nullptr;
Instruction *Ret = nullptr;
std::unique_ptr<Embedder> Emb = nullptr;
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;

GetterTestEnv() {
float OriginalOpcWeight = ::OpcWeight;
float OriginalTypeWeight = ::TypeWeight;
float OriginalArgWeight = ::ArgWeight;

void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.5, 0.5}},
{"constant", {0.2, 0.3}},
{"variable", {0.0, 0.0}},
{"unknownTy", {0.0, 0.0}}};

M = std::make_unique<Module>("M", Ctx);
// Setup IR
M = std::make_unique<Module>("TestM", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
Expand All @@ -308,61 +313,82 @@ struct GetterTestEnv {
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);

Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
Ret = ReturnInst::Create(Ctx, Add, BB);
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
}

void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
::OpcWeight = OpcWeight;
::TypeWeight = TypeWeight;
::ArgWeight = ArgWeight;
}

auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
Emb = std::move(*Result);
void TearDown() override {
// Restore original global weights
::OpcWeight = OriginalOpcWeight;
::TypeWeight = OriginalTypeWeight;
::ArgWeight = OriginalArgWeight;
}
};

TEST(IR2VecTest, GetInstVecMap) {
GetterTestEnv Env;
const auto &InstMap = Env.Emb->getInstVecMap();
TEST_F(IR2VecTestFixture, GetInstVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);

const auto &InstMap = Emb->getInstVecMap();

EXPECT_EQ(InstMap.size(), 2u);
EXPECT_TRUE(InstMap.count(Env.Add));
EXPECT_TRUE(InstMap.count(Env.Ret));
EXPECT_TRUE(InstMap.count(AddInst));
EXPECT_TRUE(InstMap.count(RetInst));

EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);

// Check values for add: {1.29, 2.31}
EXPECT_THAT(InstMap.at(Env.Add),
EXPECT_THAT(InstMap.at(AddInst),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));

// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
// vocab
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
}

TEST(IR2VecTest, GetBBVecMap) {
GetterTestEnv Env;
const auto &BBMap = Env.Emb->getBBVecMap();
TEST_F(IR2VecTestFixture, GetBBVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);

const auto &BBMap = Emb->getBBVecMap();

EXPECT_EQ(BBMap.size(), 1u);
EXPECT_TRUE(BBMap.count(Env.BB));
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);

// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
// {1.29, 2.31}
EXPECT_THAT(BBMap.at(Env.BB),
EXPECT_THAT(BBMap.at(BB),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}

TEST(IR2VecTest, GetBBVector) {
GetterTestEnv Env;
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
TEST_F(IR2VecTestFixture, GetBBVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);

const auto &BBVec = Emb->getBBVector(*BB);

EXPECT_EQ(BBVec.size(), 2u);
EXPECT_THAT(BBVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}

TEST(IR2VecTest, GetFunctionVector) {
GetterTestEnv Env;
const auto &FuncVec = Env.Emb->getFunctionVector();
TEST_F(IR2VecTestFixture, GetFunctionVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);

const auto &FuncVec = Emb->getFunctionVector();

EXPECT_EQ(FuncVec.size(), 2u);

Expand All @@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) {
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}

TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
setWeights(1.0, 1.0, 1.0);

auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);

const auto &FuncVec = Emb->getFunctionVector();

EXPECT_EQ(FuncVec.size(), 2u);

// Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
// 0.3] + [0.0 0.0])
EXPECT_THAT(FuncVec,
ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
}

TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
Vocab ExpectedVocab = InitialVocab;
unsigned ExpectedDim = InitialVocab.begin()->second.size();

IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));

LLVMContext TestCtx;
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
ModuleAnalysisManager MAM;
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);

EXPECT_TRUE(Result.isValid());
ASSERT_FALSE(Result.getVocabulary().empty());
EXPECT_EQ(Result.getDimension(), ExpectedDim);

const auto &ResultVocab = Result.getVocabulary();
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
for (const auto &pair : ExpectedVocab) {
EXPECT_TRUE(ResultVocab.count(pair.first));
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
}
}

} // end anonymous namespace
Loading