Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
iris,
load,
store,
copy,
get,
put,
atomic_add,
Expand Down Expand Up @@ -64,6 +65,7 @@
"iris",
"load",
"store",
"copy",
"get",
"put",
"atomic_add",
Expand Down
51 changes: 51 additions & 0 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,57 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
tl.store(translated_ptr, value, mask=mask)


@triton.jit
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
"""
Copies data from the specified rank's memory into the destination rank's memory.
This function performs the transfer by translating src_ptr from the from_rank's address
space to the to_rank's address space, performing a masked load from the translated
source, and storing the loaded data to dst_ptr in the to_rank memory location.
If from_rank and to_rank are the same, this function performs a local copy operation.
It is undefined behaviour if neither from_rank nor to_rank is the cur_rank.

Args:
src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's local memory from which to read data.
dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the to_rank's local memory where the data will be written.
from_rank (int): The rank ID that owns src_ptr (source rank).
to_rank (int): The rank ID that will receive the data (destination rank).
cur_rank (int): The rank ID issuing the copy operation. Must be either from_rank or to_rank.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.

Returns:
None

Example:
>>> @triton.jit
>>> def kernel(remote_ptr, local_ptr, heap_bases):
>>> from_rank = 1
>>> to_rank = 0
>>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases)
"""

cur_base = tl.load(heap_bases + cur_rank)

from_base = tl.load(heap_bases + from_rank)
to_base = tl.load(heap_bases + to_rank)

src_ptr_int = tl.cast(src_ptr, tl.uint64)
src_offset = src_ptr_int - cur_base

dst_ptr_int = tl.cast(dst_ptr, tl.uint64)
dst_offset = dst_ptr_int - cur_base

from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8))
to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8))

translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = tl.load(translated_src, mask=mask)
tl.store(translated_dst, data, mask=mask)


@triton.jit
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
Expand Down
75 changes: 75 additions & 0 deletions tests/unittests/test_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
import pytest
import iris


@triton.jit
def copy_kernel(
data,
results,
cur_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < BLOCK_SIZE

for target_rank in range(num_ranks):
src_data = data + BLOCK_SIZE * cur_rank
dest_data = results + BLOCK_SIZE * target_rank
iris.copy(src_data + offsets, dest_data + offsets, target_rank, cur_rank, cur_rank, heap_bases, mask)


@pytest.mark.parametrize(
"dtype",
[
torch.int8,
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"BLOCK_SIZE",
[
1,
8,
16,
32,
],
)
def test_copy(dtype, BLOCK_SIZE):
shmem = iris.iris(1 << 20)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
cur_rank = shmem.get_rank()

data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
base = cur_rank + num_ranks
for i in range(num_ranks):
data[i, :] = base * (i + 1)

results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
grid = lambda meta: (1,)
copy_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
shmem.barrier()

expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
for rank_id in range(num_ranks):
expected[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1)

try:
torch.testing.assert_close(results, expected, rtol=0, atol=0)
except AssertionError as e:
print(e)
print("Expected:", expected)
print("Actual:", results)
raise
Loading