diff --git a/iris/__init__.py b/iris/__init__.py index 3d9e9ab7..8eb088f2 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -29,6 +29,7 @@ iris, load, store, + copy, get, put, atomic_add, @@ -64,6 +65,7 @@ "iris", "load", "store", + "copy", "get", "put", "atomic_add", diff --git a/iris/iris.py b/iris/iris.py index de66336d..a56592bd 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1535,6 +1535,57 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): tl.store(translated_ptr, value, mask=mask) +@triton.jit +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): + """ + Copies data from the specified rank's memory into the destination rank's memory. + This function performs the transfer by translating src_ptr from the from_rank's address + space to the to_rank's address space, performing a masked load from the translated + source, and storing the loaded data to dst_ptr in the to_rank memory location. + If from_rank and to_rank are the same, this function performs a local copy operation. + It is undefined behaviour if neither from_rank nor to_rank is the cur_rank. + + Args: + src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's local memory from which to read data. + dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the to_rank's local memory where the data will be written. + from_rank (int): The rank ID that owns src_ptr (source rank). + to_rank (int): The rank ID that will receive the data (destination rank). + cur_rank (int): The rank ID issuing the copy operation. Must be either from_rank or to_rank. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases) + """ + + cur_base = tl.load(heap_bases + cur_rank) + + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + src_ptr_int = tl.cast(src_ptr, tl.uint64) + src_offset = src_ptr_int - cur_base + + dst_ptr_int = tl.cast(dst_ptr, tl.uint64) + dst_offset = dst_ptr_int - cur_base + + from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + + translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) + translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + + data = tl.load(translated_src, mask=mask) + tl.store(translated_dst, data, mask=mask) + + @triton.jit def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ diff --git a/tests/unittests/test_copy.py b/tests/unittests/test_copy.py new file mode 100644 index 00000000..b241d3c2 --- /dev/null +++ b/tests/unittests/test_copy.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def copy_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * target_rank + iris.copy(src_data + offsets, dest_data + offsets, target_rank, cur_rank, cur_rank, heap_bases, mask) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 1, + 8, + 16, + 32, + ], +) +def test_copy(dtype, BLOCK_SIZE): + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + grid = lambda meta: (1,) + copy_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases) + shmem.barrier() + + expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype) + for rank_id in range(num_ranks): + expected[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1) + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise