Skip to content

Commit

Permalink
Handle vector width (VLEN) for RISCV arches
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Feb 9, 2025
1 parent 4fdf4ae commit 465f877
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,29 @@ void CodeGenLLVM::InitTarget() {
data_layout_.reset(new llvm::DataLayout(module_.get()));
#endif
if (native_vector_bits_ == 0) {
const int vwidth = llvm_target_->GetVectorWidth();
const auto& arch = tm->getTargetTriple().getArch();
if (arch == llvm::Triple::x86_64) {
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
if (vwidth > 0) {
// override from target options
// e.g. llvm -vector-width=xxx
native_vector_bits_ = vwidth;
} else if (arch == llvm::Triple::x86_64) {
// for avx512
native_vector_bits_ = 512;
} else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
native_vector_bits_ = 256;
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
<< "using 256 bits, set -vector-width=XXX to override";
// fallback default
} else {
native_vector_bits_ = 128;
std::string arch_name = std::string(tm->getTargetTriple().getArchName());
LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
<< "`, use -vector-width=XXX to override.";
}
}

Expand Down
46 changes: 45 additions & 1 deletion src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#if TVM_LLVM_VERSION >= 190
#include <llvm/TargetParser/RISCVISAInfo.h>
#else
#if TVM_LLVM_VERSION >= 140
#include <llvm/Support/RISCVISAInfo.h>
#endif
#endif
#if TVM_LLVM_VERSION >= 160
#include <llvm/TargetParser/RISCVTargetParser.h>
#else
#include <llvm/Support/TargetParser.h>
#endif
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
Expand Down Expand Up @@ -273,10 +285,42 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
}
}

// RISCV code model
// TVM & LLVM vector width options
if (const auto& w = Downcast<Optional<runtime::Int>>(target.Get("vector-width"))) {
vector_width_ = w.value();
if ((vector_width_ <= 0) || (vector_width_ > 65535)) {
LOG(FATAL) << "Invalid -vector-width value: " << vector_width_;
}
}

// RISCV code model & vlen
auto arch = llvm::Triple(triple_).getArch();
if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
// code model
code_model_ = llvm::CodeModel::Medium;
#if TVM_LLVM_VERSION >= 140
// VLEN inference
const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo();
const auto cpu_name = mci->getCPU();
const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
auto ISAInfo =
llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM or via options
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
vector_width_ = (*ISAInfo)->getMinVLen();
}
#endif
if (vector_width_ > 0) {
// push cl-opt to LLVM
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_)));
} else {
// fallback default (codegen will warn)
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256"));
}
}

// Target options
Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ class LLVMTargetInfo {
* \return the type name of the JIT engine (default "orcjit" or "mcjit")
*/
const std::string GetJITEngine() const { return jit_engine_; }
/*!
* \brief Get the TVM & LLVM vector_width
* \return number of bits for vector width
*/
const int GetVectorWidth() const { return vector_width_; }
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
Expand Down Expand Up @@ -356,6 +361,7 @@ class LLVMTargetInfo {
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
std::string jit_engine_ = "orcjit";
int vector_width_{0};
};

/*!
Expand Down
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Array<String>>("cl-opt")
// LLVM JIT engine mcjit/orcjit
.add_attr_option<String>("jit")
// TVM & LLVM custom vector bit width
.add_attr_option<runtime::Int>("vector-width")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
Expand Down
7 changes: 7 additions & 0 deletions tests/python/target/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def test_target_llvm_jit_options():
assert target.attrs["jit"] == "orcjit"


def test_target_llvm_vector_width():
target = tvm.target.Target("llvm -vector-width=256")
assert target.attrs["vector-width"] == 256
target = tvm.target.Target("llvm -vector-width=1024")
assert target.attrs["vector-width"] == 1024


def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), bifrost()]
for tgt in targets:
Expand Down

0 comments on commit 465f877

Please sign in to comment.