Skip to content

Commit

Permalink
Pick up vector length from 'zvlXXXb' (RVV) mattr for riscv
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Feb 11, 2025
1 parent e5cea6d commit 2b186e9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 9 deletions.
35 changes: 35 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,41 @@ def _multi_gpu_exists():
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
)


# Mark a test as requiring minimum llvm version
def requires_llvm_minimum_version(major_version):
"""Mark a test as requiring at least a specific version of LLVM.
Unit test marked with this decorator will run only if the
installed version of LLVM is at least `major_version`.
This also marks the test as requiring LLVM backend support.
Parameters
----------
major_version: int
"""

try:
llvm_version = tvm.target.codegen.llvm_version_major()
except RuntimeError:
llvm_version = 0

requires = [
pytest.mark.skipif(
llvm_version < major_version, reason=f"Requires LLVM >= {major_version}"
),
*requires_llvm.marks(),
]

def inner(func):
return _compose([func], requires)

return inner


# Mark a test as requiring a GPU to run.
requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists)

Expand Down
24 changes: 15 additions & 9 deletions src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
// TVM & LLVM vector width options
if (const auto& w = Downcast<Optional<runtime::Int>>(target.Get("vector-width"))) {
vector_width_ = w.value();
if ((vector_width_ <= 0) || (vector_width_ > 65535)) {
if ((vector_width_ <= 0) || (vector_width_ > 65536)) {
LOG(FATAL) << "Invalid -vector-width value: " << vector_width_;
}
}
Expand All @@ -300,26 +300,32 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
code_model_ = llvm::CodeModel::Medium;
#if TVM_LLVM_VERSION >= 140
// VLEN inference
const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo();
const auto cpu_name = mci->getCPU();
const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU();
const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
auto ISAInfo =
llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM or via options
llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM RISCVInfo parser
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
vector_width_ = (*ISAInfo)->getMinVLen();
}
// infer VLEN from LLVM options (zvlXXXb override)
for (const auto& attr : attrs_) {
if (attr.find("zvl") != std::string::npos) {
std::string vec;
for (char c : attr) {
if (std::isdigit(c)) vec += c;
}
vector_width_ = std::stoi(vec);
}
}
#endif
if (vector_width_ > 0) {
// push cl-opt to LLVM
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_)));
} else {
// fallback default (codegen will warn)
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256"));
}
}

Expand Down
56 changes: 56 additions & 0 deletions tests/python/codegen/test_target_codegen_riscv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm.target.codegen import target_has_features


@tvm.testing.requires_llvm_minimum_version(14)
@tvm.testing.parametrize_targets(
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m",
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_rvv(target):
def check_rvv_presence(N):
K = te.size_var("K")
A = te.placeholder((K, N), dtype="int8", name="A")
B = te.placeholder((K, N), dtype="int8", name="B")
k = te.reduce_axis((0, K))
C = te.compute(
(N,),
lambda n: te.sum(A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]),
name="C",
)
s = te.create_schedule(C.op)
s[C].vectorize(s[C].op.axis[0])
f = tvm.build(s, [A, B, C], target)

# Check RVV vsetvli prensence
assembly = f.get_source("asm")
with tvm.target.Target(target):
if target_has_features("v"):
assert "vsetvli" in assembly
else:
assert "vsetvli" not in assembly

check_rvv_presence(4)


if __name__ == "__main__":
test_rvv()

0 comments on commit 2b186e9

Please sign in to comment.