Skip to content

Commit

Permalink
[ROCm] hipBLAS integration (#2830)
Browse files Browse the repository at this point in the history
This commit integrates hipBLAS for gemm computation.

Co-authored-by: Lesheng Jin <[email protected]>
  • Loading branch information
MasterJH5574 and LeshengJin authored Aug 23, 2024
1 parent 2cbf393 commit 3869fa1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 40 deletions.
46 changes: 46 additions & 0 deletions python/mlc_llm/compiler_pass/blas_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A compiler pass that dispatches patterns to CUBLAS."""

import tvm
import tvm.relax.backend.contrib.cublas as _cublas
import tvm.relax.backend.contrib.hipblas as _hipblas
from tvm import IRModule, relax
from tvm.relax.backend import get_patterns_with_prefix


@tvm.transform.module_pass(opt_level=0, name="BLASDispatch")
class BLASDispatch: # pylint: disable=too-few-public-methods,broad-exception-raised
"""A compiler pass that dispatches patterns to cuBLAS/hipBLAS."""

def __init__(self, target: tvm.target.Target) -> None:
if target.kind.name == "cuda":
self.has_blas = tvm.get_global_func("relax.ext.cublas", True)
if not self.has_blas:
raise Exception("cuBLAS is not enabled.")
self.patterns = get_patterns_with_prefix("cublas")
elif target.kind.name == "rocm":
self.has_blas = tvm.get_global_func("relax.ext.hipblas", True)
if not self.has_blas:
raise Exception("hipBLAS is not enabled.")
self.patterns = get_patterns_with_prefix("hipblas")
else:
raise Exception(f"Unsupported target {target.kind.name} for BLAS dispatch.")

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
model_names = [
gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function)
]
# exclude single batch decode
model_names = [name for name in model_names if "batch" in name or "decode" not in name]
mod = tvm.transform.Sequential(
[
relax.transform.FuseOpsByPattern(
self.patterns,
bind_constants=False,
annotate_codegen=True,
entry_functions=model_names,
),
relax.transform.RunCodegen({}, entry_functions=model_names),
]
)(mod)
return mod
37 changes: 0 additions & 37 deletions python/mlc_llm/compiler_pass/cublas_dispatch.py

This file was deleted.

4 changes: 2 additions & 2 deletions python/mlc_llm/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
AttachPipelineParallelStages,
AttachVariableBounds,
)
from .blas_dispatch import BLASDispatch
from .clean_up_tir_attrs import CleanUpTIRAttrs
from .cublas_dispatch import CublasDispatch
from .dispatch_kv_cache_creation import DispatchKVCacheCreation
from .estimate_memory_usage import AttachMetadataWithMemoryUsage
from .fuse_add_norm import FuseAddRMSNorm
Expand Down Expand Up @@ -118,7 +118,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
_LogProgress("Running TVM Relax graph-level optimizations"),
FuseFTDequantizeEpilogue(),
FuseDequantizeTranspose(),
CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]),
BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]),
FuseAddRMSNorm(target=target),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _flashinfer(target) -> bool:

def _cublas_gemm(target, quantization) -> bool:
"""correct cublas_gemm flag"""
if not target.kind.name == "cuda":
if not target.kind.name in ["cuda", "rocm"]:
return False
if not (
quantization.name in ["q0f16", "q0f32"]
Expand Down
7 changes: 7 additions & 0 deletions python/mlc_llm/support/auto_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T
)
target = Target(target_dict)
_register_cuda_hook(target)
elif target.kind.name == "rocm":
target_dict = dict(target.export())
extra_libs = ["thrust", "rocblas", "miopen", "hipblas"]
target_dict["libs"] = (
(target_dict["libs"] + extra_libs) if "libs" in target_dict else extra_libs
)
target = Target(target_dict)
return target, build_func


Expand Down

0 comments on commit 3869fa1

Please sign in to comment.