From 54288e138f7f6398ace5420ea660ef72ebcd3247 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Sun, 9 Feb 2025 00:51:50 +0200 Subject: [PATCH] Handle vector width (VLEN) for RISCV arches --- src/target/llvm/codegen_llvm.cc | 17 ++++++-- src/target/llvm/llvm_instance.cc | 52 ++++++++++++++++++++++- src/target/llvm/llvm_instance.h | 6 +++ src/target/target_kind.cc | 2 + tests/python/target/test_target_target.py | 7 +++ 5 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b1caf28149b5..f674974fc0ea 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -174,18 +174,29 @@ void CodeGenLLVM::InitTarget() { data_layout_.reset(new llvm::DataLayout(module_.get())); #endif if (native_vector_bits_ == 0) { + const int vwidth = llvm_target_->GetVectorWidth(); const auto& arch = tm->getTargetTriple().getArch(); - if (arch == llvm::Triple::x86_64) { + const std::string arch_name = std::string(tm->getTargetTriple().getArchName()); + if (vwidth > 0) { + // override from target options + // e.g. llvm -vector-width=xxx + native_vector_bits_ = vwidth; + } else if (arch == llvm::Triple::x86_64) { // for avx512 native_vector_bits_ = 512; } else if (arch == llvm::Triple::x86) { native_vector_bits_ = 256; } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) { native_vector_bits_ = 128; + } else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) { + native_vector_bits_ = 256; + LOG(WARNING) << "RISCV VLEN inference failed in LLVM, " + << "using 256 bits, set -vector-width=XXX to override"; + // fallback default } else { native_vector_bits_ = 128; - std::string arch_name = std::string(tm->getTargetTriple().getArchName()); - LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name; + LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name + << "`, use -vector-width=XXX to override."; } } diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index e2c5e28592b7..ed3b440bcb93 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -53,6 +53,18 @@ #include #include #include +#if TVM_LLVM_VERSION >= 190 +#include +#else +#if TVM_LLVM_VERSION >= 140 +#include +#endif +#endif +#if TVM_LLVM_VERSION >= 160 +#include +#else +#include +#endif #include #include #include @@ -273,11 +285,49 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } } - // RISCV code model + printf("\n----------------\n"); + + // TVM & LLVM vector width options + if (const auto& w = Downcast>(target.Get("vector-width"))) { + vector_width_ = w.value(); + printf("VECTOR-WIDTH = %i\n", vector_width_); + if ((vector_width_ <= 0) || (vector_width_ > 65535)) { + LOG(FATAL) << "Invalid -vector-width value: " << vector_width_; + } + } + + // RISCV code model & vlen auto arch = llvm::Triple(triple_).getArch(); + printf("BEGIN vector_width = %i\n", vector_width_); if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) { + // code model + printf("ENTER RISCV\n"); code_model_ = llvm::CodeModel::Medium; + // VLEN inference + const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo(); + const auto cpu_name = mci->getCPU(); +#if TVM_LLVM_VERSION >= 140 + const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name); + auto ISAInfo = + llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true); + // infer VLEN from LLVM or via options + if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) { + vector_width_ = (*ISAInfo)->getMinVLen(); + } +#endif + if (vector_width_ > 0) { + // push cl-opt to LLVM + llvm_options_.push_back( + ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_))); + llvm_options_.push_back( + ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_))); + } else { + // fallback default (codegen will warn) + llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256")); + llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256")); + } } + printf("AFTER vector_width = %i\n", vector_width_); // Target options #if TVM_LLVM_VERSION < 50 diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index a7711384d00c..82358878f157 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -242,6 +242,11 @@ class LLVMTargetInfo { * \return the type name of the JIT engine (default "orcjit" or "mcjit") */ const std::string GetJITEngine() const { return jit_engine_; } + /*! + * \brief Get the TVM & LLVM vector_width + * \return number of bits for vector width + */ + const int GetVectorWidth() const { return vector_width_; } /*! * \brief Get the LLVM optimization level * \return optimization level for this target @@ -356,6 +361,7 @@ class LLVMTargetInfo { llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; std::shared_ptr target_machine_; std::string jit_engine_ = "orcjit"; + int vector_width_{0}; }; /*! diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 979b755af846..e0a0ad23a1b6 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -283,6 +283,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit .add_attr_option("jit") + // TVM & LLVM custom vector bit width + .add_attr_option("vector-width") .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index fd79661ce632..cda228939f31 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -178,6 +178,13 @@ def test_target_llvm_jit_options(): assert target.attrs["jit"] == "orcjit" +def test_target_llvm_vector_width(): + target = tvm.target.Target("llvm -vector-width=256") + assert target.attrs["vector-width"] == 256 + target = tvm.target.Target("llvm -vector-width=1024") + assert target.attrs["vector-width"] == 1024 + + def test_target_create(): targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), bifrost()] for tgt in targets: