Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle vector width (VLEN) for RISCV arches #17631

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,29 @@ void CodeGenLLVM::InitTarget() {
data_layout_.reset(new llvm::DataLayout(module_.get()));
#endif
if (native_vector_bits_ == 0) {
const int vwidth = llvm_target_->GetVectorWidth();
const auto& arch = tm->getTargetTriple().getArch();
if (arch == llvm::Triple::x86_64) {
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
if (vwidth > 0) {
// override from target options
// e.g. llvm -vector-width=xxx
native_vector_bits_ = vwidth;
} else if (arch == llvm::Triple::x86_64) {
// for avx512
native_vector_bits_ = 512;
} else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
native_vector_bits_ = 256;
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
<< "using 256 bits, set -vector-width=XXX to override";
// fallback default
} else {
native_vector_bits_ = 128;
std::string arch_name = std::string(tm->getTargetTriple().getArchName());
LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
<< "`, use -vector-width=XXX to override.";
}
}

Expand Down
46 changes: 45 additions & 1 deletion src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#if TVM_LLVM_VERSION >= 190
#include <llvm/TargetParser/RISCVISAInfo.h>
#else
#if TVM_LLVM_VERSION >= 140
#include <llvm/Support/RISCVISAInfo.h>
#endif
#endif
#if TVM_LLVM_VERSION >= 160
#include <llvm/TargetParser/RISCVTargetParser.h>
#else
#include <llvm/Support/TargetParser.h>
#endif
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
Expand Down Expand Up @@ -273,10 +285,42 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
}
}

// RISCV code model
// TVM & LLVM vector width options
if (const auto& w = Downcast<Optional<runtime::Int>>(target.Get("vector-width"))) {
vector_width_ = w.value();
if ((vector_width_ <= 0) || (vector_width_ > 65535)) {
LOG(FATAL) << "Invalid -vector-width value: " << vector_width_;
}
}

// RISCV code model & vlen
auto arch = llvm::Triple(triple_).getArch();
if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
// code model
code_model_ = llvm::CodeModel::Medium;
#if TVM_LLVM_VERSION >= 140
// VLEN inference
const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo();
const auto cpu_name = mci->getCPU();
const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
auto ISAInfo =
llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM or via options
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
vector_width_ = (*ISAInfo)->getMinVLen();
}
#endif
if (vector_width_ > 0) {
// push cl-opt to LLVM
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_)));
} else {
// fallback default (codegen will warn)
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256"));
}
}

// Target options
Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ class LLVMTargetInfo {
* \return the type name of the JIT engine (default "orcjit" or "mcjit")
*/
const std::string GetJITEngine() const { return jit_engine_; }
/*!
* \brief Get the TVM & LLVM vector_width
* \return number of bits for vector width
*/
const int GetVectorWidth() const { return vector_width_; }
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
Expand Down Expand Up @@ -356,6 +361,7 @@ class LLVMTargetInfo {
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
std::string jit_engine_ = "orcjit";
int vector_width_{0};
};

/*!
Expand Down
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Array<String>>("cl-opt")
// LLVM JIT engine mcjit/orcjit
.add_attr_option<String>("jit")
// TVM & LLVM custom vector bit width
.add_attr_option<runtime::Int>("vector-width")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
Expand Down
7 changes: 7 additions & 0 deletions tests/python/target/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def test_target_llvm_jit_options():
assert target.attrs["jit"] == "orcjit"


def test_target_llvm_vector_width():
target = tvm.target.Target("llvm -vector-width=256")
assert target.attrs["vector-width"] == 256
target = tvm.target.Target("llvm -vector-width=1024")
assert target.attrs["vector-width"] == 1024


def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), bifrost()]
for tgt in targets:
Expand Down
Loading