Skip to content

Commit

Permalink
Pick up vector length from 'zvlXXXb' (RVV) mattr for riscv
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Feb 12, 2025
1 parent 4ac03b3 commit 146157d
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 35 deletions.
17 changes: 17 additions & 0 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,23 @@ def llvm_cpu_has_features(cpu_features, target=None):
return has_feats


def llvm_get_vector_width(target=None):
"""Get vector width from LLVM target's `-mtriple` and `-mcpu` and considering `-mattr`.
Parameters
----------
target : Target
The TVM target.
Returns
-------
vector_width : int
Vector with of target in number of bits, -1 on error.
"""
assert isinstance(target, Target) or target is None
return _ffi_api.llvm_get_vector_width(target)


def llvm_version_major(allow_none=False):
"""Get the major LLVM version.
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,41 @@ def _multi_gpu_exists():
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
)


# Mark a test as requiring minimum llvm version
def requires_llvm_minimum_version(major_version):
"""Mark a test as requiring at least a specific version of LLVM.
Unit test marked with this decorator will run only if the
installed version of LLVM is at least `major_version`.
This also marks the test as requiring LLVM backend support.
Parameters
----------
major_version: int
"""

try:
llvm_version = tvm.target.codegen.llvm_version_major()
except RuntimeError:
llvm_version = 0

requires = [
pytest.mark.skipif(
llvm_version < major_version, reason=f"Requires LLVM >= {major_version}"
),
*requires_llvm.marks(),
]

def inner(func):
return _compose([func], requires)

return inner


# Mark a test as requiring a GPU to run.
requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists)

Expand Down
26 changes: 1 addition & 25 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,32 +174,8 @@ void CodeGenLLVM::InitTarget() {
data_layout_.reset(new llvm::DataLayout(module_.get()));
#endif
if (native_vector_bits_ == 0) {
const int vwidth = llvm_target_->GetVectorWidth();
const auto& arch = tm->getTargetTriple().getArch();
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
if (vwidth > 0) {
// override from target options
// e.g. llvm -vector-width=xxx
native_vector_bits_ = vwidth;
} else if (arch == llvm::Triple::x86_64) {
// for avx512
native_vector_bits_ = 512;
} else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
native_vector_bits_ = 256;
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
<< "using 256 bits, set -vector-width=XXX to override";
// fallback default
} else {
native_vector_bits_ = 128;
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
<< "`, use -vector-width=XXX to override.";
}
native_vector_bits_ = llvm_target_->GetVectorWidth();
}

#if TVM_LLVM_VERSION >= 60
bool use_float16_abi = false;
#if TVM_LLVM_VERSION >= 150
Expand Down
50 changes: 41 additions & 9 deletions src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
// TVM & LLVM vector width options
if (const auto& w = Downcast<Optional<runtime::Int>>(target.Get("vector-width"))) {
vector_width_ = w.value();
if ((vector_width_ <= 0) || (vector_width_ > 65535)) {
if ((vector_width_ <= 0) || (vector_width_ > 65536)) {
LOG(FATAL) << "Invalid -vector-width value: " << vector_width_;
}
}
Expand All @@ -300,26 +300,32 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
code_model_ = llvm::CodeModel::Medium;
#if TVM_LLVM_VERSION >= 140
// VLEN inference
const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo();
const auto cpu_name = mci->getCPU();
const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU();
const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
auto ISAInfo =
llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM or via options
llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true);
// infer VLEN from LLVM RISCVInfo parser
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
vector_width_ = (*ISAInfo)->getMinVLen();
}
// infer VLEN from LLVM options (zvlXXXb override)
for (const auto& attr : attrs_) {
if (attr.find("zvl") != std::string::npos) {
std::string vec;
for (char c : attr) {
if (std::isdigit(c)) vec += c;
}
vector_width_ = std::stoi(vec);
}
}
#endif
if (vector_width_ > 0) {
// push cl-opt to LLVM
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
llvm_options_.push_back(
ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_)));
} else {
// fallback default (codegen will warn)
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256"));
}
}

Expand Down Expand Up @@ -924,6 +930,32 @@ const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const
return has_feature;
}

const int LLVMTargetInfo::GetVectorWidth() {
auto* tm = GetOrCreateTargetMachine(false);
const auto& arch = tm->getTargetTriple().getArch();
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
if (vector_width_ == 0) {
if (arch == llvm::Triple::x86_64) {
// for avx512
vector_width_ = 512;
} else if (arch == llvm::Triple::x86) {
vector_width_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
vector_width_ = 128;
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
vector_width_ = 256;
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
<< "using 256 bits, set -vector-width=XXX to override";
// fallback default
} else {
vector_width_ = 128;
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
<< "`, use -vector-width=XXX to override.";
}
}
return vector_width_;
}

// LLVMTarget

bool LLVMTarget::modified_llvm_state_ = false;
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class LLVMTargetInfo {
* \brief Get the TVM & LLVM vector_width
* \return number of bits for vector width
*/
const int GetVectorWidth() const { return vector_width_; }
const int GetVectorWidth();
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
Expand Down
13 changes: 13 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,19 @@ TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() ->
return "unimplemented";
});

TVM_REGISTER_GLOBAL("target.llvm_get_vector_width").set_body_typed([](const Target& target) -> int {
auto use_target = target.defined() ? target : Target::Current(false);
// ignore non "llvm" target
if (target.defined()) {
if (target->kind->name != "llvm") {
return -1;
}
}
auto llvm_instance = std::make_unique<LLVMInstance>();
LLVMTargetInfo llvm_backend(*llvm_instance, use_target);
return llvm_backend.GetVectorWidth();
});

TVM_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String {
return llvm::sys::getDefaultTargetTriple();
});
Expand Down
49 changes: 49 additions & 0 deletions tests/python/codegen/test_target_codegen_riscv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.script import tir as T
from tvm.target.codegen import target_has_features


@tvm.testing.requires_llvm_minimum_version(14)
@tvm.testing.parametrize_targets(
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m",
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_rvv(target):
def check_rvv_presence(N, extent):
@T.prim_func
def load_vec(A: T.Buffer((N,), "int8")):
for j in T.vectorized(0, extent):
A[j] = 1

f = tvm.build(load_vec, target)
# Check RVV `vsetvli` prensence
assembly = f.get_source("asm")
if target_has_features("v"):
assert "vsetvli" in assembly
else:
assert "vsetvli" not in assembly

with tvm.target.Target(target):
check_rvv_presence(16, 32)


if __name__ == "__main__":
test_rvv()
66 changes: 66 additions & 0 deletions tests/python/target/test_riscv_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm.target import _ffi_api, codegen, Target
from tvm.target.codegen import target_has_features, llvm_get_vector_width

LLVM_VERSION = codegen.llvm_version_major()

# fmt: off
min_llvm_version, tvm_target, vec_width = tvm.testing.parameters(
# generic, no-vec -> (default 256)
(-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+i,+m", 256),
(-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+64bit,+a,+c,+d,+f,+m", 256),
# generic, with-vec -> (default 256)
(-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 256),
(-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 256),
# explicit -vector-width
(-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 128),
(-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 128),
(-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 512),
(-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 512),
# explicit +zvlXXXb
(-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v,+zvl64b", 64),
(-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v,+zvl64b", 64),
# vendor CPU
(17, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-x280", 512),
(18, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-p670", 128),
(19, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=spacemit-x60", 256),
) # fmt: on


def test_riscv_rvv_features(min_llvm_version, tvm_target, vec_width):
"""Test RVV features support for different targets.
Parameters
----------
min_llvm_version : int
Minimal LLVM version.
tvm_target : str
TVM target.
vec_width : bool
Expected vector width.
"""

# skip test on llvm_version
if LLVM_VERSION < min_llvm_version:
return

with Target(tvm_target):
assert llvm_get_vector_width() == vec_width

0 comments on commit 146157d

Please sign in to comment.