From 50e9c2de74927f5496a6ded1994f0e69ceec12bc Mon Sep 17 00:00:00 2001 From: zyuli Date: Mon, 29 Jun 2026 10:03:27 +0000 Subject: [PATCH] support nvshmem --- .../PatternTritonGPUOpToLLVM.h | 5 + include/triton/Dialect/Triton/IR/TritonOps.td | 26 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/ExternCallOpToLLVM.cpp | 61 ++ .../TritonToTritonGPUPass.cpp | 1 + lib/Dialect/Triton/IR/Ops.cpp | 18 + python/src/ir.cc | 8 + .../experimental/tle/language/raw/core.py | 30 +- .../experimental/tle/raw/cuda/runtime.py | 144 ++++- python/triton/language/core.py | 88 +++ .../01-simple-shift/simple-shift-device.cu | 10 + .../01-simple-shift/simple-shift-host.cu | 47 ++ .../nvshmem/01-simple-shift/simple-shift.py | 98 +++ .../02-allgather-gemm/ag-gemm-device.cu | 60 ++ .../nvshmem/02-allgather-gemm/ag-gemm-host.cu | 58 ++ .../raw/nvshmem/02-allgather-gemm/ag-gemm.py | 568 ++++++++++++++++++ .../tutorials/tle/raw/nvshmem/common/build.py | 222 +++++++ .../tle/raw/nvshmem/common/common-host.cu | 60 ++ .../nvshmem/common/generate_extern_call.py | 197 ++++++ .../tutorials/tle/raw/nvshmem/common/utils.py | 103 ++++ python/tutorials/tle/raw/nvshmem/run.sh | 160 +++++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 + 22 files changed, 1964 insertions(+), 3 deletions(-) create mode 100644 lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py create mode 100644 python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-device.cu create mode 100644 python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-host.cu create mode 100644 python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm.py create mode 100644 python/tutorials/tle/raw/nvshmem/common/build.py create mode 100644 python/tutorials/tle/raw/nvshmem/common/common-host.cu create mode 100644 python/tutorials/tle/raw/nvshmem/common/generate_extern_call.py create mode 100644 python/tutorials/tle/raw/nvshmem/common/utils.py create mode 100755 python/tutorials/tle/raw/nvshmem/run.sh diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 680bf0e045..69102ed61b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateExternCallOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7fd215f9de..95239b3cb9 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -845,6 +845,32 @@ def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return", let assemblyFormat = "attr-dict ($result^ `:` type($result))?"; } +// +// External Call op +// +def TT_ExternCallOp : TT_Op<"extern_call", [ + DeclareOpInterfaceMethods, + ConditionallySpeculatable, +]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs Variadic:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + // // External Elementwise op // diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index d4f49c8d18..b3c9e36459 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_triton_library(TritonGPUToLLVM TypeConverter.cpp Utility.cpp ViewOpToLLVM.cpp + ExternCallOpToLLVM.cpp DEPENDS TritonGPUConversionPassIncGen diff --git a/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp new file mode 100644 index 0000000000..81b6bf256e --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp @@ -0,0 +1,61 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +class ExternCallOpConversion + : public ConvertOpToLLVMPattern { +public: + ExternCallOpConversion(const LLVMTypeConverter &converter, + const PatternBenefit &benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ExternCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + if (op->getNumResults() > 1) { + llvm::errs() << "ExternCallConversion does not support multi outs."; + return failure(); + } + + LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto newOperands = adaptor.getOperands(); + Type retType = + op->getNumResults() == 0 + ? voidTy + : this->getTypeConverter()->convertType(op->getResult(0).getType()); + std::string funcName = op.getSymbol().str(); + StringRef libname = op.getLibname(); + StringRef libpath = op.getLibpath(); + + Operation *externCallOp; + Type funcType = mlir::triton::gpu::getFunctionType(retType, newOperands); + LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, libname, libpath); + externCallOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, newOperands); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, externCallOp->getResult(0)); + } + + return success(); + } +}; + +} // namespace + +void mlir::triton::populateExternCallOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index bf509ffe52..512ab61f61 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -618,6 +618,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 27fa26554d..854ae1f2bc 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1313,6 +1313,24 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { return Speculation::NotSpeculatable; } +// -- ExternCallOp -- +void ExternCallOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternCallOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + // -- GatherOp -- LogicalResult GatherOp::verify() { RankedTensorType indicesTy = getIndices().getType(); diff --git a/python/src/ir.cc b/python/src/ir.cc index e5591dcb08..9abfe8472f 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1673,6 +1673,14 @@ void init_triton_ir(py::module &&m) { return self.create(retType, argList, libName, libPath, symbol, isPure); }) + .def("create_extern_call", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, const std::vector &retTypes, + bool isPure) -> OpState { + return self.create(retTypes, argList, libName, + libPath, symbol, isPure); + }) // Built-in instruction .def("create_get_program_id", [](TritonOpBuilder &self, int axis) -> Value { diff --git a/python/triton/experimental/tle/language/raw/core.py b/python/triton/experimental/tle/language/raw/core.py index 3ae5e56e32..8de3c40a09 100644 --- a/python/triton/experimental/tle/language/raw/core.py +++ b/python/triton/experimental/tle/language/raw/core.py @@ -1,6 +1,26 @@ import triton.language as tl from triton.language.core import builtin, constexpr as tl_constexpr, tensor from triton.experimental.tle.language.gpu import buffered_tensor +import importlib.util + + +def _pointer_type_hash(self): + return hash((self.name, self.element_ty, "tt_ptr")) + + +def patch_hash_method_for_pointer_type(): + elem_dtype_list = tl.core.dtype.SINT_TYPES + tl.core.dtype.UINT_TYPES + tl.core.dtype.FP_TYPES + tl.core.dtype.OTHER_TYPES + for elem_dtype in elem_dtype_list: + ptr_ty = type(tl.core.pointer_type(tl.core.dtype(elem_dtype))) + ptr_ty.__hash__ = _pointer_type_hash + + +def import_from_path(file_path): + module_name = f"_imported_{abs(hash(file_path))}" + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module def _resolve_alias_indices(func, llvm, handles, output_indices, _semantic): @@ -42,14 +62,22 @@ def _normalize_hint(hint): def _tle_raw_call(func, args, *, output_indices, hint, smem, _semantic): hint = _normalize_hint(hint) - handles = [arg.handle for arg in args] if getattr(func, "deferred", False): + handles = [arg.handle for arg in args] if output_indices is None: raise RuntimeError("deferred tle_raw.call requires explicit output_indices=") alias_indices = output_indices source_id = func.register_pending_source(hint=hint) dsl_region_op = func.create_region_deferred(_semantic.builder, source_id, handles, alias_indices, hint) else: + if func.compiler.lower() == "nvcc" or (func.compiler.lower() == "clang" and func.target.lower() == "bc"): + patch_hash_method_for_pointer_type() + module = import_from_path(func.extern_file) + target_fn = getattr(module, func.extern_func_name) + ret = target_fn(*args, _semantic=_semantic) + return ret + + handles = [arg.handle for arg in args] context = _semantic.builder.get_context() llvm = func.make_llvm(context) alias_indices = _resolve_alias_indices(func, llvm, handles, output_indices, _semantic) diff --git a/python/triton/experimental/tle/raw/cuda/runtime.py b/python/triton/experimental/tle/raw/cuda/runtime.py index 2fc2ae69fd..75e13224e3 100644 --- a/python/triton/experimental/tle/raw/cuda/runtime.py +++ b/python/triton/experimental/tle/raw/cuda/runtime.py @@ -8,6 +8,11 @@ from typing import Any, Final import torch +import tempfile +import signal +from triton import knobs +from triton.runtime.errors import PTXASError +from functools import partial from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports] from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports] @@ -17,6 +22,16 @@ CLANG = os.getenv("CLANG", "clang") CLANG_FLAGS = shlex.split(os.getenv("CLANG_FLAGS", "")) +NVCC = os.getenv("NVCC", "nvcc") +NVCC_FLAGS = shlex.split(os.getenv("NVCC_FLAGS", "")) + +PTXAS = os.getenv("PTXAS", "ptxas") + +NVLINK = os.getenv("NVLINK", "nvlink") +NVLINK_FLAGS = shlex.split(os.getenv("NVLINK_FLAGS", "")) + +OPT = os.getenv("OPT", "opt") + def _sanitize_clang_ir(ir: str) -> str: # Newer clang emits attributes that this Triton branch's LLVM parser does @@ -43,23 +58,117 @@ def _get_cuda_gpu_arch() -> str: if arch: return f"--cuda-gpu-arch={arch}" major, minor = torch.cuda.get_device_capability() - return f"--cuda-gpu-arch=sm_{major}{minor}" + suffix = "a" if major >= 9 else "" + return f"--cuda-gpu-arch=sm_{major}{minor}{suffix}" + + +def make_cubin_inspection_hook(cuda_self, triton_self, stages, options, language, capability): + + def make_cubin(self, src, metadata, opt, capability): + fsrc_cuda = cuda_self.source_file + arch = _get_cuda_gpu_arch().split('=')[1] + fbin_cuda = tempfile.NamedTemporaryFile(delete=False, suffix='.o').name + + build = subprocess.run([NVCC, "-c", "-rdc=true", f"-arch={arch}", *NVCC_FLAGS, "-o", fbin_cuda, fsrc_cuda], + capture_output=True) + assert build.returncode == 0, (f"nvcc failed\nstderr:\n{build.stderr.decode()}") + + with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc_triton, \ + tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: + fsrc_triton.write(src) + fsrc_triton.flush() + fbin_triton = fsrc_triton.name + '.o' + fbin_combined = fbin_triton + '.combined.cubin' + + compile_only_cmds = ['-c'] + line_info = ["-lineinfo", "-suppress-debug-info"] if knobs.compilation.disable_line_info else ["-lineinfo"] + fmad = [] if opt.enable_fp_fusion else ['--fmad=false'] + disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else [] + ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else [] + + ptxas_cmd = [ + PTXAS, *compile_only_cmds, *line_info, *fmad, '-v', *disable_opt, *ptx_extra_options, + f'--gpu-name={arch}', fsrc_triton.name, '-o', fbin_triton + ] + + try: + subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog) + if os.path.exists(fsrc_triton.name): + os.remove(fsrc_triton.name) + if os.path.exists(flog.name): + os.remove(flog.name) + except subprocess.CalledProcessError as e: + with open(flog.name) as log_file: + log = log_file.read() + if os.path.exists(flog.name): + os.remove(flog.name) + + if e.returncode == 255: + error = 'Internal Triton PTX codegen error' + elif e.returncode == 128 + signal.SIGSEGV: + error = '`ptxas` raised SIGSEGV' + else: + error = f'`ptxas` failed with error code {e.returncode}' + raise PTXASError(f"{error}\n" + f"`ptxas` stderr:\n{log}\n" + f'Repro command: {" ".join(ptxas_cmd)}\n') + + nvlink_cmds = [ + NVLINK, + f"-arch={arch}", + *NVLINK_FLAGS, + fbin_triton, + fbin_cuda, + "-o", + fbin_combined, + ] + + try: + subprocess.run(nvlink_cmds, check=True, close_fds=False, stderr=flog) + except Exception as e: + import logging + logging.error(f"error runing nvlink: {shlex.join(nvlink_cmds)}") + logging.exception(e) + + with open(fbin_combined, 'rb') as f: + cubin = f.read() + if os.path.exists(fbin_combined): + os.remove(fbin_combined) + if os.path.exists(fbin_triton): + os.remove(fbin_triton) + if os.path.exists(fbin_cuda): + os.remove(fbin_cuda) + return cubin + + stages["cubin"] = lambda src, metadata: make_cubin(triton_self, src, metadata, options, triton_self.target.arch) class CUDAJITFunction(object): def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None: - super().__init__(*args, **{k: v for k, v in kwargs.items() if k not in ("extern_func_name", "deferred")}) + super().__init__( + *args, **{ + k: v + for k, v in kwargs.items() + if k not in ("compiler", "target", "extern_file", "extern_func_name", "deferred") + }) self.fn: Final[Any] = fn self.code: Final[str] = file.read_text() self.region_dialect: Final[str] = "cuda" self.lowered_region_dialect: Final[str] = "llvm" self.arg_dialect: Final[str] = "llvm" self.source_file: Final[str] = str(file) + self.compiler = kwargs.get("compiler", None) + self.target = kwargs.get("target", None) + self.extern_file = kwargs.get("extern_file", None) self.extern_func_name = kwargs.get("extern_func_name", None) self.deferred: Final[bool] = kwargs.get("deferred", False) self.__triton_builtin__: Final[bool] = True + if self.compiler.lower() == "nvcc" and knobs.runtime.add_stages_inspection_hook is None: + nvcc_cuda_hook = partial(make_cubin_inspection_hook, self) + knobs.runtime.add_stages_inspection_hook = nvcc_cuda_hook + def register_pending_source(self, *, hint: str = "") -> str: if not self.extern_func_name: raise RuntimeError("deferred tle_raw CUDA source requires extern_func_name= " @@ -116,6 +225,37 @@ def make_llvm(self, mlir_context) -> str: module = parse_llvm_ir(_sanitize_clang_ir(build.stdout.decode()), llvm_context, mlir_context) return f"{module}" + def make_bc(self, public_api_names=None): + fbc_cuda_unopti = Path(self.source_file).with_suffix('.bc.unopti') + fbc_cuda = Path(self.source_file).with_suffix('.bc') + + build = subprocess.run([ + CLANG, "-c", "-x", "cuda", "--cuda-device-only", + _get_cuda_gpu_arch(), "-emit-llvm", "-fcuda-flush-denormals-to-zero", *CLANG_FLAGS, "-o", fbc_cuda_unopti, + self.source_file + ], capture_output=True) + assert build.returncode == 0, (f"clang failed\nstderr:\n{build.stderr.decode()}") + + if public_api_names is None: + public_api_names = [self.extern_func_name] + elif isinstance(public_api_names, str): + public_api_names = [public_api_names] + else: + public_api_names = list(public_api_names) + if not public_api_names or any(not name for name in public_api_names): + raise ValueError("make_bc requires at least one public API name") + public_api_list = ",".join(dict.fromkeys(public_api_names)) + + opt = subprocess.run([ + OPT, "--passes=internalize,inline,globaldce", f"-internalize-public-api-list={public_api_list}", "-o", + fbc_cuda, fbc_cuda_unopti + ], capture_output=True) + assert opt.returncode == 0, (f"opt failed\nstderr:\n{opt.stderr.decode()}") + + if os.path.exists(fbc_cuda_unopti): + os.remove(fbc_cuda_unopti) + return fbc_cuda + def compile_deferred_pending_source(entry: dict, *, context) -> str: source_text = entry["source"] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 439671c3a1..c87872c3af 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3380,6 +3380,59 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type) +def dispatch_ec(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _semantic=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _semantic: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_types = arg_type_symbol_dict[arg_types][1] + if not isinstance(ret_types, (builtins.list, builtins.tuple)): + ret_types = [ret_types] + + if symbol == "": + raise ValueError("Symbol can not be empty") + call = func(lib_name, lib_path, symbol, arg_list, [ret_type.to_ir(_semantic.builder) for ret_type in ret_types], + is_pure) + + if len(ret_types) == 0: + return tensor(call, void) + if len(ret_types) == 1: + return tensor(call.get_result(0), ret_types[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(ret_types)) + + @builtin def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _semantic=None): @@ -3423,6 +3476,41 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic) +@builtin +def extern_call(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _semantic=None): + ''' + Dispatch an function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _semantic: the semantic + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _semantic.to_tensor(dispatch_args[i]) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if not all_scalar: + raise ValueError("extern call only support inputs with scalr type") + + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + func = _semantic.builder.create_extern_call + return dispatch_ec(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, is_pure, _semantic) + + def binary_op_type_legalization(lhs, rhs, semantic): ''' Convert both operands to a single common type diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu new file mode 100644 index 0000000000..93f668f478 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu @@ -0,0 +1,10 @@ +extern "C" __device__ int nvshmem_my_pe(); +extern "C" __device__ int nvshmem_n_pes(); +extern "C" __device__ void nvshmem_int_p(int *dest, int value, int pe); + +extern "C" __device__ void simple_shift(int *destination) { + int mype = nvshmem_my_pe(); + int npes = nvshmem_n_pes(); + int peer = (mype + 1) % npes; + nvshmem_int_p(destination, mype, peer); +} diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu new file mode 100644 index 0000000000..610cf7d665 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu @@ -0,0 +1,47 @@ +#include +#include +#include +#include + +#undef CUDA_CHECK +#define CUDA_CHECK(stmt) \ + do { \ + cudaError_t result = (stmt); \ + if (cudaSuccess != result) { \ + fprintf(stderr, "[%s:%d] cuda failed with %s \n", __FILE__, __LINE__, \ + cudaGetErrorString(result)); \ + exit(-1); \ + } \ + } while (0) + +extern "C" void simple_shift_before_launch(int *mype, int *npes, + int *mype_in_node, int *npes_in_node, + cudaStream_t *stream, int **dst, + int **data_h) { + nvshmem_init(); + *mype = nvshmem_my_pe(); + *npes = nvshmem_n_pes(); + *mype_in_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + *npes_in_node = nvshmem_team_n_pes(NVSHMEMX_TEAM_NODE); + + CUDA_CHECK(cudaSetDevice(*mype_in_node)); + CUDA_CHECK(cudaStreamCreate(stream)); + + *dst = (int *)nvshmem_malloc(sizeof(int)); + *data_h = (int *)malloc(sizeof(int)); +} + +extern "C" void simple_shift_after_launch(cudaStream_t stream, void *dst, + void *data_h, int mype, int npes) { + nvshmemx_barrier_all_on_stream(stream); + cudaMemcpyAsync(data_h, dst, sizeof(int), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + int *data = (int *)data_h; + printf("%d: received message %d\n", mype, data[0]); + + nvshmem_free(dst); + free(data_h); + + nvshmem_finalize(); +} diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py new file mode 100644 index 0000000000..d45d8460f2 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py @@ -0,0 +1,98 @@ +import ctypes +import torch +import triton +import triton.experimental.tle.language.raw as tle_raw + +from pathlib import Path +from triton.experimental.tle.raw import dialect + +from common.utils import ( + install_cumodule_hook, + load_library, + prepare_clang_bitcode, +) + + +@dialect( + name="cuda", + compiler="clang", + target="bc", + file=(Path(__file__).parent / "simple-shift-device.cu"), + extern_file=(Path(__file__).parent / "simple-shift-device-extern-call.py"), + extern_func_name="simple_shift", +) +def simple_shift(*args, **kwargs): + ... + + +@triton.jit +def simple_shift_kernel(destination_ptr, ): + tle_raw.call(simple_shift, [destination_ptr]) + + +def tensor_from_pointer(pointer, shape, dtype, device): + num_elements = 1 + for extent in shape: + num_elements *= extent + storage = torch._C._construct_storage_from_data_pointer( + pointer.value, + device, + num_elements * dtype.itemsize, + ) + return torch.empty(0, dtype=dtype, device=device).set_(storage).view(shape) + + +def simpe_shift(): + common_path = Path(__file__).parents[1] / "common" / "common-host.so" + host_path = Path(__file__).with_name("simple-shift-host.so") + common_lib = load_library(common_path) + host_lib = load_library(host_path) + + mype = ctypes.c_int() + npes = ctypes.c_int() + mype_in_node = ctypes.c_int() + npes_in_node = ctypes.c_int() + stream_ptr = ctypes.c_void_p() + destination_ptr = ctypes.c_void_p() + host_data_ptr = ctypes.c_void_p() + host_lib.simple_shift_before_launch( + ctypes.byref(mype), + ctypes.byref(npes), + ctypes.byref(mype_in_node), + ctypes.byref(npes_in_node), + ctypes.byref(stream_ptr), + ctypes.byref(destination_ptr), + ctypes.byref(host_data_ptr), + ) + + device = triton.runtime.driver.active.get_active_torch_device() + destination = tensor_from_pointer( + destination_ptr, + (1, ), + torch.int32, + device, + ) + stream = torch.cuda.ExternalStream(stream_ptr.value, device=device) + install_cumodule_hook(common_lib) + + extern_libs = prepare_clang_bitcode( + common_lib, + mype_in_node.value, + Path(__file__).with_name("simple-shift-device.bc"), + simple_shift, + ) + + with torch.cuda.stream(stream): + simple_shift_kernel[(1, )](destination, extern_libs=extern_libs) + + host_lib.simple_shift_after_launch( + stream_ptr, + destination_ptr, + host_data_ptr, + mype_in_node.value, + npes_in_node.value, + ) + + +if __name__ == "__main__": + simpe_shift() diff --git a/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-device.cu b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-device.cu new file mode 100644 index 0000000000..0680fc4616 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-device.cu @@ -0,0 +1,60 @@ +#include +#include +#include + +extern "C" __device__ void +nvshmemx_putmem_signal_nbi_block(void *dest, const void *source, size_t nbytes, + uint64_t *sig_addr, uint64_t signal, + int sig_op, int pe); + +extern "C" __device__ uint64_t nvshmem_signal_wait_until(uint64_t *sig_addr, + int cmp, + uint64_t cmp_value); + +enum { + NVSHMEM_CMP_GE = 5, + NVSHMEM_SIGNAL_SET = 9, +}; + +extern "C" __device__ void ag_mark_local_ready(uint64_t *ready, int rank, + int num_chunks) { + int chunk_id = (int)blockIdx.x; + if (chunk_id >= num_chunks) { + return; + } + if (threadIdx.x == 0) { + __threadfence_system(); + ready[(size_t)rank * num_chunks + chunk_id] = 1; + } + __syncthreads(); +} + +// One Triton program publishes one chunk of this rank's A slice to one peer. +extern "C" __device__ void +ag_publish_local_chunk(__half *workspace, uint64_t *ready, + int elements_per_rank, int elements_per_chunk, + int num_chunks, int rank, int world_size) { + int block_id = (int)blockIdx.x; + int peer_offset = block_id / num_chunks + 1; + int chunk_id = block_id % num_chunks; + if (peer_offset >= world_size) { + return; + } + + int peer = (rank + peer_offset) % world_size; + __half *local_chunk = workspace + (size_t)rank * elements_per_rank + + (size_t)chunk_id * elements_per_chunk; + uint64_t *chunk_ready = ready + (size_t)rank * num_chunks + chunk_id; + + nvshmemx_putmem_signal_nbi_block(local_chunk, local_chunk, + (size_t)elements_per_chunk * sizeof(__half), + chunk_ready, 1, NVSHMEM_SIGNAL_SET, peer); +} + +// The GEMM program waits for the source chunk whose A rows it will consume. +extern "C" __device__ void ag_wait_ready(uint64_t *ready, int signal_index) { + if (threadIdx.x == 0) { + nvshmem_signal_wait_until(ready + signal_index, NVSHMEM_CMP_GE, 1); + } + __syncthreads(); +} diff --git a/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-host.cu b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-host.cu new file mode 100644 index 0000000000..fd90e1e198 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm-host.cu @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK(stmt) \ + do { \ + cudaError_t result = (stmt); \ + if (result != cudaSuccess) { \ + fprintf(stderr, "[%s:%d] CUDA failed: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(result)); \ + return -1; \ + } \ + } while (0) + +extern "C" int ag_gemm_workspace_create(int elements_per_rank, int num_chunks, + void **workspace, uint64_t **ready, + int *mype, int *npes, int *mype_in_node, + int *npes_in_node) { + if (elements_per_rank <= 0 || num_chunks <= 0 || workspace == nullptr || + ready == nullptr) { + return -1; + } + + *mype = nvshmem_my_pe(); + *npes = nvshmem_n_pes(); + *mype_in_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + *npes_in_node = nvshmem_team_n_pes(NVSHMEMX_TEAM_NODE); + CUDA_CHECK(cudaSetDevice(*mype_in_node)); + + size_t workspace_bytes = (size_t)(*npes) * elements_per_rank * sizeof(__half); + *workspace = nvshmem_malloc(workspace_bytes); + *ready = (uint64_t *)nvshmem_calloc((size_t)(*npes) * num_chunks, + sizeof(uint64_t)); + if (*workspace == nullptr || *ready == nullptr) { + if (*ready != nullptr) { + nvshmem_free(*ready); + } + if (*workspace != nullptr) { + nvshmem_free(*workspace); + } + return -2; + } + + nvshmem_barrier_all(); + return 0; +} + +extern "C" int ag_gemm_workspace_destroy(void *workspace, void *ready) { + CUDA_CHECK(cudaDeviceSynchronize()); + nvshmem_barrier_all(); + nvshmem_free(ready); + nvshmem_free(workspace); + return 0; +} diff --git a/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm.py b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm.py new file mode 100644 index 0000000000..3cbee845e5 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/02-allgather-gemm/ag-gemm.py @@ -0,0 +1,568 @@ +import argparse +import ctypes +import os +from pathlib import Path + +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language.raw as tle_raw +from triton.experimental.tle.raw import dialect + +from common.utils import ( + load_library, + install_cumodule_hook, + init_torch_distributed, + init_nvshmem_by_torch_pg, + tensor_from_pointer, + prepare_clang_bitcode, +) + + +def _device_dialect(function_name): + return dialect( + name="cuda", + compiler="clang", + target="bc", + file=Path(__file__).parent / "ag-gemm-device.cu", + extern_file=Path(__file__).parent / "ag-gemm-device-extern-call.py", + extern_func_name=function_name, + ) + + +@_device_dialect("ag_publish_local_chunk") +def publish_chunk(*args, **kwargs): + ... + + +@_device_dialect("ag_mark_local_ready") +def mark_local_ready(*args, **kwargs): + ... + + +@_device_dialect("ag_wait_ready") +def wait_ready(*args, **kwargs): + ... + + +@triton.jit +def set_local_ready(ready, rank, num_chunks): + tle_raw.call(mark_local_ready, [ready, rank, num_chunks]) + + +@triton.jit +def allgather_producer( + workspace, + ready, + elements_per_rank, + elements_per_chunk, + num_chunks, + rank, + world_size, +): + tle_raw.call( + publish_chunk, + [ + workspace, + ready, + elements_per_rank, + elements_per_chunk, + num_chunks, + rank, + world_size, + ], + ) + + +@triton.jit +def ag_gemm_consumer( + a_ptr, + b_ptr, + c_ptr, + ready, + M, + N, + K, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + CHUNK_M: tl.constexpr, + NUM_CHUNKS: tl.constexpr, + READY_VALUE: tl.constexpr, + LOCAL_WORLD_SIZE: tl.constexpr, +): + dtype = c_ptr.dtype.element_ty + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + m_per_rank = M // WORLD_SIZE + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + pid_m_offset = tl.cdiv(m_per_rank * RANK, BLOCK_M) + pid_m = (pid_m + pid_m_offset) % num_pid_m + + tile_m = pid_m * BLOCK_M + source_rank = tile_m // m_per_rank + source_rank_row = tile_m - source_rank * m_per_rank + chunk_id = source_rank_row // CHUNK_M + signal_index = source_rank * NUM_CHUNKS + chunk_id + tle_raw.call(wait_ready, [ready, signal_index]) + + offs_m = tile_m + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_offsets = k * BLOCK_K + offs_k + a_ptrs = a_ptr + offs_m[:, None] * K + k_offsets[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + k_offsets[:, None] + a_mask = (offs_m[:, None] < M) & (k_offsets[None, :] < K) + b_mask = (k_offsets[:, None] < K) & (offs_n[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, accumulator.to(dtype), mask=c_mask) + + +def ag_gemm_op( + a, + b, + c, + rank, + world_size, + workspace, + ready, + comm_stream, + compute_stream, + extern_libs, + chunk_m, + local_world_size, + block_m=128, + block_n=256, + block_k=64, + group_m=8, + stages=3, + ready_value=1, +): + assert a.shape[1] == b.shape[1], "incompatible GEMM dimensions" + assert a.dtype == b.dtype == c.dtype, "incompatible GEMM dtypes" + + m, k = workspace.shape + m_per_rank = m // world_size + n_per_rank = b.shape[0] + num_chunks = m_per_rank // chunk_m + elements_per_chunk = chunk_m * k + total_tiles = triton.cdiv(m, block_m) * triton.cdiv(n_per_rank, block_n) + grid = (total_tiles, ) + + local_ready = torch.cuda.Event() + comm_start = torch.cuda.Event(enable_timing=True) + comm_done = torch.cuda.Event(enable_timing=True) + compute_start = torch.cuda.Event(enable_timing=True) + compute_done = torch.cuda.Event(enable_timing=True) + current_stream = torch.cuda.current_stream(b.device) + + with torch.cuda.stream(compute_stream): + compute_stream.wait_stream(current_stream) + local_ready.record(compute_stream) + + with torch.cuda.stream(comm_stream): + comm_stream.wait_event(local_ready) + comm_start.record(comm_stream) + allgather_producer[((world_size - 1) * num_chunks, )]( + workspace, + ready, + m_per_rank * k, + elements_per_chunk, + num_chunks, + rank, + world_size, + num_warps=32, + extern_libs=extern_libs, + ) + comm_done.record(comm_stream) + + with torch.cuda.stream(compute_stream): + compute_start.record(compute_stream) + ag_gemm_consumer[grid]( + workspace, + b, + c, + ready, + m, + n_per_rank, + k, + rank, + world_size, + block_m, + block_n, + block_k, + group_m, + chunk_m, + num_chunks, + ready_value, + local_world_size, + num_warps=8, + num_stages=stages, + extern_libs=extern_libs, + ) + compute_done.record(compute_stream) + + current_stream.wait_event(comm_done) + current_stream.wait_event(compute_done) + return c, { + "comm": (comm_start, comm_done), + "compute": (compute_start, compute_done), + } + + +def triton_prepare( + a_local, + workspace, + ready, + rank, + extern_libs, + chunk_m, +): + m_per_rank = a_local.shape[0] + num_chunks = m_per_rank // chunk_m + + ready.zero_() + workspace[rank * m_per_rank:(rank + 1) * m_per_rank].copy_(a_local) + set_local_ready[(num_chunks, )]( + ready, + rank, + num_chunks, + num_warps=1, + extern_libs=extern_libs, + ) + + +def torch_ag_gemm(group, a_local, b, gathered): + torch.distributed.all_gather_into_tensor(gathered, a_local, group=group) + return torch.matmul(gathered, b.T) + + +def create_workspace(host, world_size, m_per_rank, k, num_chunks, device): + workspace_ptr = ctypes.c_void_p() + ready_ptr = ctypes.c_void_p() + mype = ctypes.c_int() + npes = ctypes.c_int() + local_pe = ctypes.c_int() + local_npes = ctypes.c_int() + result = host.ag_gemm_workspace_create( + m_per_rank * k, + num_chunks, + ctypes.byref(workspace_ptr), + ctypes.byref(ready_ptr), + ctypes.byref(mype), + ctypes.byref(npes), + ctypes.byref(local_pe), + ctypes.byref(local_npes), + ) + assert result == 0, f"workspace allocation failed: {result}" + assert npes.value == world_size + workspace = tensor_from_pointer( + workspace_ptr, + (world_size * m_per_rank, k), + torch.float16, + device, + ) + ready = tensor_from_pointer( + ready_ptr, + (world_size, num_chunks), + torch.uint64, + device, + ) + return workspace_ptr, ready_ptr, workspace, ready, mype, local_pe, local_npes + + +def perf(fn, group, warmup, iters, prepare_fn=None): + assert warmup >= 0 and iters > 0 + + def unpack_result(result): + if isinstance(result, tuple) and len(result) == 2: + maybe_events = result[1] + if isinstance(maybe_events, dict): + return result[0], maybe_events + return result, {} + + def prepare(): + torch.distributed.barrier(group=group) + if prepare_fn is not None: + prepare_fn() + torch.cuda.synchronize() + torch.distributed.barrier(group=group) + + output = None + for _ in range(warmup): + prepare() + output, _ = unpack_result(fn()) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + latencies_ms = [] + comm_latencies_ms = [] + compute_latencies_ms = [] + for _ in range(iters): + prepare() + start.record() + output, profile_events = unpack_result(fn()) + end.record() + end.synchronize() + latencies_ms.append(start.elapsed_time(end)) + if "comm" in profile_events: + comm_start, comm_end = profile_events["comm"] + comm_latencies_ms.append(comm_start.elapsed_time(comm_end)) + if "compute" in profile_events: + compute_start, compute_end = profile_events["compute"] + compute_latencies_ms.append(compute_start.elapsed_time(compute_end)) + + profile = {"total": sum(latencies_ms) / len(latencies_ms)} + if comm_latencies_ms: + profile["comm"] = sum(comm_latencies_ms) / len(comm_latencies_ms) + if compute_latencies_ms: + profile["compute"] = (sum(compute_latencies_ms) / len(compute_latencies_ms)) + return output, profile + + +def print_perf( + name: str, + value: float, + group, + rank: int, + world_size: int, + unit: str = "ms", +): + for index in range(world_size): + torch.distributed.barrier(group=group) + if rank == index: + print(f"{name} #{rank}: {value:.4f} {unit}", flush=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Chunked NVSHMEM all-gather GEMM with Triton-distributed-style benchmarking") + parser.add_argument("--m-per-rank", type=int, default=1024) + parser.add_argument( + "--chunk-m", + type=int, + default=1024, + help="rows per independently transferred and signaled A chunk", + ) + parser.add_argument( + "--n-per-rank", + type=int, + default=4096, + help="local output width (the local B shard has shape N_per_rank x K)", + ) + parser.add_argument("--k", type=int, default=8192) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--iters", type=int, default=5) + + parser.add_argument( + "--mode", + choices=("check", "perf"), + default="check", + help=("check: run correctness only; " + "perf: run benchmark only"), + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + group = init_torch_distributed() + rank = group.rank() + world_size = group.size() + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device("cuda", local_rank) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + assert world_size >= 2, "AG-GEMM overlap requires at least two GPUs" + assert world_size == local_world_size, ("this example follows Triton-distributed's single-node AG-GEMM path: " + "WORLD_SIZE must equal LOCAL_WORLD_SIZE") + + a_local = torch.randn((args.m_per_rank, args.k), device=device, dtype=torch.float16) + b = torch.randn((args.n_per_rank, args.k), device=device, dtype=torch.float16) + + common_path = Path(__file__).parents[1] / "common" / "common-host.so" + host_path = Path(__file__).with_name("ag-gemm-host.so") + common = load_library(common_path) + host = load_library(host_path) + init_nvshmem_by_torch_pg(common, group) + install_cumodule_hook(common) + + bitcode_path = Path(__file__).with_name("ag-gemm-device.bc") + extern_libs = prepare_clang_bitcode( + common, local_rank, bitcode_path, publish_chunk, + public_api_names=["ag_mark_local_ready", "ag_publish_local_chunk", "ag_wait_ready"]) + + num_chunks = args.m_per_rank // args.chunk_m + ( + workspace_ptr, + ready_ptr, + workspace, + ready, + mype, + local_pe, + local_npes, + ) = create_workspace( + host, + world_size, + args.m_per_rank, + args.k, + num_chunks, + device, + ) + assert mype.value == rank + assert local_pe.value == local_rank + + comm_stream = torch.cuda.Stream(device=device) + compute_stream = torch.cuda.Stream(device=device) + c = torch.empty( + (world_size * args.m_per_rank, args.n_per_rank), + dtype=a_local.dtype, + device=device, + ) + gathered = torch.empty( + (world_size * args.m_per_rank, args.k), + dtype=a_local.dtype, + device=device, + ) + + try: + + def prepare_triton_mode(): + triton_prepare( + a_local, + workspace, + ready, + rank, + extern_libs, + args.chunk_m, + ) + + def triton_func(): + return ag_gemm_op( + a_local, + b, + c, + rank, + world_size, + workspace, + ready, + comm_stream, + compute_stream, + extern_libs, + args.chunk_m, + local_world_size, + ) + + def torch_func(): + return torch_ag_gemm(group, a_local, b, gathered) + + def run_correctness(): + if rank == 0: + print("[check] start correctness validation", flush=True) + prepare_triton_mode() + output, _ = triton_func() + torch.cuda.synchronize(device) + golden = torch_func() + torch.cuda.synchronize(device) + torch.testing.assert_close(output, golden, atol=2e-2, rtol=2e-2) + torch.distributed.barrier(group=group) + if rank == 0: + print("[check] Pass!", flush=True) + + def run_benchmark(): + if rank == 0: + print( + f"[bench] start benchmark: warmup={args.warmup}, " + f"iters={args.iters}", + flush=True, + ) + _, triton_profile = perf( + triton_func, + group, + args.warmup, + args.iters, + prepare_fn=prepare_triton_mode, + ) + torch.cuda.synchronize(device) + + _, torch_profile = perf( + torch_func, + group, + args.warmup, + args.iters, + ) + torch.cuda.synchronize(device) + torch.distributed.barrier(group=group) + + print_perf( + "dist-triton ag-gemm", + triton_profile["total"], + group, + rank, + world_size, + ) + print_perf( + "torch ag-gemm", + torch_profile["total"], + group, + rank, + world_size, + ) + print_perf( + "speedup", + torch_profile["total"] / triton_profile["total"], + group, + rank, + world_size, + unit="x", + ) + + if rank == 0: + print(f"configuration: GPUs={world_size}, " + f"A_local=({args.m_per_rank}, {args.k}), " + f"B_local=({args.n_per_rank}, {args.k}), " + f"chunk_m={args.chunk_m}, chunks/rank={num_chunks}") + if "comm" in triton_profile and "compute" in triton_profile: + print( + f"dist-triton detail: comm={triton_profile['comm']:.4f} ms, " + f"compute={triton_profile['compute']:.4f} ms", + flush=True, + ) + + if args.mode == "check": + run_correctness() + if args.mode == "perf": + run_benchmark() + + finally: + torch.cuda.synchronize(device) + result = host.ag_gemm_workspace_destroy(workspace_ptr, ready_ptr) + assert result == 0 + torch.distributed.barrier(group=group) + result = common.nvshmem_finalize_from_torch_distributed() + assert result == 0 + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/python/tutorials/tle/raw/nvshmem/common/build.py b/python/tutorials/tle/raw/nvshmem/common/build.py new file mode 100644 index 0000000000..b07995bd29 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/common/build.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +import argparse +import ast +import importlib.util +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +NVSHMEM_ROOT = Path(__file__).resolve().parents[1] +GENERATOR_FILE = Path(__file__).with_name("generate_extern_call.py") +_SPEC = importlib.util.spec_from_file_location("_tle_generate_extern_call", GENERATOR_FILE) +if _SPEC is None or _SPEC.loader is None: + raise ImportError(f"cannot load extern call generator from {GENERATOR_FILE}") +_GENERATOR = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _GENERATOR +_DONT_WRITE_BYTECODE = sys.dont_write_bytecode +try: + sys.dont_write_bytecode = True + _SPEC.loader.exec_module(_GENERATOR) +finally: + sys.dont_write_bytecode = _DONT_WRITE_BYTECODE +generate = _GENERATOR.generate + + +def _last_string(node): + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + for child in reversed(list(ast.iter_child_nodes(node))): + value = _last_string(child) + if value is not None: + return value + return None + + +def _dialect_file_pairs(example_dir): + pairs = set() + for python_file in example_dir.glob("*.py"): + tree = ast.parse(python_file.read_text(encoding="utf-8"), filename=str(python_file)) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + name = getattr(node.func, "id", None) or getattr(node.func, "attr", None) + if name != "dialect": + continue + keywords = {keyword.arg: keyword.value for keyword in node.keywords if keyword.arg} + cuda_file_name = _last_string(keywords.get("file")) if "file" in keywords else None + extern_file_name = _last_string(keywords.get("extern_file")) if "extern_file" in keywords else None + if cuda_file_name and extern_file_name: + pairs.add((example_dir / cuda_file_name, example_dir / extern_file_name)) + return pairs + + +def generate_extern_files(example_dir): + pairs = _dialect_file_pairs(example_dir) + generated = [] + for cuda_file, output_file in sorted(pairs, key=lambda pair: str(pair[1])): + if not cuda_file.is_file(): + raise FileNotFoundError(f"CUDA device file does not exist: {cuda_file}") + generate(cuda_file, output_file) + generated.append(output_file) + return generated + + +def detect_arch(explicit_arch): + arch = explicit_arch or "sm_80" + arch = arch.removeprefix("-arch=").removeprefix("sm_") + return f"sm_{arch}" + + +def resolve_nvshmem_home(explicit_home): + if not explicit_home: + raise ValueError("NVSHMEM_HOME is required; set the environment variable or pass --nvshmem-home") + home = Path(explicit_home).expanduser().resolve() + + if not (home / "include").is_dir() or not (home / "lib").is_dir(): + raise FileNotFoundError(f"invalid NVSHMEM_HOME {home}: expected include/ and lib/ directories") + return home + + +def compile_common_host(nvshmem_home, arch, force): + cuda_file = Path(__file__).with_name("common-host.cu") + if not cuda_file.is_file(): + return None + + nvcc = os.getenv("NVCC", "nvcc") + output_file = cuda_file.with_suffix(".so") + arch_file = output_file.with_suffix(".so.arch") + current_arch = arch_file.read_text(encoding="utf-8").strip() if arch_file.is_file() else "" + + rebuild_reason = None + if force: + rebuild_reason = "--force" + elif not output_file.exists(): + rebuild_reason = "missing output" + elif output_file.stat().st_mtime_ns < cuda_file.stat().st_mtime_ns: + rebuild_reason = "source newer than output" + elif current_arch != arch: + rebuild_reason = f"arch changed: {current_arch or 'unknown'} -> {arch}" + + if rebuild_reason is None: + print(f"[reuse] common host: {output_file} (arch={arch})") + return output_file + + print(f"[build] common host: {output_file} ({rebuild_reason})") + temporary = tempfile.NamedTemporaryFile( + prefix=f".{output_file.name}.", + suffix=".tmp", + dir=output_file.parent, + delete=False, + ) + temporary_path = Path(temporary.name) + temporary.close() + command = [ + nvcc, + "-shared", + "-Xcompiler", + "-fPIC", + "-rdc=true", + f"-arch={arch}", + f"-I{nvshmem_home / 'include'}", + f"-L{nvshmem_home / 'lib'}", + "-lnvshmem_host", + "-lnvshmem_device", + "-o", + str(temporary_path), + str(cuda_file), + ] + try: + subprocess.run(command, check=True) + os.replace(temporary_path, output_file) + arch_file.write_text(f"{arch}\n", encoding="utf-8") + finally: + temporary_path.unlink(missing_ok=True) + return output_file + + +def compile_host_files(example_dir, nvshmem_home, arch, force): + nvcc = os.getenv("NVCC", "nvcc") + outputs = [] + for cuda_file in sorted(example_dir.glob("*-host.cu")): + output_file = cuda_file.with_suffix(".so") + rebuild_reason = None + if force: + rebuild_reason = "--force" + elif not output_file.exists(): + rebuild_reason = "missing output" + elif output_file.stat().st_mtime_ns < cuda_file.stat().st_mtime_ns: + rebuild_reason = "source newer than output" + + if rebuild_reason is None: + print(f"[reuse] host: {output_file}") + outputs.append(output_file) + continue + + print(f"[build] host: {output_file} ({rebuild_reason})") + temporary = tempfile.NamedTemporaryFile( + prefix=f".{output_file.name}.", + suffix=".tmp", + dir=output_file.parent, + delete=False, + ) + temporary_path = Path(temporary.name) + temporary.close() + command = [ + nvcc, + "-shared", + "-Xcompiler", + "-fPIC", + "-rdc=true", + f"-arch={arch}", + f"-I{nvshmem_home / 'include'}", + f"-L{nvshmem_home / 'lib'}", + "-lnvshmem_host", + "-lnvshmem_device", + "-o", + str(temporary_path), + str(cuda_file), + ] + try: + subprocess.run(command, check=True) + os.replace(temporary_path, output_file) + finally: + temporary_path.unlink(missing_ok=True) + outputs.append(output_file) + return outputs + + +def main(): + parser = argparse.ArgumentParser(description="Generate Triton extern calls and compile NVSHMEM host libraries.") + parser.add_argument("target", type=Path, help="NVSHMEM example Python file") + parser.add_argument("--nvshmem-home", default=os.getenv("NVSHMEM_HOME")) + parser.add_argument("--arch", default="sm_90", help="CUDA architecture (default: sm_90)") + parser.add_argument("--force", action="store_true", help="always rebuild host libraries") + args = parser.parse_args() + + target = args.target.expanduser().resolve() + if not target.is_file(): + parser.error(f"target does not exist: {target}") + if NVSHMEM_ROOT not in target.parents: + parser.error(f"target must be below {NVSHMEM_ROOT}") + + nvshmem_home = resolve_nvshmem_home(args.nvshmem_home) + arch = detect_arch(args.arch) + generated = generate_extern_files(target.parent) + common = compile_common_host(nvshmem_home, arch, args.force) + libraries = compile_host_files(target.parent, nvshmem_home, arch, args.force) + + for path in generated: + print(f"[prepare] extern file: {path}") + if common is not None: + print(f"[prepare] common host: {common}") + for path in libraries: + print(f"[prepare] host: {path}") + print(f"[prepare] NVSHMEM_HOME={nvshmem_home}") + print(f"[prepare] CUDA architecture={arch}") + + +if __name__ == "__main__": + main() diff --git a/python/tutorials/tle/raw/nvshmem/common/common-host.cu b/python/tutorials/tle/raw/nvshmem/common/common-host.cu new file mode 100644 index 0000000000..442262b441 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/common/common-host.cu @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include +#include + +#undef CUDA_CHECK +#define CUDA_CHECK(stmt) \ + do { \ + cudaError_t result = (stmt); \ + if (cudaSuccess != result) { \ + fprintf(stderr, "[%s:%d] cuda failed with %s \n", __FILE__, __LINE__, \ + cudaGetErrorString(result)); \ + exit(-1); \ + } \ + } while (0) + +extern "C" int nvshmemx_cumodule_init_wrapper(CUmodule module) { + return nvshmemx_cumodule_init(module); +} + +extern "C" void nvshmem_barrier_all_wrapper() { nvshmem_barrier_all(); } + +extern "C" int nvshmem_get_unique_id_bytes(void *uid_buffer, + size_t uid_buffer_size) { + if (uid_buffer_size < sizeof(nvshmemx_uniqueid_t)) { + return -1; + } + + nvshmemx_uniqueid_t uid; + nvshmemx_get_uniqueid(&uid); + memcpy(uid_buffer, &uid, sizeof(uid)); + return 0; +} + +extern "C" int nvshmem_init_from_torch_distributed(int rank, int nranks, + int cuda_device, + void *uid_buffer, + size_t uid_buffer_size) { + if (uid_buffer_size < sizeof(nvshmemx_uniqueid_t)) { + return -1; + } + + CUDA_CHECK(cudaSetDevice(cuda_device)); + + nvshmemx_uniqueid_t uid; + memcpy(&uid, uid_buffer, sizeof(uid)); + + nvshmemx_init_attr_t attr; + memset(&attr, 0, sizeof(attr)); + nvshmemx_set_attr_uniqueid_args(rank, nranks, &uid, &attr); + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + return 0; +} + +extern "C" int nvshmem_finalize_from_torch_distributed() { + nvshmem_finalize(); + return 0; +} diff --git a/python/tutorials/tle/raw/nvshmem/common/generate_extern_call.py b/python/tutorials/tle/raw/nvshmem/common/generate_extern_call.py new file mode 100644 index 0000000000..6f6010c392 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/common/generate_extern_call.py @@ -0,0 +1,197 @@ +"""Generate Triton extern-call wrappers from CUDA device functions.""" + +import re +from dataclasses import dataclass +from pathlib import Path + +_DTYPES = { + "bool": "int1", + "char": "int8", + "signed char": "int8", + "unsigned char": "uint8", + "int8_t": "int8", + "uint8_t": "uint8", + "short": "int16", + "short int": "int16", + "signed short": "int16", + "unsigned short": "uint16", + "int16_t": "int16", + "uint16_t": "uint16", + "int": "int32", + "signed": "int32", + "signed int": "int32", + "unsigned": "uint32", + "unsigned int": "uint32", + "int32_t": "int32", + "uint32_t": "uint32", + "long": "int64", + "long int": "int64", + "long long": "int64", + "long long int": "int64", + "unsigned long": "uint64", + "unsigned long int": "uint64", + "unsigned long long": "uint64", + "unsigned long long int": "uint64", + "int64_t": "int64", + "uint64_t": "uint64", + "size_t": "uint64", + "float": "fp32", + "double": "fp64", + "__half": "fp16", +} + +_QUALIFIERS = re.compile(r"\b(?:const|volatile|restrict|__restrict__|__restrict|__const__|" + r"__device__|__forceinline__|inline|static)\b") +_ATTRIBUTE = re.compile( + r"__attribute__\s*\(\((?:[^()]|\([^()]*\))*\)\)", + re.DOTALL, +) +_COMMENT = re.compile(r"//[^\n]*|/\*.*?\*/", re.DOTALL) +_FUNCTION = re.compile( + r'extern\s+"C"\s+(?P[^;{}]*?)\b(?P[A-Za-z_]\w*)\s*' + r"\((?P[^;{}]*?)\)\s*(?:noexcept\s*)?\{", + re.DOTALL, +) + + +@dataclass(frozen=True) +class CudaType: + dtype: str + pointer: bool = False + + @property + def triton_type(self) -> str: + dtype = f'core.dtype("{self.dtype}")' + return f"core.pointer_type({dtype})" if self.pointer else dtype + + +@dataclass(frozen=True) +class Parameter: + name: str + type: CudaType + + @property + def argument(self) -> str: + if self.type.pointer: + return self.name + return f"tl.cast({self.name}, tl.{self.type.dtype}, _semantic=_semantic)" + + +@dataclass(frozen=True) +class Function: + name: str + parameters: tuple[Parameter, ...] + return_type: CudaType | None + + +def _split_parameters(parameters: str) -> list[str]: + result = [] + start = 0 + depth = 0 + for index, char in enumerate(parameters): + if char in "([": + depth += 1 + elif char in ")]": + depth -= 1 + elif char == "," and depth == 0: + result.append(parameters[start:index].strip()) + start = index + 1 + tail = parameters[start:].strip() + if tail and tail != "void": + result.append(tail) + return result + + +def _parse_type(spelling: str, *, context: str) -> CudaType | None: + spelling = _ATTRIBUTE.sub(" ", spelling) + pointer = "*" in spelling or "[" in spelling + spelling = re.sub(r"\[[^\]]*\]", " ", spelling) + spelling = spelling.replace("*", " ") + spelling = _QUALIFIERS.sub(" ", spelling) + spelling = " ".join(spelling.split()) + if spelling == "void" and not pointer: + return None + if spelling == "void" and pointer: + return CudaType("uint8", pointer=True) + dtype = _DTYPES.get(spelling) + if dtype is None: + raise ValueError(f"unsupported CUDA type {spelling!r} in {context}") + return CudaType(dtype, pointer=pointer) + + +def _parse_parameter(parameter: str, function_name: str) -> Parameter: + parameter = _ATTRIBUTE.sub(" ", parameter).strip() + match = re.search(r"([A-Za-z_]\w*)\s*(?:\[[^\]]*\])?\s*$", parameter) + if match is None: + raise ValueError(f"cannot parse parameter {parameter!r} in {function_name}") + name = match.group(1) + type_spelling = parameter[:match.start(1)] + parameter[match.end(1):] + parsed_type = _parse_type(type_spelling, context=f"{function_name}.{name}") + if parsed_type is None: + raise ValueError(f"parameter {name!r} in {function_name} cannot have type void") + return Parameter(name, parsed_type) + + +def parse_cuda_functions(source: str) -> tuple[Function, ...]: + source = _COMMENT.sub("", source) + source = _ATTRIBUTE.sub(" ", source) + functions = [] + for match in _FUNCTION.finditer(source): + if "__device__" not in match.group("prefix"): + continue + prefix = _QUALIFIERS.sub(" ", _ATTRIBUTE.sub(" ", match.group("prefix"))) + return_type = _parse_type(prefix, context=f"{match.group('name')} return type") + parameters = tuple( + _parse_parameter(parameter, match.group("name")) for parameter in _split_parameters(match.group("params"))) + functions.append(Function(match.group("name"), parameters, return_type)) + return tuple(functions) + + +def _render_function(function: Function) -> str: + names = ", ".join(parameter.name for parameter in function.parameters) + signature = f"{names}, _semantic=None" if names else "_semantic=None" + arguments = ",\n ".join(parameter.argument for parameter in function.parameters) + if arguments: + arguments = f"[\n {arguments},\n ]" + else: + arguments = "[]" + types = "\n".join(f" {parameter.type.triton_type}," for parameter in function.parameters) + returns = "()" if function.return_type is None else f"({function.return_type.triton_type},)" + return (f"@core.extern\n" + f"def {function.name}({signature}):\n" + f" return core.extern_call(\n" + f' "",\n' + f' "",\n' + f" {arguments},\n" + f" {{\n" + f" (\n{types}\n" + f' ): ("{function.name}", {returns}),\n' + f" }},\n" + f" is_pure=False,\n" + f" _semantic=_semantic,\n" + f" )\n") + + +def generate( + cuda_file: str | Path, + output_file: str | Path, + required_function: str | None = None, +) -> Path: + cuda_file = Path(cuda_file) + output_file = Path(output_file) + functions = parse_cuda_functions(cuda_file.read_text(encoding="utf-8")) + if not functions: + raise ValueError(f"no extern \"C\" CUDA functions found in {cuda_file}") + names = {function.name for function in functions} + if required_function is not None and required_function not in names: + raise ValueError(f"extern function {required_function!r} was not found in {cuda_file}; " + f"found: {', '.join(sorted(names))}") + content = ("# Generated from " + f"{cuda_file.name}; do not edit manually.\n" + "import triton.language as tl\n" + "import triton.language.core as core\n\n\n" + + "\n\n".join(_render_function(function) for function in functions) + "\n") + output_file.parent.mkdir(parents=True, exist_ok=True) + if not output_file.exists() or output_file.read_text(encoding="utf-8") != content: + output_file.write_text(content, encoding="utf-8") + return output_file diff --git a/python/tutorials/tle/raw/nvshmem/common/utils.py b/python/tutorials/tle/raw/nvshmem/common/utils.py new file mode 100644 index 0000000000..9406e399ed --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/common/utils.py @@ -0,0 +1,103 @@ +import ctypes +from pathlib import Path + +import os +import torch +import datetime +import triton.knobs as knobs + + +def load_library(library_path): + library_path = Path(library_path).expanduser().resolve() + return ctypes.CDLL(str(library_path)) + + +def install_cumodule_hook(common): + + def hook(*args, **kwargs): + key = kwargs["key"] + function = kwargs["fn"].jit_function + device = kwargs["compile"]["device"] + kernel = function.device_caches[device][0].get(key) + assert kernel is not None + kernel._init_handles() + result = common.nvshmemx_cumodule_init_wrapper(ctypes.c_void_p(kernel.module)) + assert result == 0, f"nvshmemx_cumodule_init failed: {result}" + + knobs.runtime.jit_post_compile_hook = hook + + +def init_torch_distributed(): + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=torch.device("cuda", local_rank), + timeout=datetime.timedelta(seconds=1800), + ) + group = torch.distributed.new_group(ranks=list(range(world_size)), backend="nccl") + torch.distributed.barrier(group=group) + return group + + +def init_nvshmem_by_torch_pg(common, group): + rank = group.rank() + world_size = group.size() + uid_size = 1024 + uid_buffer = ctypes.create_string_buffer(uid_size) + + if rank == 0: + result = common.nvshmem_get_unique_id_bytes(uid_buffer, uid_size) + assert result == 0, f"nvshmemx_get_uniqueid failed: {result}" + uid = bytes(uid_buffer.raw) + else: + uid = bytes(uid_size) + + objects = [uid] + torch.distributed.broadcast_object_list( + objects, + src=torch.distributed.get_global_rank(group, 0), + group=group, + ) + uid_buffer = ctypes.create_string_buffer(objects[0], uid_size) + result = common.nvshmem_init_from_torch_distributed( + rank, + world_size, + int(os.environ["LOCAL_RANK"]), + uid_buffer, + uid_size, + ) + assert result == 0, f"NVSHMEM init failed: {result}" + torch.distributed.barrier(group=group) + + +def tensor_from_pointer(pointer, shape, dtype, device): + elements = 1 + for extent in shape: + elements *= extent + storage = torch._C._construct_storage_from_data_pointer( + pointer.value, + device, + elements * dtype.itemsize, + ) + return torch.empty(0, dtype=dtype, device=device).set_(storage).view(shape) + + +def prepare_clang_bitcode( + common, + local_rank, + bitcode_path, + dialect_function, + public_api_names=None, +): + bitcode_path = Path(bitcode_path).expanduser().resolve() + if local_rank == 0: + generated = dialect_function.make_bc(public_api_names) + assert generated == bitcode_path + common.nvshmem_barrier_all_wrapper() + assert bitcode_path.is_file(), f"missing device bitcode: {bitcode_path}" + return {bitcode_path.stem: str(bitcode_path)} diff --git a/python/tutorials/tle/raw/nvshmem/run.sh b/python/tutorials/tle/raw/nvshmem/run.sh new file mode 100755 index 0000000000..98daddd875 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/run.sh @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +PYTHON_BIN="${PYTHON:-python}" +NVSHMRUN_BIN="${NVSHMRUN:-nvshmrun}" +TORCHRUN_BIN="${TORCHRUN:-torchrun}" +LAUNCHER="${LAUNCHER:-nvshmrun}" +NP=2 +ARCH="sm_90" +FORCE=0 +NVSHMEM_HOME_ARG="${NVSHMEM_HOME:-}" + +usage() { + cat <<'EOF' +Usage: + run_nvshmem.sh [options] [-- ] + +Options: + --np N Number of NVSHMEM processes (default: 2) + --python PATH Python executable (default: $PYTHON or python) + --nvshmrun PATH nvshmrun executable (default: $NVSHMRUN or nvshmrun) + --nvshmem-home PATH NVSHMEM installation root (or set NVSHMEM_HOME) + --arch ARCH CUDA architecture (default: sm_90) + --force Rebuild the host shared library + -h, --help Show this help + +Examples: + ./run.sh --np 2 01-simple-shift/simple-shift.py +EOF +} + +while (($#)); do + case "$1" in + --np) + NP="$2" + shift 2 + ;; + --launcher) + LAUNCHER="$2" + shift 2 + ;; + --python) + PYTHON_BIN="$2" + shift 2 + ;; + --nvshmrun) + NVSHMRUN_BIN="$2" + shift 2 + ;; + --nvshmem-home) + NVSHMEM_HOME_ARG="$2" + shift 2 + ;; + --arch) + ARCH="$2" + shift 2 + ;; + --force) + FORCE=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + --) + shift + break + ;; + -*) + echo "Unknown option: $1" >&2 + usage >&2 + exit 2 + ;; + *) + TARGET_INPUT="$1" + shift + break + ;; + esac +done + +if [[ -z "${TARGET_INPUT:-}" ]]; then + usage >&2 + exit 2 +fi + +if [[ ! "$NP" =~ ^[1-9][0-9]*$ ]]; then + echo "Process count must be a positive integer: $NP" >&2 + exit 2 +fi + +if [[ "$LAUNCHER" != "nvshmrun" && "$LAUNCHER" != "torchrun" ]]; then + echo "Launcher must be one of: nvshmrun, torchrun" >&2 + exit 2 +fi + +if [[ -z "$NVSHMEM_HOME_ARG" ]]; then + echo "NVSHMEM_HOME is required; set it or pass --nvshmem-home PATH" >&2 + exit 2 +fi + +if [[ "${1:-}" == "--" ]]; then + shift +fi + +if [[ -f "$TARGET_INPUT" ]]; then + TARGET="$(cd -- "$(dirname -- "$TARGET_INPUT")" && pwd)/$(basename -- "$TARGET_INPUT")" +elif [[ -f "$SCRIPT_DIR/$TARGET_INPUT" ]]; then + TARGET="$SCRIPT_DIR/$TARGET_INPUT" +else + mapfile -t MATCHES < <(find "$SCRIPT_DIR" -mindepth 2 -maxdepth 2 -type f -name "$TARGET_INPUT") + if ((${#MATCHES[@]} != 1)); then + echo "Could not uniquely resolve example: $TARGET_INPUT" >&2 + exit 2 + fi + TARGET="${MATCHES[0]}" +fi + +if [[ "$LAUNCHER" == "nvshmrun" ]]; then + if ! NVSHMRUN_PATH="$(command -v "$NVSHMRUN_BIN")"; then + echo "Could not find nvshmrun executable: $NVSHMRUN_BIN" >&2 + exit 2 + fi + NVSHMRUN_BIN="$(cd -- "$(dirname -- "$NVSHMRUN_PATH")" && pwd)/$(basename -- "$NVSHMRUN_PATH")" +else + if ! TORCHRUN_PATH="$(command -v "$TORCHRUN_BIN")"; then + echo "Could not find torchrun executable: $TORCHRUN_BIN" >&2 + exit 2 + fi + TORCHRUN_BIN="$(cd -- "$(dirname -- "$TORCHRUN_PATH")" && pwd)/$(basename -- "$TORCHRUN_PATH")" +fi + +if [[ -n "$NVSHMEM_HOME_ARG" ]]; then + NVSHMEM_HOME_ARG="$(cd -- "$NVSHMEM_HOME_ARG" && pwd)" +fi + +PREPARE_ARGS=( + "$SCRIPT_DIR/common/build.py" + "$TARGET" + "--nvshmem-home" + "$NVSHMEM_HOME_ARG" +) +[[ -n "$ARCH" ]] && PREPARE_ARGS+=("--arch" "$ARCH") +((FORCE)) && PREPARE_ARGS+=("--force") + +"$PYTHON_BIN" "${PREPARE_ARGS[@]}" + +export NVSHMEM_HOME="$NVSHMEM_HOME_ARG" +export LD_LIBRARY_PATH="$NVSHMEM_HOME/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" +export PYTHONPATH="$SCRIPT_DIR${PYTHONPATH:+:$PYTHONPATH}" + +cd -- "$(dirname -- "$TARGET")" +if [[ "$LAUNCHER" == "nvshmrun" ]]; then + exec "$NVSHMRUN_BIN" -np "$NP" "$PYTHON_BIN" "$(basename -- "$TARGET")" "$@" +else + exec "$TORCHRUN_BIN" --nproc_per_node="$NP" "$(basename -- "$TARGET")" "$@" +fi diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 733f103c4a..d65658e38a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -228,6 +228,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateExternCallOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns,