-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[IR2Vec] Scale embeddings once in vocab analysis instead of repetitive scaling #143986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) { | |
return *this; | ||
} | ||
|
||
Embedding &Embedding::operator*=(double Factor) { | ||
std::transform(this->begin(), this->end(), this->begin(), | ||
[Factor](double Elem) { return Elem * Factor; }); | ||
return *this; | ||
} | ||
|
||
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) { | ||
assert(this->size() == Src.size() && "Vectors must have the same dimension"); | ||
for (size_t Itr = 0; Itr < this->size(); ++Itr) | ||
|
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS, | |
return true; | ||
} | ||
|
||
void Embedding::print(raw_ostream &OS) const { | ||
OS << " ["; | ||
for (const auto &Elem : Data) | ||
OS << " " << format("%.2f", Elem) << " "; | ||
OS << "]\n"; | ||
} | ||
|
||
// ==----------------------------------------------------------------------===// | ||
// Embedder and its subclasses | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { | |
for (const auto &I : BB.instructionsWithoutDebug()) { | ||
Embedding InstVector(Dimension, 0); | ||
|
||
const auto OpcVec = lookupVocab(I.getOpcodeName()); | ||
InstVector.scaleAndAdd(OpcVec, OpcWeight); | ||
|
||
// FIXME: Currently lookups are string based. Use numeric Keys | ||
// for efficiency. | ||
const auto Type = I.getType(); | ||
const auto TypeVec = getTypeEmbedding(Type); | ||
InstVector.scaleAndAdd(TypeVec, TypeWeight); | ||
|
||
InstVector += lookupVocab(I.getOpcodeName()); | ||
InstVector += getTypeEmbedding(I.getType()); | ||
for (const auto &Op : I.operands()) { | ||
const auto OperandVec = getOperandEmbedding(Op.get()); | ||
InstVector.scaleAndAdd(OperandVec, ArgWeight); | ||
InstVector += getOperandEmbedding(Op.get()); | ||
} | ||
InstVecMap[&I] = InstVector; | ||
BBVector += InstVector; | ||
|
@@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate( | |
return !(PAC.preservedWhenStateless()); | ||
} | ||
|
||
Error IR2VecVocabAnalysis::parseVocabSection( | ||
StringRef Key, const json::Value &ParsedVocabValue, | ||
ir2vec::Vocab &TargetVocab, unsigned &Dim) { | ||
json::Path::Root Path(""); | ||
const json::Object *RootObj = ParsedVocabValue.getAsObject(); | ||
if (!RootObj) | ||
return createStringError(errc::invalid_argument, | ||
"JSON root is not an object"); | ||
|
||
const json::Value *SectionValue = RootObj->get(Key); | ||
if (!SectionValue) | ||
return createStringError(errc::invalid_argument, | ||
"Missing '" + std::string(Key) + | ||
"' section in vocabulary file"); | ||
if (!json::fromJSON(*SectionValue, TargetVocab, Path)) | ||
return createStringError(errc::illegal_byte_sequence, | ||
"Unable to parse '" + std::string(Key) + | ||
"' section from vocabulary"); | ||
|
||
Dim = TargetVocab.begin()->second.size(); | ||
if (Dim == 0) | ||
return createStringError(errc::illegal_byte_sequence, | ||
"Dimension of '" + std::string(Key) + | ||
"' section of the vocabulary is zero"); | ||
|
||
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), | ||
[Dim](const std::pair<StringRef, Embedding> &Entry) { | ||
return Entry.second.size() == Dim; | ||
})) | ||
return createStringError( | ||
errc::illegal_byte_sequence, | ||
"All vectors in the '" + std::string(Key) + | ||
"' section of the vocabulary are not of the same dimension"); | ||
|
||
return Error::success(); | ||
}; | ||
|
||
// FIXME: Make this optional. We can avoid file reads | ||
// by auto-generating a default vocabulary during the build time. | ||
Error IR2VecVocabAnalysis::readVocabulary() { | ||
|
@@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() { | |
return createFileError(VocabFile, BufOrError.getError()); | ||
|
||
auto Content = BufOrError.get()->getBuffer(); | ||
json::Path::Root Path(""); | ||
|
||
Expected<json::Value> ParsedVocabValue = json::parse(Content); | ||
if (!ParsedVocabValue) | ||
return ParsedVocabValue.takeError(); | ||
|
||
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path); | ||
if (!Res) | ||
return createStringError(errc::illegal_byte_sequence, | ||
"Unable to parse the vocabulary"); | ||
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab; | ||
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; | ||
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This changes the format, best to also update the doc. Also, this means the sections must all be present (in any order), even if empty, correct? SGTM, just something worth spelling out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. Will put it in the doc. |
||
OpcodeDim)) | ||
return Err; | ||
|
||
if (Vocabulary.empty()) | ||
return createStringError(errc::illegal_byte_sequence, | ||
"Vocabulary is empty"); | ||
if (auto Err = | ||
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim)) | ||
return Err; | ||
|
||
unsigned Dim = Vocabulary.begin()->second.size(); | ||
if (Dim == 0) | ||
if (auto Err = | ||
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim)) | ||
return Err; | ||
|
||
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) | ||
return createStringError(errc::illegal_byte_sequence, | ||
"Dimension of vocabulary is zero"); | ||
"Vocabulary sections have different dimensions"); | ||
|
||
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(), | ||
[Dim](const std::pair<StringRef, Embedding> &Entry) { | ||
return Entry.second.size() == Dim; | ||
})) | ||
return createStringError( | ||
errc::illegal_byte_sequence, | ||
"All vectors in the vocabulary are not of the same dimension"); | ||
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) { | ||
for (auto &Entry : Vocab) | ||
Entry.second *= Weight; | ||
}; | ||
scaleVocabSection(OpcodeVocab, OpcWeight); | ||
scaleVocabSection(TypeVocab, TypeWeight); | ||
scaleVocabSection(ArgVocab, ArgWeight); | ||
|
||
Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end()); | ||
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end()); | ||
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end()); | ||
|
||
return Error::success(); | ||
} | ||
|
@@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { | |
IR2VecVocabAnalysis::Result | ||
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { | ||
auto Ctx = &M.getContext(); | ||
// FIXME: Scale the vocabulary once. This would avoid scaling per use later. | ||
// If vocabulary is already populated by the constructor, use it. | ||
if (!Vocabulary.empty()) | ||
return IR2VecVocabResult(std::move(Vocabulary)); | ||
|
@@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { | |
} | ||
|
||
// ==----------------------------------------------------------------------===// | ||
// IR2VecPrinterPass | ||
// Printer Passes | ||
//===----------------------------------------------------------------------===// | ||
|
||
void IR2VecPrinterPass::printVector(const Embedding &Vec) const { | ||
OS << " ["; | ||
for (const auto &Elem : Vec) | ||
OS << " " << format("%.2f", Elem) << " "; | ||
OS << "]\n"; | ||
} | ||
|
||
PreservedAnalyses IR2VecPrinterPass::run(Module &M, | ||
ModuleAnalysisManager &MAM) { | ||
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M); | ||
|
@@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, | |
|
||
OS << "IR2Vec embeddings for function " << F.getName() << ":\n"; | ||
OS << "Function vector: "; | ||
printVector(Emb->getFunctionVector()); | ||
Emb->getFunctionVector().print(OS); | ||
|
||
OS << "Basic block vectors:\n"; | ||
const auto &BBMap = Emb->getBBVecMap(); | ||
for (const BasicBlock &BB : F) { | ||
auto It = BBMap.find(&BB); | ||
if (It != BBMap.end()) { | ||
OS << "Basic block: " << BB.getName() << ":\n"; | ||
printVector(It->second); | ||
It->second.print(OS); | ||
} | ||
} | ||
|
||
|
@@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, | |
if (It != InstMap.end()) { | ||
OS << "Instruction: "; | ||
I.print(OS); | ||
printVector(It->second); | ||
It->second.print(OS); | ||
} | ||
} | ||
} | ||
} | ||
return PreservedAnalyses::all(); | ||
} | ||
|
||
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, | ||
ModuleAnalysisManager &MAM) { | ||
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M); | ||
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid"); | ||
|
||
auto Vocab = IR2VecVocabResult.getVocabulary(); | ||
for (const auto &Entry : Vocab) { | ||
OS << "Key: " << Entry.first << ": "; | ||
Entry.second.print(OS); | ||
} | ||
|
||
return PreservedAnalyses::all(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document that the sections are mandatory, but the order in which they appear isn't