Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions examples/07_gemm_all_scatter/gemm_all_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand Down Expand Up @@ -47,8 +47,7 @@ def persistent_gemm_all_scatter(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -67,12 +66,7 @@ def persistent_gemm_all_scatter(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down
21 changes: 4 additions & 17 deletions examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand All @@ -22,15 +22,8 @@ def tile_id_to_index_range(
):
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)

tile_in_group = tile_id % num_pid_in_group
pid_m = first_pid_m + (tile_in_group % group_size_m)
pid_n = tile_in_group // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm_start = pid_m * BLOCK_SIZE_M
rn_start = pid_n * BLOCK_SIZE_N
Expand Down Expand Up @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -152,12 +144,7 @@ def persistent_gemm_all_reduce(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand All @@ -22,15 +22,8 @@ def tile_id_to_index_range(
):
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)

tile_in_group = tile_id % num_pid_in_group
pid_m = first_pid_m + (tile_in_group % group_size_m)
pid_n = tile_in_group // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm_start = pid_m * BLOCK_SIZE_M
rn_start = pid_n * BLOCK_SIZE_N
Expand Down Expand Up @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -152,12 +144,7 @@ def persistent_gemm_all_reduce(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand Down Expand Up @@ -49,8 +49,7 @@ def persistent_gemm_all_scatter_wg_specialization(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -74,12 +73,7 @@ def persistent_gemm_all_scatter_wg_specialization(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down Expand Up @@ -149,12 +143,7 @@ def persistent_gemm_all_scatter_wg_specialization(
COMM_SMS = NUM_SMS - GEMM_SMS
pid = pid - GEMM_SMS
for tile_id in range(pid, total_tiles, COMM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

# Begin: See the if segment for explanation:
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand Down Expand Up @@ -45,8 +45,7 @@ def persistent_gemm(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -65,12 +64,7 @@ def persistent_gemm(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down Expand Up @@ -159,19 +153,13 @@ def persistent_all_scatter(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

for tile_id in range(pid, total_tiles, COMM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import triton
import triton.language as tl
from examples.common.utils import read_realtime
from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder

import sys
import os
Expand Down Expand Up @@ -44,8 +44,7 @@ def persistent_gemm(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
Expand All @@ -64,12 +63,7 @@ def persistent_gemm(
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down Expand Up @@ -155,19 +149,13 @@ def persistent_all_scatter(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

for tile_id in range(pid, total_tiles, COMM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
Expand Down
11 changes: 3 additions & 8 deletions examples/14_all_gather_gemm/all_gather_gemm_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import triton
import triton.language as tl
from examples.common.utils import chiplet_reorder, program_id_reorder
import torch
import iris

Expand Down Expand Up @@ -35,8 +36,7 @@ def persistent_ag_gemm(
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
Expand All @@ -52,12 +52,7 @@ def persistent_ag_gemm(
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

for tile_id in range(pid, total_tiles, NUM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down
11 changes: 3 additions & 8 deletions examples/14_all_gather_gemm/all_gather_gemm_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import triton
import triton.language as tl
from examples.common.utils import chiplet_reorder, program_id_reorder
import iris
import torch

Expand Down Expand Up @@ -99,8 +100,7 @@ def gemm_push_kernel(
world_size: tl.constexpr,
):
pid = tl.program_id(0)
if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
Expand All @@ -121,12 +121,7 @@ def gemm_push_kernel(
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

for tile_id in range(pid, total_tiles, NUM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
Expand Down
Loading