Skip to content

[KMCompiler] [TLERaw]Triton v3.6.x support nvshmem#682

Draft
lizhangyu258 wants to merge 1 commit into
flagos-ai:triton_v3.6.xfrom
lizhangyu258:triton_v3.6.x_support_nvshmem
Draft

[KMCompiler] [TLERaw]Triton v3.6.x support nvshmem#682
lizhangyu258 wants to merge 1 commit into
flagos-ai:triton_v3.6.xfrom
lizhangyu258:triton_v3.6.x_support_nvshmem

Conversation

@lizhangyu258

@lizhangyu258 lizhangyu258 commented Jun 10, 2026

Copy link
Copy Markdown

Preliminary Triton support for invoking NVSHMEM communication operations via external calls.

run

# 01-simple-shift
cd python/tutorials/tle/raw/nvshmem
## nvcc compiler
export NVCC_FLAGS="-I/data/zyuli/miniconda3/envs/flagtree_triton_v3.6.x/lib/python3.12/site-packages/nvidia/nvshmem/include"
export NVLINK_FLAGS="-L/data/zyuli/miniconda3/envs/flagtree_triton_v3.6.x/lib/python3.12/site-packages/nvidia/nvshmem/lib -lnvshmem_device"
./run.sh --launcher nvshmrun --np 4 --arch sm_90a 01-simple-shift/simple-shift.py
## clang compiler
export CLANG_FLAGS="-Xclang -mlink-bitcode-file -Xclang /data/zyuli/miniconda3/envs/flagtree_triton_v3.6.x/lib/python3.12/site-packages/nvidia/nvshmem/lib/libnvshmem_device.bc"
./run.sh --launcher nvshmrun --np 4 --arch sm_90a 01-simple-shift/simple-shift.py


# 02-allgather-gemm
cd python/tutorials/tle/raw/nvshmem
export CLANG_FLAGS="-Xclang -mlink-bitcode-file -Xclang /data/zyuli/miniconda3/envs/flagtree_triton_v3.6.x/lib/python3.12/site-packages/nvidia/nvshmem/lib/libnvshmem_device.bc"
## correctness
./run.sh --launcher torchrun --np 4 --arch sm_90a 02-allgather-gemm/ag-gemm.py --m-per-rank 4096 --chunk-m 4096 --n-per-rank 8192 --k 8192 --mode check
## performance
./run.sh --launcher torchrun --np 4 --arch sm_90a 02-allgather-gemm/ag-gemm.py --m-per-rank 4096 --chunk-m 4096 --n-per-rank 8192 --k 8192 --mode perf

allgather-gemm performance

config

  • single node; 4 GPUs
  • allgather:
    • chunk_m = local_m
    • grid = num_ranks - 1
    • config = {num_warps=32}
  • gemm:
    • grid = (cdiv(local_m * num_ranks , BM) * cdiv(n, BN))
    • config = {"BM": 128, "BN": 256, "BK": 64, "stage": 3, "num_warps": 8}
image
  • allgather
    • chunk_m = {128, 256, 512, 1024}
    • grid = (num_ranks - 1) * (local_m / chunk_m)
    • config = {num_warps=32}
image

@i3wanna2 i3wanna2 marked this pull request as draft June 10, 2026 06:24
@lizhangyu258 lizhangyu258 changed the title Triton v3.6.x support nvshmem [KMCompiler] [TLE-Raw]Triton v3.6.x support nvshmem Jun 10, 2026
@lizhangyu258 lizhangyu258 changed the title [KMCompiler] [TLE-Raw]Triton v3.6.x support nvshmem [KMCompiler] [TLERaw]Triton v3.6.x support nvshmem Jun 10, 2026
@lizhangyu258 lizhangyu258 force-pushed the triton_v3.6.x_support_nvshmem branch from 8cea58e to 351a431 Compare June 29, 2026 10:19
@lizhangyu258 lizhangyu258 force-pushed the triton_v3.6.x_support_nvshmem branch from 351a431 to dc3cd01 Compare June 29, 2026 12:46

def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None:
super().__init__(*args, **{k: v for k, v in kwargs.items() if k not in ("extern_func_name", "deferred")})
super().__init__(

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the parent class is object, simply calling super().__init__() is sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant