-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit integrates hipBLAS for gemm computation. Co-authored-by: Lesheng Jin <[email protected]>
- Loading branch information
1 parent
2cbf393
commit 3869fa1
Showing
5 changed files
with
56 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""A compiler pass that dispatches patterns to CUBLAS.""" | ||
|
||
import tvm | ||
import tvm.relax.backend.contrib.cublas as _cublas | ||
import tvm.relax.backend.contrib.hipblas as _hipblas | ||
from tvm import IRModule, relax | ||
from tvm.relax.backend import get_patterns_with_prefix | ||
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="BLASDispatch") | ||
class BLASDispatch: # pylint: disable=too-few-public-methods,broad-exception-raised | ||
"""A compiler pass that dispatches patterns to cuBLAS/hipBLAS.""" | ||
|
||
def __init__(self, target: tvm.target.Target) -> None: | ||
if target.kind.name == "cuda": | ||
self.has_blas = tvm.get_global_func("relax.ext.cublas", True) | ||
if not self.has_blas: | ||
raise Exception("cuBLAS is not enabled.") | ||
self.patterns = get_patterns_with_prefix("cublas") | ||
elif target.kind.name == "rocm": | ||
self.has_blas = tvm.get_global_func("relax.ext.hipblas", True) | ||
if not self.has_blas: | ||
raise Exception("hipBLAS is not enabled.") | ||
self.patterns = get_patterns_with_prefix("hipblas") | ||
else: | ||
raise Exception(f"Unsupported target {target.kind.name} for BLAS dispatch.") | ||
|
||
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
"""IRModule-level transformation""" | ||
model_names = [ | ||
gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) | ||
] | ||
# exclude single batch decode | ||
model_names = [name for name in model_names if "batch" in name or "decode" not in name] | ||
mod = tvm.transform.Sequential( | ||
[ | ||
relax.transform.FuseOpsByPattern( | ||
self.patterns, | ||
bind_constants=False, | ||
annotate_codegen=True, | ||
entry_functions=model_names, | ||
), | ||
relax.transform.RunCodegen({}, entry_functions=model_names), | ||
] | ||
)(mod) | ||
return mod |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters