From 2b186e93afd494216a77ede29329ec1f2003c273 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Tue, 11 Feb 2025 15:43:29 +0200 Subject: [PATCH] Pick up vector length from 'zvlXXXb' (RVV) mattr for riscv --- python/tvm/testing/utils.py | 35 ++++++++++++ src/target/llvm/llvm_instance.cc | 24 +++++--- .../codegen/test_target_codegen_riscv.py | 56 +++++++++++++++++++ 3 files changed, 106 insertions(+), 9 deletions(-) create mode 100644 tests/python/codegen/test_target_codegen_riscv.py diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 8546d4aef233..c6fdb529c030 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -848,6 +848,41 @@ def _multi_gpu_exists(): "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" ) + +# Mark a test as requiring minimum llvm version +def requires_llvm_minimum_version(major_version): + """Mark a test as requiring at least a specific version of LLVM. + + Unit test marked with this decorator will run only if the + installed version of LLVM is at least `major_version`. + + This also marks the test as requiring LLVM backend support. + + Parameters + ---------- + major_version: int + + + """ + + try: + llvm_version = tvm.target.codegen.llvm_version_major() + except RuntimeError: + llvm_version = 0 + + requires = [ + pytest.mark.skipif( + llvm_version < major_version, reason=f"Requires LLVM >= {major_version}" + ), + *requires_llvm.marks(), + ] + + def inner(func): + return _compose([func], requires) + + return inner + + # Mark a test as requiring a GPU to run. requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists) diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 02efd9e05360..710f7823dd0a 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -288,7 +288,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // TVM & LLVM vector width options if (const auto& w = Downcast>(target.Get("vector-width"))) { vector_width_ = w.value(); - if ((vector_width_ <= 0) || (vector_width_ > 65535)) { + if ((vector_width_ <= 0) || (vector_width_ > 65536)) { LOG(FATAL) << "Invalid -vector-width value: " << vector_width_; } } @@ -300,26 +300,32 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) code_model_ = llvm::CodeModel::Medium; #if TVM_LLVM_VERSION >= 140 // VLEN inference - const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo(); - const auto cpu_name = mci->getCPU(); - const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name); + const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU(); + const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name); auto ISAInfo = - llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true); - // infer VLEN from LLVM or via options + llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true); + // infer VLEN from LLVM RISCVInfo parser if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) { vector_width_ = (*ISAInfo)->getMinVLen(); } + // infer VLEN from LLVM options (zvlXXXb override) + for (const auto& attr : attrs_) { + if (attr.find("zvl") != std::string::npos) { + std::string vec; + for (char c : attr) { + if (std::isdigit(c)) vec += c; + } + vector_width_ = std::stoi(vec); + } + } #endif if (vector_width_ > 0) { // push cl-opt to LLVM llvm_options_.push_back( ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_))); - llvm_options_.push_back( - ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_))); } else { // fallback default (codegen will warn) llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256")); - llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256")); } } diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py new file mode 100644 index 000000000000..0fb10a59c3a6 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm.target.codegen import target_has_features + + +@tvm.testing.requires_llvm_minimum_version(14) +@tvm.testing.parametrize_targets( + "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m", + "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_rvv(target): + def check_rvv_presence(N): + K = te.size_var("K") + A = te.placeholder((K, N), dtype="int8", name="A") + B = te.placeholder((K, N), dtype="int8", name="B") + k = te.reduce_axis((0, K)) + C = te.compute( + (N,), + lambda n: te.sum(A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), + name="C", + ) + s = te.create_schedule(C.op) + s[C].vectorize(s[C].op.axis[0]) + f = tvm.build(s, [A, B, C], target) + + # Check RVV vsetvli prensence + assembly = f.get_source("asm") + with tvm.target.Target(target): + if target_has_features("v"): + assert "vsetvli" in assembly + else: + assert "vsetvli" not in assembly + + check_rvv_presence(4) + + +if __name__ == "__main__": + test_rvv()