| //===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| // Exceptions. See the LICENSE file for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| /// |
| /// \file |
| /// This file implements the IR2Vec algorithm. |
| /// |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Analysis/IR2Vec.h" |
| |
| #include "llvm/ADT/DepthFirstIterator.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/Statistic.h" |
| #include "llvm/IR/CFG.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/PassManager.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/Errc.h" |
| #include "llvm/Support/Error.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include "llvm/Support/Format.h" |
| #include "llvm/Support/MemoryBuffer.h" |
| |
| using namespace llvm; |
| using namespace ir2vec; |
| |
| #define DEBUG_TYPE "ir2vec" |
| |
| STATISTIC(VocabMissCounter, |
| "Number of lookups to entites not present in the vocabulary"); |
| |
| namespace llvm { |
| namespace ir2vec { |
| static cl::OptionCategory IR2VecCategory("IR2Vec Options"); |
| |
| // FIXME: Use a default vocab when not specified |
| static cl::opt<std::string> |
| VocabFile("ir2vec-vocab-path", cl::Optional, |
| cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), |
| cl::cat(IR2VecCategory)); |
| cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0), |
| cl::desc("Weight for opcode embeddings"), |
| cl::cat(IR2VecCategory)); |
| cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5), |
| cl::desc("Weight for type embeddings"), |
| cl::cat(IR2VecCategory)); |
| cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2), |
| cl::desc("Weight for argument embeddings"), |
| cl::cat(IR2VecCategory)); |
| } // namespace ir2vec |
| } // namespace llvm |
| |
| AnalysisKey IR2VecVocabAnalysis::Key; |
| |
| // ==----------------------------------------------------------------------===// |
| // Local helper functions |
| //===----------------------------------------------------------------------===// |
| namespace llvm::json { |
| inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, |
| llvm::json::Path P) { |
| std::vector<double> TempOut; |
| if (!llvm::json::fromJSON(E, TempOut, P)) |
| return false; |
| Out = Embedding(std::move(TempOut)); |
| return true; |
| } |
| } // namespace llvm::json |
| |
| // ==----------------------------------------------------------------------===// |
| // Embedding |
| //===----------------------------------------------------------------------===// |
| Embedding &Embedding::operator+=(const Embedding &RHS) { |
| assert(this->size() == RHS.size() && "Vectors must have the same dimension"); |
| std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), |
| std::plus<double>()); |
| return *this; |
| } |
| |
| Embedding Embedding::operator+(const Embedding &RHS) const { |
| Embedding Result(*this); |
| Result += RHS; |
| return Result; |
| } |
| |
| Embedding &Embedding::operator-=(const Embedding &RHS) { |
| assert(this->size() == RHS.size() && "Vectors must have the same dimension"); |
| std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), |
| std::minus<double>()); |
| return *this; |
| } |
| |
| Embedding Embedding::operator-(const Embedding &RHS) const { |
| Embedding Result(*this); |
| Result -= RHS; |
| return Result; |
| } |
| |
| Embedding &Embedding::operator*=(double Factor) { |
| std::transform(this->begin(), this->end(), this->begin(), |
| [Factor](double Elem) { return Elem * Factor; }); |
| return *this; |
| } |
| |
| Embedding Embedding::operator*(double Factor) const { |
| Embedding Result(*this); |
| Result *= Factor; |
| return Result; |
| } |
| |
| Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) { |
| assert(this->size() == Src.size() && "Vectors must have the same dimension"); |
| for (size_t Itr = 0; Itr < this->size(); ++Itr) |
| (*this)[Itr] += Src[Itr] * Factor; |
| return *this; |
| } |
| |
| bool Embedding::approximatelyEquals(const Embedding &RHS, |
| double Tolerance) const { |
| assert(this->size() == RHS.size() && "Vectors must have the same dimension"); |
| for (size_t Itr = 0; Itr < this->size(); ++Itr) |
| if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) |
| return false; |
| return true; |
| } |
| |
| void Embedding::print(raw_ostream &OS) const { |
| OS << " ["; |
| for (const auto &Elem : Data) |
| OS << " " << format("%.2f", Elem) << " "; |
| OS << "]\n"; |
| } |
| |
| // ==----------------------------------------------------------------------===// |
| // Embedder and its subclasses |
| //===----------------------------------------------------------------------===// |
| |
| Embedder::Embedder(const Function &F, const Vocabulary &Vocab) |
| : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), |
| OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) { |
| } |
| |
| std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, |
| const Vocabulary &Vocab) { |
| switch (Mode) { |
| case IR2VecKind::Symbolic: |
| return std::make_unique<SymbolicEmbedder>(F, Vocab); |
| } |
| return nullptr; |
| } |
| |
| const InstEmbeddingsMap &Embedder::getInstVecMap() const { |
| if (InstVecMap.empty()) |
| computeEmbeddings(); |
| return InstVecMap; |
| } |
| |
| const BBEmbeddingsMap &Embedder::getBBVecMap() const { |
| if (BBVecMap.empty()) |
| computeEmbeddings(); |
| return BBVecMap; |
| } |
| |
| const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { |
| auto It = BBVecMap.find(&BB); |
| if (It != BBVecMap.end()) |
| return It->second; |
| computeEmbeddings(BB); |
| return BBVecMap[&BB]; |
| } |
| |
| const Embedding &Embedder::getFunctionVector() const { |
| // Currently, we always (re)compute the embeddings for the function. |
| // This is cheaper than caching the vector. |
| computeEmbeddings(); |
| return FuncVector; |
| } |
| |
| void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { |
| Embedding BBVector(Dimension, 0); |
| |
| // We consider only the non-debug and non-pseudo instructions |
| for (const auto &I : BB.instructionsWithoutDebug()) { |
| Embedding ArgEmb(Dimension, 0); |
| for (const auto &Op : I.operands()) |
| ArgEmb += Vocab[Op]; |
| auto InstVector = |
| Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; |
| InstVecMap[&I] = InstVector; |
| BBVector += InstVector; |
| } |
| BBVecMap[&BB] = BBVector; |
| } |
| |
| void SymbolicEmbedder::computeEmbeddings() const { |
| if (F.isDeclaration()) |
| return; |
| |
| // Consider only the basic blocks that are reachable from entry |
| for (const BasicBlock *BB : depth_first(&F)) { |
| computeEmbeddings(*BB); |
| FuncVector += BBVecMap[BB]; |
| } |
| } |
| |
| // ==----------------------------------------------------------------------===// |
| // Vocabulary |
| //===----------------------------------------------------------------------===// |
| |
| Vocabulary::Vocabulary(VocabVector &&Vocab) |
| : Vocab(std::move(Vocab)), Valid(true) {} |
| |
| bool Vocabulary::isValid() const { |
| return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid; |
| } |
| |
| size_t Vocabulary::size() const { |
| assert(Valid && "IR2Vec Vocabulary is invalid"); |
| return Vocab.size(); |
| } |
| |
| unsigned Vocabulary::getDimension() const { |
| assert(Valid && "IR2Vec Vocabulary is invalid"); |
| return Vocab[0].size(); |
| } |
| |
| const Embedding &Vocabulary::operator[](unsigned Opcode) const { |
| assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); |
| return Vocab[Opcode - 1]; |
| } |
| |
| const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const { |
| assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID"); |
| return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)]; |
| } |
| |
| const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { |
| OperandKind ArgKind = getOperandKind(Arg); |
| return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)]; |
| } |
| |
| StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { |
| assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); |
| #define HANDLE_INST(NUM, OPCODE, CLASS) \ |
| if (Opcode == NUM) { \ |
| return #OPCODE; \ |
| } |
| #include "llvm/IR/Instruction.def" |
| #undef HANDLE_INST |
| return "UnknownOpcode"; |
| } |
| |
| StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { |
| switch (TypeID) { |
| case Type::VoidTyID: |
| return "VoidTy"; |
| case Type::HalfTyID: |
| case Type::BFloatTyID: |
| case Type::FloatTyID: |
| case Type::DoubleTyID: |
| case Type::X86_FP80TyID: |
| case Type::FP128TyID: |
| case Type::PPC_FP128TyID: |
| return "FloatTy"; |
| case Type::IntegerTyID: |
| return "IntegerTy"; |
| case Type::FunctionTyID: |
| return "FunctionTy"; |
| case Type::StructTyID: |
| return "StructTy"; |
| case Type::ArrayTyID: |
| return "ArrayTy"; |
| case Type::PointerTyID: |
| case Type::TypedPointerTyID: |
| return "PointerTy"; |
| case Type::FixedVectorTyID: |
| case Type::ScalableVectorTyID: |
| return "VectorTy"; |
| case Type::LabelTyID: |
| return "LabelTy"; |
| case Type::TokenTyID: |
| return "TokenTy"; |
| case Type::MetadataTyID: |
| return "MetadataTy"; |
| case Type::X86_AMXTyID: |
| case Type::TargetExtTyID: |
| return "UnknownTy"; |
| } |
| return "UnknownTy"; |
| } |
| |
| StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { |
| unsigned Index = static_cast<unsigned>(Kind); |
| assert(Index < MaxOperandKinds && "Invalid OperandKind"); |
| return OperandKindNames[Index]; |
| } |
| |
| Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { |
| VocabVector DummyVocab; |
| float DummyVal = 0.1f; |
| // Create a dummy vocabulary with entries for all opcodes, types, and |
| // operand |
| for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs + |
| Vocabulary::MaxOperandKinds)) { |
| DummyVocab.push_back(Embedding(Dim, DummyVal)); |
| DummyVal += 0.1; |
| } |
| return DummyVocab; |
| } |
| |
| // Helper function to classify an operand into OperandKind |
| Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { |
| if (isa<Function>(Op)) |
| return OperandKind::FunctionID; |
| if (isa<PointerType>(Op->getType())) |
| return OperandKind::PointerID; |
| if (isa<Constant>(Op)) |
| return OperandKind::ConstantID; |
| return OperandKind::VariableID; |
| } |
| |
| StringRef Vocabulary::getStringKey(unsigned Pos) { |
| assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds && |
| "Position out of bounds in vocabulary"); |
| // Opcode |
| if (Pos < MaxOpcodes) |
| return getVocabKeyForOpcode(Pos + 1); |
| // Type |
| if (Pos < MaxOpcodes + MaxTypeIDs) |
| return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes)); |
| // Operand |
| return getVocabKeyForOperandKind( |
| static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs)); |
| } |
| |
| // For now, assume vocabulary is stable unless explicitly invalidated. |
| bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, |
| ModuleAnalysisManager::Invalidator &Inv) const { |
| auto PAC = PA.getChecker<IR2VecVocabAnalysis>(); |
| return !(PAC.preservedWhenStateless()); |
| } |
| |
| // ==----------------------------------------------------------------------===// |
| // IR2VecVocabAnalysis |
| //===----------------------------------------------------------------------===// |
| |
| Error IR2VecVocabAnalysis::parseVocabSection( |
| StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, |
| unsigned &Dim) { |
| json::Path::Root Path(""); |
| const json::Object *RootObj = ParsedVocabValue.getAsObject(); |
| if (!RootObj) |
| return createStringError(errc::invalid_argument, |
| "JSON root is not an object"); |
| |
| const json::Value *SectionValue = RootObj->get(Key); |
| if (!SectionValue) |
| return createStringError(errc::invalid_argument, |
| "Missing '" + std::string(Key) + |
| "' section in vocabulary file"); |
| if (!json::fromJSON(*SectionValue, TargetVocab, Path)) |
| return createStringError(errc::illegal_byte_sequence, |
| "Unable to parse '" + std::string(Key) + |
| "' section from vocabulary"); |
| |
| Dim = TargetVocab.begin()->second.size(); |
| if (Dim == 0) |
| return createStringError(errc::illegal_byte_sequence, |
| "Dimension of '" + std::string(Key) + |
| "' section of the vocabulary is zero"); |
| |
| if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), |
| [Dim](const std::pair<StringRef, Embedding> &Entry) { |
| return Entry.second.size() == Dim; |
| })) |
| return createStringError( |
| errc::illegal_byte_sequence, |
| "All vectors in the '" + std::string(Key) + |
| "' section of the vocabulary are not of the same dimension"); |
| |
| return Error::success(); |
| } |
| |
| // FIXME: Make this optional. We can avoid file reads |
| // by auto-generating a default vocabulary during the build time. |
| Error IR2VecVocabAnalysis::readVocabulary() { |
| auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); |
| if (!BufOrError) |
| return createFileError(VocabFile, BufOrError.getError()); |
| |
| auto Content = BufOrError.get()->getBuffer(); |
| |
| Expected<json::Value> ParsedVocabValue = json::parse(Content); |
| if (!ParsedVocabValue) |
| return ParsedVocabValue.takeError(); |
| |
| unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; |
| if (auto Err = |
| parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim)) |
| return Err; |
| |
| if (auto Err = |
| parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim)) |
| return Err; |
| |
| if (auto Err = |
| parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim)) |
| return Err; |
| |
| if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) |
| return createStringError(errc::illegal_byte_sequence, |
| "Vocabulary sections have different dimensions"); |
| |
| return Error::success(); |
| } |
| |
| void IR2VecVocabAnalysis::generateNumMappedVocab() { |
| |
| // Helper for handling missing entities in the vocabulary. |
| // Currently, we use a zero vector. In the future, we will throw an error to |
| // ensure that *all* known entities are present in the vocabulary. |
| auto handleMissingEntity = [](const std::string &Val) { |
| LLVM_DEBUG(errs() << Val |
| << " is not in vocabulary, using zero vector; This " |
| "would result in an error in future.\n"); |
| ++VocabMissCounter; |
| }; |
| |
| unsigned Dim = OpcVocab.begin()->second.size(); |
| assert(Dim > 0 && "Vocabulary dimension must be greater than zero"); |
| |
| // Handle Opcodes |
| std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, |
| Embedding(Dim, 0)); |
| for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { |
| StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); |
| auto It = OpcVocab.find(VocabKey.str()); |
| if (It != OpcVocab.end()) |
| NumericOpcodeEmbeddings[Opcode] = It->second; |
| else |
| handleMissingEntity(VocabKey.str()); |
| } |
| Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), |
| NumericOpcodeEmbeddings.end()); |
| |
| // Handle Types |
| std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, |
| Embedding(Dim, 0)); |
| for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) { |
| StringRef VocabKey = |
| Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID)); |
| if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) { |
| NumericTypeEmbeddings[TypeID] = It->second; |
| continue; |
| } |
| handleMissingEntity(VocabKey.str()); |
| } |
| Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(), |
| NumericTypeEmbeddings.end()); |
| |
| // Handle Arguments/Operands |
| std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, |
| Embedding(Dim, 0)); |
| for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { |
| Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); |
| StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); |
| auto It = ArgVocab.find(VocabKey.str()); |
| if (It != ArgVocab.end()) { |
| NumericArgEmbeddings[OpKind] = It->second; |
| continue; |
| } |
| handleMissingEntity(VocabKey.str()); |
| } |
| Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), |
| NumericArgEmbeddings.end()); |
| } |
| |
| IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) |
| : Vocab(Vocab) {} |
| |
| IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab) |
| : Vocab(std::move(Vocab)) {} |
| |
| void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { |
| handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { |
| Ctx.emitError("Error reading vocabulary: " + EI.message()); |
| }); |
| } |
| |
| IR2VecVocabAnalysis::Result |
| IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { |
| auto Ctx = &M.getContext(); |
| // If vocabulary is already populated by the constructor, use it. |
| if (!Vocab.empty()) |
| return Vocabulary(std::move(Vocab)); |
| |
| // Otherwise, try to read from the vocabulary file. |
| if (VocabFile.empty()) { |
| // FIXME: Use default vocabulary |
| Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to " |
| "set it using --ir2vec-vocab-path"); |
| return Vocabulary(); // Return invalid result |
| } |
| if (auto Err = readVocabulary()) { |
| emitError(std::move(Err), *Ctx); |
| return Vocabulary(); |
| } |
| |
| // Scale the vocabulary sections based on the provided weights |
| auto scaleVocabSection = [](VocabMap &Vocab, double Weight) { |
| for (auto &Entry : Vocab) |
| Entry.second *= Weight; |
| }; |
| scaleVocabSection(OpcVocab, OpcWeight); |
| scaleVocabSection(TypeVocab, TypeWeight); |
| scaleVocabSection(ArgVocab, ArgWeight); |
| |
| // Generate the numeric lookup vocabulary |
| generateNumMappedVocab(); |
| |
| return Vocabulary(std::move(Vocab)); |
| } |
| |
| // ==----------------------------------------------------------------------===// |
| // Printer Passes |
| //===----------------------------------------------------------------------===// |
| |
| PreservedAnalyses IR2VecPrinterPass::run(Module &M, |
| ModuleAnalysisManager &MAM) { |
| auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); |
| assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); |
| |
| for (Function &F : M) { |
| std::unique_ptr<Embedder> Emb = |
| Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); |
| if (!Emb) { |
| OS << "Error creating IR2Vec embeddings \n"; |
| continue; |
| } |
| |
| OS << "IR2Vec embeddings for function " << F.getName() << ":\n"; |
| OS << "Function vector: "; |
| Emb->getFunctionVector().print(OS); |
| |
| OS << "Basic block vectors:\n"; |
| const auto &BBMap = Emb->getBBVecMap(); |
| for (const BasicBlock &BB : F) { |
| auto It = BBMap.find(&BB); |
| if (It != BBMap.end()) { |
| OS << "Basic block: " << BB.getName() << ":\n"; |
| It->second.print(OS); |
| } |
| } |
| |
| OS << "Instruction vectors:\n"; |
| const auto &InstMap = Emb->getInstVecMap(); |
| for (const BasicBlock &BB : F) { |
| for (const Instruction &I : BB) { |
| auto It = InstMap.find(&I); |
| if (It != InstMap.end()) { |
| OS << "Instruction: "; |
| I.print(OS); |
| It->second.print(OS); |
| } |
| } |
| } |
| } |
| return PreservedAnalyses::all(); |
| } |
| |
| PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, |
| ModuleAnalysisManager &MAM) { |
| auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); |
| assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid"); |
| |
| // Print each entry |
| unsigned Pos = 0; |
| for (const auto &Entry : IR2VecVocabulary) { |
| OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": "; |
| Entry.print(OS); |
| } |
| return PreservedAnalyses::all(); |
| } |