From aea849e1b336e56ec58a0e7497e26c17dfc59ef1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:25:54 +0000 Subject: [PATCH 01/13] Initial plan From 320b421e253e489531afae5bea7f13fb4f746a39 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:34:48 +0000 Subject: [PATCH 02/13] Add examples to logging, tensor creation, and utility methods Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/iris.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/iris/iris.py b/iris/iris.py index f90a3052..dac005df 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -160,6 +160,9 @@ def warning(self, message): Args: message (str): Human-readable message to log at warning level. + + Example: + >>> iris_ctx.warning("Memory usage is high") """ self._log_with_rank(logging.WARNING, message) @@ -169,6 +172,9 @@ def error(self, message): Args: message (str): Human-readable message to log at error level. + + Example: + >>> iris_ctx.error("Failed to allocate memory") """ self._log_with_rank(logging.ERROR, message) @@ -233,6 +239,12 @@ def zeros_like( Default: False. memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.preserve_format. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> input_tensor = iris_ctx.ones(2, 3) + >>> zeros_tensor = iris_ctx.zeros_like(input_tensor) + >>> print(zeros_tensor.shape) # torch.Size([2, 3]) """ self.debug( f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -376,6 +388,11 @@ def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.zeros(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -466,6 +483,11 @@ def randn( Default: False. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.randn(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -529,6 +551,11 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.ones(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -588,6 +615,11 @@ def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.full((2, 3), 3.14) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -649,6 +681,11 @@ def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): Returns: Tensor: A tensor filled with random numbers from a uniform distribution. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.uniform((2, 3), low=0.0, high=1.0) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") size, num_elements = self.__parse_size(size) @@ -998,6 +1035,11 @@ def get_heap_bases(self): torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` on the Iris device. Pass this to device-side Triton kernels that require heap translation. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> heap_bases = iris_ctx.get_heap_bases() + >>> print(heap_bases.shape) # torch.Size([num_ranks]) """ return self.heap_bases @@ -1010,6 +1052,10 @@ def barrier(self, stream=None): ranks reach the same point before proceeding. Args: stream: If stream is given: wait only for that stream before MPI_Barrier. If stream is None: legacy behavior (device-wide sync). + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> iris_ctx.barrier() # Synchronize all ranks """ # Wait for all GPUs to finish work if stream is None: @@ -1026,6 +1072,11 @@ def get_device(self): Returns: torch.device: The CUDA device of Iris-managed memory. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> device = iris_ctx.get_device() + >>> print(device) # cuda:0 """ return self.memory_pool.device @@ -1035,6 +1086,11 @@ def get_cu_count(self): Returns: int: Number of compute units on this rank's GPU. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> cu_count = iris_ctx.get_cu_count() + >>> print(f"GPU has {cu_count} CUs") """ return get_cu_count(self.gpu_id) @@ -1044,6 +1100,11 @@ def get_rank(self): Returns: int: Zero-based rank id of the current process. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> rank = iris_ctx.get_rank() + >>> print(f"This is rank {rank}") """ return self.cur_rank @@ -1053,6 +1114,11 @@ def get_num_ranks(self): Returns: int: World size (number of ranks). + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> num_ranks = iris_ctx.get_num_ranks() + >>> print(f"Total ranks: {num_ranks}") """ return self.num_ranks @@ -1379,6 +1445,13 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): Returns: Block: The loaded value from the target memory location. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> rank = 0 + >>> data = iris.load(ptr, rank, rank, heap_bases) + >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) result = tl.load(translated_ptr, mask=mask) @@ -1405,6 +1478,12 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(ptr, data, heap_bases): + >>> rank = 0 + >>> iris.store(ptr, data, rank, rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) tl.store(translated_ptr, value, mask=mask) @@ -1487,6 +1566,13 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, increment, heap_bases): + >>> rank = 0 + >>> old_val = iris.atomic_add(ptr, increment, rank, rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) From 4b82d5a4f4b0be94ba657a6ba391f5e285dfe032 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:38:58 +0000 Subject: [PATCH 03/13] Add examples to remaining device functions, tensor methods, and utility functions Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/iris.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ iris/util.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/iris/iris.py b/iris/iris.py index dac005df..3d42f228 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -316,6 +316,11 @@ def arange( device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> print(tensor.shape) # torch.Size([5]) """ self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") @@ -731,6 +736,11 @@ def empty( Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.contiguous_format. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.empty(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -1509,6 +1519,13 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) @@ -1536,6 +1553,13 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) @@ -1627,6 +1651,15 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop Returns: Block: The value contained at the memory location before the atomic operation attempt. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> rank = 0 + >>> expected = 0 + >>> new_val = 42 + >>> old_val = iris.atomic_cas(ptr, expected, new_val, rank, rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @@ -1654,6 +1687,13 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, new_value, heap_bases): + >>> rank = 0 + >>> old_val = iris.atomic_xchg(ptr, new_value, rank, rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1803,5 +1843,10 @@ def iris(heap_size=1 << 30): Returns: Iris: An initialized Iris instance. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(2**30) # 1GB heap + >>> tensor = iris_ctx.zeros(1024, 1024) """ return Iris(heap_size) diff --git a/iris/util.py b/iris/util.py index 17da52de..03d24a46 100644 --- a/iris/util.py +++ b/iris/util.py @@ -68,6 +68,28 @@ def do_bench( quantiles=None, return_mode="mean", ): + """ + Benchmark a function by timing its execution. + + Args: + fn (callable): Function to benchmark. + barrier_fn (callable, optional): Function to call for synchronization. Default: no-op. + preamble_fn (callable, optional): Function to call before each execution. Default: no-op. + n_warmup (int, optional): Number of warmup iterations. Default: 25. + n_repeat (int, optional): Number of timing iterations. Default: 100. + quantiles (list, optional): Quantiles to return instead of summary statistic. Default: None. + return_mode (str, optional): Summary statistic to return ("mean", "min", "max", "median", "all"). Default: "mean". + + Returns: + float or list: Timing result(s) in milliseconds. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(1 << 20) + >>> def test_fn(): + >>> tensor = iris_ctx.zeros(1000, 1000) + >>> time_ms = iris.do_bench(test_fn, barrier_fn=iris_ctx.barrier) + """ # Wait for anything that happened before barrier_fn() preamble_fn() @@ -115,6 +137,20 @@ def memset_kernel(ptr, value, n_elements, BLOCK_SIZE: tl.constexpr): def memset_tensor(tensor, value): + """ + Set all elements of a tensor to a specified value using a Triton kernel. + + Args: + tensor (torch.Tensor): Contiguous int32 tensor to modify in-place. + value (int): Value to set all elements to. + + Example: + >>> import iris + >>> import torch + >>> tensor = torch.zeros(100, dtype=torch.int32, device='cuda') + >>> iris.memset_tensor(tensor, 42) + >>> assert torch.all(tensor == 42) + """ assert tensor.is_contiguous(), "Tensor must be contiguous" assert tensor.dtype == torch.int32, "Only torch.int32 tensors are supported" n_elements = tensor.numel() From a6c88c3b4bfe4fb03c7fd29933b25b16ae150a69 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:41:18 +0000 Subject: [PATCH 04/13] Complete docstring examples for all remaining tensor creation and atomic methods Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/iris.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/iris/iris.py b/iris/iris.py index 3d42f228..df803e36 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -805,6 +805,11 @@ def randint( layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -886,6 +891,11 @@ def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor.shape) # torch.Size([5]) """ self.debug( f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -987,6 +997,11 @@ def rand( Default: False. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + + Example: + >>> iris_ctx = iris.iris(1 << 20) + >>> tensor = iris_ctx.rand(2, 3) # Random values in [0, 1) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -1624,6 +1639,13 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The value at the memory location before the atomic subtraction. + + Example: + >>> @triton.jit + >>> def kernel(ptr, decrement, heap_bases): + >>> rank = 0 + >>> old_val = iris.atomic_sub(ptr, decrement, rank, rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) From 733df66850ab7622336a8d66dd57635a5b994bae Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:58:40 +0000 Subject: [PATCH 05/13] Update Triton API docstring examples to show cross-GPU operations Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/iris.py | 98 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 16 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index df803e36..f67b4ed2 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1474,8 +1474,10 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): Example: >>> @triton.jit >>> def kernel(ptr, heap_bases): - >>> rank = 0 - >>> data = iris.load(ptr, rank, rank, heap_bases) + >>> # Load data from rank 1's memory into the current rank + >>> cur_rank = 0 # Current rank + >>> remote_rank = 1 # Remote rank to load from + >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) @@ -1506,9 +1508,12 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): Example: >>> @triton.jit - >>> def kernel(ptr, data, heap_bases): - >>> rank = 0 - >>> iris.store(ptr, data, rank, rank, heap_bases) + >>> def kernel(ptr, heap_bases): + >>> # Store value 42 into rank 1's heap from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> value = 42 + >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) tl.store(translated_ptr, value, mask=mask) @@ -1608,9 +1613,12 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Example: >>> @triton.jit - >>> def kernel(ptr, increment, heap_bases): - >>> rank = 0 - >>> old_val = iris.atomic_add(ptr, increment, rank, rank, heap_bases) + >>> def kernel(ptr, heap_bases): + >>> # Atomically add 5 to rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> increment = 5 + >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) @@ -1642,9 +1650,12 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Example: >>> @triton.jit - >>> def kernel(ptr, decrement, heap_bases): - >>> rank = 0 - >>> old_val = iris.atomic_sub(ptr, decrement, rank, rank, heap_bases) + >>> def kernel(ptr, heap_bases): + >>> # Atomically subtract 3 from rank 2's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 2 # Remote rank (destination) + >>> decrement = 3 + >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) @@ -1677,10 +1688,12 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop Example: >>> @triton.jit >>> def kernel(ptr, heap_bases): - >>> rank = 0 + >>> # Compare-and-swap on rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) >>> expected = 0 >>> new_val = 42 - >>> old_val = iris.atomic_cas(ptr, expected, new_val, rank, rank, heap_bases) + >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) @@ -1712,9 +1725,12 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non Example: >>> @triton.jit - >>> def kernel(ptr, new_value, heap_bases): - >>> rank = 0 - >>> old_val = iris.atomic_xchg(ptr, new_value, rank, rank, heap_bases) + >>> def kernel(ptr, heap_bases): + >>> # Exchange value with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_value = 99 + >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) @@ -1743,6 +1759,16 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically XOR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xFF + >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1770,6 +1796,16 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically AND with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0x0F + >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1797,6 +1833,16 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically OR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xF0 + >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1824,6 +1870,16 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find minimum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 10 + >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1851,6 +1907,16 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find maximum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 100 + >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) + >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) From 7b93dd5a802eae08347d5cb36d96a24a412ac0eb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:59:18 +0000 Subject: [PATCH 06/13] Sync with main - merge PyTorch distributed changes Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .github/copilot-instructions.md | 7 +- .github/workflows/iris-tests-apptainer.yml | 17 +- README.md | 62 +++++--- apptainer/iris.def | 6 +- docker/Dockerfile | 12 +- docs/CONTRIBUTING.md | 8 +- docs/conceptual/programming-model.md | 2 +- docs/conf.py | 2 +- docs/getting-started/installation.md | 3 +- docs/index.md | 60 ++++--- docs/reference/api-iris-class.md | 2 +- examples/00_load/README.md | 2 +- examples/00_load/load_bench.py | 27 +++- examples/01_store/store_bench.py | 27 +++- examples/02_all_load/README.md | 2 +- examples/02_all_load/all_load_bench.py | 29 +++- examples/03_all_store/README.md | 2 +- examples/03_all_store/all_store_bench.py | 28 +++- examples/04_atomic_add/README.md | 2 +- examples/04_atomic_add/atomic_add_bench.py | 35 ++++- examples/05_atomic_xchg/README.md | 2 +- examples/05_atomic_xchg/atomic_xchg_bench.py | 35 ++++- .../message_passing_load_store.py | 29 +++- .../06_message_passing/message_passing_put.py | 27 +++- examples/07_gemm_all_scatter/benchmark.py | 45 ++++-- .../08_gemm_atomics_all_reduce/benchmark.py | 32 +++- .../matmul_wrapper.py | 4 +- .../09_gemm_one_shot_all_reduce/benchmark.py | 30 +++- .../matmul_wrapper.py | 4 +- .../benchmark.py | 24 ++- .../benchmark.py | 24 ++- .../benchmark.py | 24 ++- examples/README.md | 28 ++-- examples/benchmark/reference/gemm.py | 1 + iris/__init__.py | 2 + iris/_distributed_helpers.py | 148 ++++++++++++++++++ iris/_mpi_helpers.py | 51 ------ iris/iris.py | 40 ++--- iris/logging.py | 2 +- pyproject.toml | 1 - tests/run_tests_distributed.py | 89 +++++++++++ 41 files changed, 747 insertions(+), 230 deletions(-) create mode 100644 iris/_distributed_helpers.py delete mode 100644 iris/_mpi_helpers.py create mode 100755 tests/run_tests_distributed.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 3a613ec9..cf5f9a0d 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -7,7 +7,7 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM - Clean abstractions with full symmetric heap implementation - Pythonic PyTorch-like host APIs for tensor operations - Triton-style device APIs for load, store, and atomic operations -- Minimal dependencies (Triton, PyTorch, HIP runtime, mpi4py) +- Minimal dependencies (Triton, PyTorch, HIP runtime) - Comprehensive examples showing communication/computation overlap **FOLLOW THESE INSTRUCTIONS EXACTLY. Reference these instructions first before using search or bash commands.** @@ -17,7 +17,6 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM - **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI350X & MI355X) > **Note**: See below for instructions on development without AMD GPU access - **ROCm/HIP Toolkit**: Required for building C++/HIP components -- **MPI**: Required for multi-GPU operations - **Docker/Apptainer**: Recommended for containerized development ## Build @@ -78,8 +77,8 @@ pytest tests/unittests/ # Run example tests pytest tests/examples/ -# Run specific example (requires MPI and GPU) -mpirun -np 8 python examples/00_load/load_bench.py +# Run specific example +python examples/00_load/load_bench.py ``` ### Code Quality diff --git a/.github/workflows/iris-tests-apptainer.yml b/.github/workflows/iris-tests-apptainer.yml index d39e5814..5e2d9a85 100644 --- a/.github/workflows/iris-tests-apptainer.yml +++ b/.github/workflows/iris-tests-apptainer.yml @@ -52,7 +52,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Run Iris Tests with ${{ matrix.ranks }} MPI ranks + - name: Run Iris Tests with ${{ matrix.ranks }} ranks run: | apptainer exec ~/apptainer/iris-dev.sif bash -c " set -e # Exit on any error @@ -60,20 +60,17 @@ jobs: # Install iris first pip install -e . - # Create function for mpirun with root permissions - mpirun-root() { mpirun --allow-run-as-root \"\$@\"; } - - # Run examples tests one at a time + # Run examples tests one at a time using distributed wrapper echo 'Running examples tests one at a time...' for test_file in tests/examples/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} MPI ranks\" - mpirun-root -np ${{ matrix.ranks }} python -m pytest \"\$test_file\" -v --tb=short + echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" + python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short done - # Run unit tests one at a time + # Run unit tests one at a time using distributed wrapper echo 'Running unit tests one at a time...' for test_file in tests/unittests/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} MPI ranks\" - mpirun-root -np ${{ matrix.ranks }} python -m pytest \"\$test_file\" -v --tb=short + echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" + python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short done " \ No newline at end of file diff --git a/README.md b/README.md index 7f44fa5b..639e8d9f 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,10 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations. Iris Here's a simple example showing how to perform remote memory operations between GPUs using Iris: ```python +import os import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import iris @@ -47,7 +50,7 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas pid = tl.program_id(0) block_start = pid * block_size offsets = block_start + tl.arange(0, block_size) - + # Guard for out-of-bounds accesses mask = offsets < buffer_size @@ -58,29 +61,40 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas source_rank, target_rank, heap_bases_ptr, mask=mask) -# Iris initialization -heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication -iris_ctx = iris.iris(heap_size) -cur_rank = iris_ctx.get_rank() - -# Iris tensor allocation -buffer_size = 4096 # 4K elements buffer -buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) - -# Launch the kernel on rank 0 -block_size = 1024 -grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) -source_rank = 0 -if cur_rank == source_rank: - kernel[grid]( - buffer, - buffer_size, - block_size, - iris_ctx.get_heap_bases(), - ) - -# Synchronize all ranks -iris_ctx.barrier() +def _worker(rank, world_size): + # Torch distributed initialization + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + # Iris initialization + heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication + iris_ctx = iris.iris(heap_size) + cur_rank = iris_ctx.get_rank() + + # Iris tensor allocation + buffer_size = 4096 # 4K elements buffer + buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) + + # Launch the kernel on rank 0 + block_size = 1024 + grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) + source_rank = 0 + if cur_rank == source_rank: + kernel[grid]( + buffer, + buffer_size, + block_size, + iris_ctx.get_heap_bases(), + ) + + # Synchronize all ranks + iris_ctx.barrier() + dist.destroy_process_group() + +if __name__ == "__main__": + world_size = 2 # Using two ranks + mp.spawn(_worker, args=(world_size,), nprocs=world_size, join=True) ``` ## Quick Start Guide diff --git a/apptainer/iris.def b/apptainer/iris.def index d3ca16d3..31960182 100644 --- a/apptainer/iris.def +++ b/apptainer/iris.def @@ -10,7 +10,7 @@ From: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch export TRITON_PATH=/workspace/triton conda env list source /opt/conda/bin/activate py_3.10 - conda install -y -n py_3.10 -c conda-forge mpi4py openmpi jupyter ninja cmake wheel + conda install -y -n py_3.10 -c conda-forge jupyter ninja cmake wheel git clone https://github.com/triton-lang/triton.git \$TRITON_PATH cd \$TRITON_PATH git checkout dd5823453bcc7973eabadb65f9d827c43281c434 @@ -23,9 +23,9 @@ From: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch # Define environment variables export TRITON_PATH=/workspace/triton export PYTHONPATH=$TRITON_PATH/python/ - export LD_LIBRARY_PATH=/opt/rocm/lib:/usr/lib/openmpi/lib:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH export ROCM_PATH=/opt/rocm - export PATH=/opt/conda/envs/py_3.10/bin:/opt/rocm/bin:/usr/lib/openmpi/bin:$PATH + export PATH=/opt/conda/envs/py_3.10/bin:/opt/rocm/bin:$PATH export OMPI_MCA_mtl="^ofi" export OMPI_MCA_pml="ob1" diff --git a/docker/Dockerfile b/docker/Dockerfile index 73738239..8b49c01a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -12,8 +12,8 @@ ENV TRITON_PATH=/opt/triton \ OMPI_MCA_mtl="^ofi" \ OMPI_MCA_pml="ob1" -ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:/usr/lib/openmpi/lib:$LD_LIBRARY_PATH \ - PATH="$ROCM_PATH/bin:/usr/lib/openmpi/bin:$PATH" +ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:$LD_LIBRARY_PATH \ + PATH="$ROCM_PATH/bin:$PATH" ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ OMPI_ALLOW_RUN_AS_ROOT=1 @@ -21,19 +21,13 @@ ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ # Install system packages RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential \ - openmpi-bin libopenmpi-dev && \ + git wget ninja-build cmake python3-pip python3-dev build-essential && \ rm -rf /var/lib/apt/lists/* # Install Python packages with pip RUN pip3 install --upgrade pip && \ pip3 install wheel jupyter -# This needs sudo, I can only get it to install with sudo -# or using conda, but conda runs into issues with too many requests. -# https://stackoverflow.com/a/54052470/5729690 -RUN sudo pip3 install mpi4py - # Clone and install Triton WORKDIR $TRITON_PATH RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index d21d49ac..3f579361 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -20,8 +20,12 @@ git checkout -b $USER/your-feature-name ruff check . ruff format . -# Run tests -pytest +# Run tests +python tests/run_tests_distributed.py tests/examples/test_all_load_bench.py --num_ranks 2 -v +python tests/run_tests_distributed.py tests/unittests/ --num_ranks 2 -v + +# Or run individual test files +python tests/run_tests_distributed.py tests/examples/test_load_bench.py --num_ranks 2 -v ``` ### 4. Commit and Push diff --git a/docs/conceptual/programming-model.md b/docs/conceptual/programming-model.md index cf078042..ea2f772f 100644 --- a/docs/conceptual/programming-model.md +++ b/docs/conceptual/programming-model.md @@ -25,7 +25,7 @@ html[data-theme="light"] .theme-switch-wrapper .light-theme-img { 1. **Designed by Experts, Built for Scale** - Written from scratch by GPU and distributed computing experts - - Minimal dependencies: only Triton, PyTorch, HIP runtime and mpi4py (for initialization) + - Minimal dependencies: only Triton, PyTorch, and HIP runtime - No external frameworks or heavyweight runtimes beyond core stack 2. **Clean Abstractions** diff --git a/docs/conf.py b/docs/conf.py index 8210169d..9f509212 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -94,7 +94,7 @@ "triton", "triton.language", "numpy", - "iris._mpi_helpers", + "iris._distributed_helpers", "iris.hip", ] diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index 9af00357..5a38cb99 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -4,7 +4,7 @@ This guide covers how to install Iris on your system using various methods. ## Overview -Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, MPI, and Triton. This guide will walk you through the installation process using different approaches. +Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, and Triton. This guide will walk you through the installation process using different approaches. ## Prerequisites @@ -22,7 +22,6 @@ Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, MPI, - Python 3.10+ - PyTorch 2.0+ (ROCm version) - ROCm 6.3.1+ HIP runtime -- OpenMPI - Git - CMake, Ninja, build-essential - Triton (specific commit: [dd5823453bcc7973eabadb65f9d827c43281c434](https://github.com/triton-lang/triton/tree/dd5823453bcc7973eabadb65f9d827c43281c434)) diff --git a/docs/index.md b/docs/index.md index 2513188b..ea3e3d90 100644 --- a/docs/index.md +++ b/docs/index.md @@ -46,7 +46,10 @@ cd iris && pip install -e . Here's a simple example showing how to perform remote memory operations between GPUs using Iris: ```python +import os import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import iris @@ -69,29 +72,40 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas source_rank, target_rank, heap_bases_ptr, mask=mask) -# Iris initialization -heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication -iris_ctx = iris.iris(heap_size) -cur_rank = iris_ctx.get_rank() - -# Iris tensor allocation -buffer_size = 4096 # 4K elements buffer -buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) - -# Launch the kernel on rank 0 -block_size = 1024 -grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) -source_rank = 0 -if cur_rank == source_rank: - kernel[grid]( - buffer, - buffer_size, - block_size, - iris_ctx.get_heap_bases(), - ) - -# Synchronize all ranks -iris_ctx.barrier() +def _worker(rank, world_size): + # Torch distributed initialization + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + # Iris initialization + heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication + iris_ctx = iris.iris(heap_size) + cur_rank = iris_ctx.get_rank() + + # Iris tensor allocation + buffer_size = 4096 # 4K elements buffer + buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) + + # Launch the kernel on rank 0 + block_size = 1024 + grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) + source_rank = 0 + if cur_rank == source_rank: + kernel[grid]( + buffer, + buffer_size, + block_size, + iris_ctx.get_heap_bases(), + ) + + # Synchronize all ranks + iris_ctx.barrier() + dist.destroy_process_group() + +if __name__ == "__main__": + world_size = 2 # Using two ranks + mp.spawn(_worker, args=(world_size,), nprocs=world_size, join=True) ``` For more examples, see the [Examples](reference/examples.md) page with ready-to-run scripts and usage patterns. diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index aa673af5..f16eccf5 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -33,7 +33,7 @@ Use Iris-aware logging that automatically annotates each message with the curren ## Broadcast Helper -Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal MPI helper. +Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper. ```{eval-rst} .. automethod:: iris.iris.Iris.broadcast diff --git a/examples/00_load/README.md b/examples/00_load/README.md index e985c7e2..aab1e0f1 100644 --- a/examples/00_load/README.md +++ b/examples/00_load/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -mpirun -np 8 python examples/00_load/load_bench.py +python examples/00_load/load_bench.py --num_ranks 8 ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/00_load/load_bench.py b/examples/00_load/load_bench.py index 1ef16775..fc8e4148 100755 --- a/examples/00_load/load_bench.py +++ b/examples/00_load/load_bench.py @@ -5,6 +5,8 @@ import argparse import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -97,6 +99,7 @@ def parse_args(): parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") parser.add_argument("-n", "--num_experiments", type=int, default=10, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=1, help="Number of warmup iterations") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -229,9 +232,12 @@ def print_bandwidth_matrix(matrix, label="Unidirectional LOAD bandwidth GiB/s [R raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -262,6 +268,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/01_store/store_bench.py b/examples/01_store/store_bench.py index 835809f2..80e7a7e0 100755 --- a/examples/01_store/store_bench.py +++ b/examples/01_store/store_bench.py @@ -5,6 +5,8 @@ import argparse import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -82,6 +84,7 @@ def parse_args(): parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") parser.add_argument("-n", "--num_experiments", type=int, default=10, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=1, help="Number of warmup iterations") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -202,9 +205,12 @@ def print_bandwidth_matrix(matrix, label="Unidirectional STORE bandwidth GiB/s [ raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -222,6 +228,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/02_all_load/README.md b/examples/02_all_load/README.md index fd7f396a..13a8def2 100644 --- a/examples/02_all_load/README.md +++ b/examples/02_all_load/README.md @@ -10,7 +10,7 @@ All-Load benchmark using Iris. ## Usage ```terminal -mpirun -np 8 python examples/02_all_load/all_load_bench.py +python examples/02_all_load/all_load_bench.py --num_ranks 8 ``` On an MI300X, this example will run on 8 GPUs. It prints: diff --git a/examples/02_all_load/all_load_bench.py b/examples/02_all_load/all_load_bench.py index d3ab765a..6fb65c79 100755 --- a/examples/02_all_load/all_load_bench.py +++ b/examples/02_all_load/all_load_bench.py @@ -4,6 +4,8 @@ import argparse import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -12,7 +14,6 @@ import iris - torch.manual_seed(123) random.seed(123) @@ -124,6 +125,8 @@ def parse_args(): parser.add_argument("-w", "--num_warmup", type=int, default=2, help="Number of warmup experiments") parser.add_argument("-a", "--active_ranks", type=int, default=8, help="Number of active ranks") parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + return vars(parser.parse_args()) @@ -310,9 +313,12 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic heap_size = args["heap_size"] shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -354,6 +360,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_data, buffer_sizes, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/03_all_store/README.md b/examples/03_all_store/README.md index ed2b5685..7e691735 100644 --- a/examples/03_all_store/README.md +++ b/examples/03_all_store/README.md @@ -10,7 +10,7 @@ All-Store benchmark using Iris. ## Usage ```terminal -mpirun -np 8 python examples/03_all_store/all_store_bench.py +python examples/03_all_store/all_store_bench.py --num_ranks 8 ``` On an MI300X, this example will run on 8 GPUs. It prints: diff --git a/examples/03_all_store/all_store_bench.py b/examples/03_all_store/all_store_bench.py index ce17a1e3..eac5dd5d 100755 --- a/examples/03_all_store/all_store_bench.py +++ b/examples/03_all_store/all_store_bench.py @@ -5,6 +5,8 @@ import argparse import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -13,7 +15,6 @@ import iris - torch.manual_seed(123) random.seed(123) @@ -87,6 +88,7 @@ def parse_args(): parser.add_argument("-w", "--num_warmup", type=int, default=2, help="Number of warmup experiments") parser.add_argument("-a", "--active_ranks", type=int, default=8, help="Number of active ranks") parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -240,9 +242,12 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic heap_size = args["heap_size"] shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -284,6 +289,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_data, buffer_sizes, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/04_atomic_add/README.md b/examples/04_atomic_add/README.md index 242cabe5..127e51fb 100644 --- a/examples/04_atomic_add/README.md +++ b/examples/04_atomic_add/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -mpirun -np 8 python examples/04_atomic_add/atomic_add_bench.py +python examples/04_atomic_add/atomic_add_bench.py --num_ranks 8 ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/04_atomic_add/atomic_add_bench.py b/examples/04_atomic_add/atomic_add_bench.py index 6b292736..9b6dfb4f 100755 --- a/examples/04_atomic_add/atomic_add_bench.py +++ b/examples/04_atomic_add/atomic_add_bench.py @@ -3,15 +3,17 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse +import json +import random +import numpy as np import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl -import random -import numpy as np -import json -import iris +import iris torch.manual_seed(123) random.seed(123) @@ -78,6 +80,7 @@ def parse_args(): parser.add_argument("-x", "--num_experiments", type=int, default=16, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=4, help="Number of warmup experiments") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -202,9 +205,12 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -223,6 +229,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/05_atomic_xchg/README.md b/examples/05_atomic_xchg/README.md index fff66eed..691b70e0 100644 --- a/examples/05_atomic_xchg/README.md +++ b/examples/05_atomic_xchg/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -mpirun -np 8 python examples/05_atomic_xchg/atomic_xchg_bench.py +python examples/05_atomic_xchg/atomic_xchg_bench.py --num_ranks 8 ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/05_atomic_xchg/atomic_xchg_bench.py b/examples/05_atomic_xchg/atomic_xchg_bench.py index 51ae2410..89d7792f 100755 --- a/examples/05_atomic_xchg/atomic_xchg_bench.py +++ b/examples/05_atomic_xchg/atomic_xchg_bench.py @@ -3,15 +3,17 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse +import json +import random +import numpy as np import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl -import random -import numpy as np -import json -import iris +import iris torch.manual_seed(123) random.seed(123) @@ -82,6 +84,7 @@ def parse_args(): parser.add_argument("-x", "--num_experiments", type=int, default=16, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=4, help="Number of warmup experiments") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -206,9 +209,12 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -227,6 +233,23 @@ def main(): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/06_message_passing/message_passing_load_store.py b/examples/06_message_passing/message_passing_load_store.py index 9e19f564..37db8bcb 100755 --- a/examples/06_message_passing/message_passing_load_store.py +++ b/examples/06_message_passing/message_passing_load_store.py @@ -3,11 +3,13 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse +import random import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl -import random import iris @@ -125,13 +127,17 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) dtype = torch_dtype_from_str(args["datatype"]) cur_rank = shmem.get_rank() @@ -199,6 +205,23 @@ def main(): shmem.barrier() + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index fc849b8b..b4d54064 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -4,6 +4,8 @@ import argparse import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -113,13 +115,17 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) dtype = torch_dtype_from_str(args["datatype"]) cur_rank = shmem.get_rank() @@ -187,6 +193,23 @@ def main(): shmem.barrier() + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index fa10b0c3..6d872e61 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -2,20 +2,21 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -import torch -import triton -import random -import sys -import os import argparse import json +import os +import random +import sys -from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set -from examples.common.validation import validate_gemm +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +from matmul_wrapper import matmul import iris - -from matmul_wrapper import matmul +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set +from examples.common.validation import validate_gemm torch.manual_seed(123) random.seed(123) @@ -52,13 +53,17 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -238,6 +243,24 @@ def run_experiment(): shmem.barrier() + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + # Use command line argument if provided, otherwise use num_ranks parameter + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 45503492..c17fb5ba 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import random import sys @@ -68,14 +70,19 @@ def parse_args(): # For All Scatter, use: 288 # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for Stream-K") + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -128,7 +135,7 @@ def main(): total_tiles = total_blocks_M * total_blocks_N if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}") + print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") exit(1) tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) @@ -278,6 +285,23 @@ def run_experiment(): shmem.barrier() + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + if __name__ == "__main__": main() diff --git a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py index 6d6fcf2e..ab014faf 100644 --- a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py +++ b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py @@ -70,10 +70,10 @@ def _call( total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - if total_programs_streamk > 0: # Stream-K + if total_programs_streamk > 0: # GEMM # last wave may occupy less than total_programs_streamk SMs total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper + # for two-tile GEMM + data-parallel from original paper # if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: # total_tiles_streamk += total_programs_streamk # remaining tiles are computed using classical blocking diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index f7c531e9..a79fc5fc 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import random import sys @@ -66,15 +68,16 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - # For All Scatter, use: 288 - # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for Stream-K") + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -128,7 +131,7 @@ def main(): total_tiles = total_blocks_M * total_blocks_N if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}") + print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") exit(1) tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) @@ -277,6 +280,21 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py index 468a16d7..83c9326f 100644 --- a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py +++ b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py @@ -70,10 +70,10 @@ def _call( total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - if total_programs_streamk > 0: # Stream-K + if total_programs_streamk > 0: # GEMM # last wave may occupy less than total_programs_streamk SMs total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper + # for two-tile GEMM + data-parallel from original paper # if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: # total_tiles_streamk += total_programs_streamk # remaining tiles are computed using classical blocking diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index f36c85a8..bb49bacb 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import random import sys @@ -55,12 +57,15 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -245,6 +250,21 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 19e08c7e..41c164a8 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import random import sys @@ -56,12 +58,15 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -279,6 +284,21 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 840db13d..5cdc3819 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist +import torch.multiprocessing as mp import triton import random import sys @@ -56,12 +58,15 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def main(): - args = parse_args() +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -275,6 +280,21 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/README.md b/examples/README.md index 92c394f3..0794d70f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -34,39 +34,39 @@ This directory contains various algorithm implementations for distributed comput ### Basic Operations ```terminal # Example command to run distributed load operations -mpirun -np 8 python examples/00_load/load_bench.py # Load across GPUs -mpirun -np 8 python examples/02_all_load/all_load_bench.py # Simultaneous load on all GPUs +python examples/00_load/load_bench.py --num_ranks 8 # Load across GPUs +python examples/02_all_load/all_load_bench.py --num_ranks 8 # Simultaneous load on all GPUs # Example command to run distributed store operations -mpirun -np 8 python examples/01_store/store_bench.py # Store across GPUs -mpirun -np 8 python examples/03_all_store/all_store_bench.py # Simultaneous store on all GPUs +python examples/01_store/store_bench.py --num_ranks 8 # Store across GPUs +python examples/03_all_store/all_store_bench.py --num_ranks 8 # Simultaneous store on all GPUs # Example command to run atomic operations -mpirun -np 8 python examples/04_atomic_add/atomic_add_bench.py # Atomic add across GPUs -mpirun -np 8 python examples/05_atomic_xchg/atomic_xchg_bench.py # Atomic exchange across GPUs +python examples/04_atomic_add/atomic_add_bench.py --num_ranks 8 # Atomic add across GPUs +python examples/05_atomic_xchg/atomic_xchg_bench.py --num_ranks 8 # Atomic exchange across GPUs # Example command to run message passing -python examples/06_message_passing/message_passing_put.py -python examples/06_message_passing/message_passing_load_store.py +python examples/06_message_passing/message_passing_put.py --num_ranks 8 +python examples/06_message_passing/message_passing_load_store.py --num_ranks 8 ``` ### GEMM Operations ```terminal # Example command to run benchmark with all-scatter algorithm -mpirun -np 8 python examples/07_gemm_all_scatter/benchmark.py --benchmark --validate +python examples/07_gemm_all_scatter/benchmark.py --benchmark --validate --num_ranks 8 # Example command to run benchmark with all-reduce algorithm -mpirun -np 8 python examples/08_gemm_atomics_all_reduce/benchmark.py --benchmark --validate +python examples/08_gemm_atomics_all_reduce/benchmark.py --benchmark --validate --num_ranks 8 # Example command to run benchmark with one-shot all-reduce algorithm -mpirun -np 8 python examples/09_gemm_one_shot_all_reduce/benchmark.py --benchmark --validate +python examples/09_gemm_one_shot_all_reduce/benchmark.py --benchmark --validate --num_ranks 8 # Example command to run benchmark with all-scatter and workgroup specialization -mpirun -np 8 python examples/10_gemm_all_scatter_wg_specialization/benchmark.py --benchmark --validate +python examples/10_gemm_all_scatter_wg_specialization/benchmark.py --benchmark --validate --num_ranks 8 # Example command to run benchmark with all-scatter producer-consumer pattern -mpirun -np 8 python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark --validate +python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark --validate --num_ranks 8 # Example command to run benchmark with all-scatter bulk synchronous approach -mpirun -np 8 python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate +python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate --num_ranks 8 ``` diff --git a/examples/benchmark/reference/gemm.py b/examples/benchmark/reference/gemm.py index 61bed8c8..b37b7e5a 100755 --- a/examples/benchmark/reference/gemm.py +++ b/examples/benchmark/reference/gemm.py @@ -2,6 +2,7 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch +import torch.distributed as dist import triton import torch.cuda.nvtx as nvtx diff --git a/iris/__init__.py b/iris/__init__.py index f252ac4d..088a5334 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -62,6 +62,8 @@ ERROR, ) +# Launcher functionality is now user code - see examples and documentation + # Pipe allocations via finegrained allocator current_dir = os.path.dirname(__file__) # Look for the library in the installed package location diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py new file mode 100644 index 00000000..ce656f93 --- /dev/null +++ b/iris/_distributed_helpers.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import numpy as np + + +def _infer_device(): + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + try: + backend = str(dist.get_backend()).lower() + except Exception: + backend = "gloo" + if backend == "nccl" and torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") + + +def _nccl_dtype_supported(t: torch.Tensor) -> bool: + """Conservative whitelist for NCCL tensor dtypes.""" + supported = { + torch.int8, + torch.uint8, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + } + # bfloat16 is commonly supported in recent stacks; include if available + if hasattr(torch, "bfloat16"): + supported.add(torch.bfloat16) + return t.dtype in supported + + +def distributed_allgather(data): + """ + All-gather operation using PyTorch distributed. + + Args: + data: 1D numpy array to gather across all ranks + + Returns: + 2D numpy array with shape (world_size, len(data)) + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + + data = np.asarray(data) + assert data.ndim == 1, "Only 1D arrays are supported." + + world_size = dist.get_world_size() + device = _infer_device() + backend = str(dist.get_backend()).lower() + + # Fast path: tensor all_gather if dtype is NCCL-supported or backend != nccl + data_tensor = torch.from_numpy(data) + use_tensor_collective = backend != "nccl" or _nccl_dtype_supported(data_tensor) + + if use_tensor_collective: + data_tensor = data_tensor.to(device) + gathered_tensors = [torch.empty_like(data_tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, data_tensor) + return torch.stack(gathered_tensors, dim=0).to("cpu").numpy() + + # Fallback for NCCL-unsupported dtypes (e.g., uint64/bool/etc.) + obj_list = [None for _ in range(world_size)] + # Use object collective (works across backends) + dist.all_gather_object(obj_list, data) + # Ensure uniform shapes and stack + return np.stack(obj_list, axis=0) + + +def distributed_broadcast_scalar(value=None, root=0): + """ + Broadcast a scalar value from root to all ranks. + + Args: + value: Value to broadcast (only used on root rank) + root: Root rank to broadcast from + + Returns: + Broadcasted value + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + + rank = dist.get_rank() + device = _infer_device() + backend = str(dist.get_backend()).lower() + + # First agree on dtype (numpy dtype object) + if rank == root: + if value is None: + raise ValueError("Root must provide a value.") + np_val = np.array(value) # captures dtype + dtype = np_val.dtype + else: + np_val = None + dtype = None + + dtype_obj = [dtype] + dist.broadcast_object_list(dtype_obj, src=root) + dtype = dtype_obj[0] + + # If NCCL can't handle this dtype, just broadcast the object directly. + if backend == "nccl": + # Try a quick check using a tiny tensor of the dtype + torch_dtype = torch.from_numpy(np.array(0, dtype=dtype)).dtype + dummy = torch.empty((), dtype=torch_dtype) + if not _nccl_dtype_supported(dummy): + obj = [value if rank == root else None] + dist.broadcast_object_list(obj, src=root) + return obj[0] + + # Tensor path: create a 0-D tensor, broadcast on the selected device + if rank != root: + np_val = np.empty((), dtype=dtype) + val_t = torch.from_numpy(np_val).to(device) + dist.broadcast(val_t, src=root) + return val_t.to("cpu").item() + + +def distributed_barrier(): + """ + Synchronization barrier using PyTorch distributed. + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized") + dist.barrier() + + +def init_distributed(): + """ + Initialize PyTorch distributed and return communicator info. + + Returns: + tuple: (communicator_placeholder, rank, world_size) + Note: communicator_placeholder is None since PyTorch distributed + uses global state rather than explicit communicator objects + """ + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed is not initialized. Call dist.init_process_group() first.") + rank = dist.get_rank() + world_size = dist.get_world_size() + return None, rank, world_size diff --git a/iris/_mpi_helpers.py b/iris/_mpi_helpers.py deleted file mode 100644 index ff59a650..00000000 --- a/iris/_mpi_helpers.py +++ /dev/null @@ -1,51 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -from mpi4py import MPI -import numpy as np - - -def mpi_allgather(data): - thread_comm = MPI.COMM_WORLD - shmcomm = thread_comm.Split_type(MPI.COMM_TYPE_SHARED) - shm_size = shmcomm.Get_size() - data = np.asarray(data) - assert len(data.shape) == 1, "Only 1D arrays are supported." - recv_data = np.empty(len(data) * shm_size, dtype=data.dtype) - shmcomm.Allgather(sendbuf=data, recvbuf=recv_data) - shmcomm.Free() - reshaped = recv_data.reshape(shm_size, len(data)) - return reshaped - - -def mpi_broadcast_scalar(value=None, root=0): - thread_comm = MPI.COMM_WORLD - shmcomm = thread_comm.Split_type(MPI.COMM_TYPE_SHARED) - shm_rank = shmcomm.Get_rank() - - if shm_rank == root: - assert value is not None, "Root must provide a value." - value = np.array(value) - dtype = value.dtype - else: - value = None - dtype = None - dtype = shmcomm.bcast(dtype, root=root) - if shm_rank != root: - value = np.empty(1, dtype=dtype) - else: - value = np.array([value], dtype=dtype) - shmcomm.Bcast(value, root=root) - shmcomm.Free() - return value[0] - - -def world_barrier(): - MPI.COMM_WORLD.Barrier() - - -def init_mpi(): - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() - return comm, rank, world_size diff --git a/iris/iris.py b/iris/iris.py index f67b4ed2..3ed89f3b 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -14,7 +14,7 @@ - Efficient load/store operations with rank-to-rank communication - Memory allocation and deallocation utilities - Built-in logging with rank information -- MPI integration for distributed computing +- PyTorch distributed integration for distributed computing Example: >>> import iris @@ -26,11 +26,11 @@ import triton import triton.language as tl -from iris._mpi_helpers import ( - init_mpi, - mpi_allgather, - world_barrier, - mpi_broadcast_scalar, +from iris._distributed_helpers import ( + init_distributed, + distributed_allgather, + distributed_barrier, + distributed_broadcast_scalar, ) from iris.hip import ( set_device, @@ -68,7 +68,7 @@ class Iris: def __init__(self, heap_size=1 << 30): # Initialize - comm, cur_rank, num_ranks = init_mpi() + comm, cur_rank, num_ranks = init_distributed() num_gpus = count_devices() gpu_id = cur_rank % num_gpus @@ -92,12 +92,12 @@ def __init__(self, heap_size=1 << 30): ipc_handles = np.zeros((num_ranks, 64), dtype=np.uint8) ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) - world_barrier() + distributed_barrier() - all_ipc_handles = mpi_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) - all_heap_bases = mpi_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) + all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) + all_heap_bases = distributed_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) - world_barrier() + distributed_barrier() ipc_heap_bases = np.zeros(num_ranks, dtype=np.uintp) for rank in range(num_ranks): @@ -110,10 +110,10 @@ def __init__(self, heap_size=1 << 30): for i in range(num_ranks): self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") - world_barrier() + distributed_barrier() self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) - world_barrier() + distributed_barrier() def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" @@ -194,7 +194,7 @@ def broadcast(self, value, source_rank): >>> value = 42 if iris_ctx.get_rank() == 0 else None >>> value = iris_ctx.broadcast(value, source_rank=0) """ - return mpi_broadcast_scalar(value, source_rank) + return distributed_broadcast_scalar(value, source_rank) def __allocate(self, num_elements, dtype): self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") @@ -1073,10 +1073,10 @@ def barrier(self, stream=None): Synchronize all ranks and their CUDA devices. This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has - finished all queued work, then performs a global MPI barrier so that all + finished all queued work, then performs a global distributed barrier so that all ranks reach the same point before proceeding. Args: - stream: If stream is given: wait only for that stream before MPI_Barrier. If stream is None: legacy behavior (device-wide sync). + stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). Example: >>> iris_ctx = iris.iris(1 << 20) @@ -1088,8 +1088,8 @@ def barrier(self, stream=None): else: stream.synchronize() - # MPI barrier - world_barrier() + # Distributed barrier + distributed_barrier() def get_device(self): """ @@ -1121,7 +1121,7 @@ def get_cu_count(self): def get_rank(self): """ - Get this process's rank id in the MPI communicator. + Get this process's rank id in the distributed communicator. Returns: int: Zero-based rank id of the current process. @@ -1135,7 +1135,7 @@ def get_rank(self): def get_num_ranks(self): """ - Get the total number of ranks in the MPI communicator. + Get the total number of ranks in the distributed communicator. Returns: int: World size (number of ranks). diff --git a/iris/logging.py b/iris/logging.py index e1a4d6fb..51a27941 100644 --- a/iris/logging.py +++ b/iris/logging.py @@ -2,7 +2,7 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. """ -Iris logging module - provides logging functionality without MPI dependencies. +Iris logging module - provides logging functionality. """ import logging diff --git a/pyproject.toml b/pyproject.toml index b700c83d..56ba24ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ requires-python = ">=3.8" dependencies = [ "numpy", "requests", - "mpi4py", "ruff", "triton @ git+https://github.com/triton-lang/triton.git@dd5823453bcc7973eabadb65f9d827c43281c434" ] diff --git a/tests/run_tests_distributed.py b/tests/run_tests_distributed.py new file mode 100755 index 00000000..8c4af8e1 --- /dev/null +++ b/tests/run_tests_distributed.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Simple wrapper to run pytest tests within a single distributed process group. +This avoids the overhead of creating/destroying process groups for each test case. +""" + +import sys +import subprocess +import torch.multiprocessing as mp +import torch.distributed as dist +import socket +import os + + +def _find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _distributed_worker(rank, world_size, test_file, pytest_args): + """Worker function that runs pytest within a distributed process group.""" + # Initialize distributed once for all tests + init_method = "tcp://127.0.0.1:12355" + dist.init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + + try: + # Import and run pytest directly + import pytest + import sys + + # Set up sys.argv for pytest + original_argv = sys.argv[:] + sys.argv = ["pytest", test_file] + pytest_args + + try: + # Run pytest directly in this process + exit_code = pytest.main([test_file] + pytest_args) + return exit_code + finally: + # Restore original argv + sys.argv = original_argv + + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def main(): + if len(sys.argv) < 2: + print("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] ") + sys.exit(1) + + # Get number of ranks from args or default to 2 + num_ranks = 2 + args = sys.argv[1:] + + if "--num_ranks" in args: + idx = args.index("--num_ranks") + if idx + 1 < len(args): + num_ranks = int(args[idx + 1]) + # Remove --num_ranks and its value from args + args = args[:idx] + args[idx + 2 :] + + # The test file is the first argument after --num_ranks, everything else is pytest args + if not args: + print("Error: No test file specified") + sys.exit(1) + + test_file = args[0] + pytest_args = args[1:] # Everything after the test file + + print(f"Running {test_file} with {num_ranks} ranks") + print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}") + + # Run all tests within a single distributed process group + mp.spawn(_distributed_worker, args=(num_ranks, test_file, pytest_args), nprocs=num_ranks, join=True) + + +if __name__ == "__main__": + main() From 60633e824151e8094234c729a4a24bc7cca1a476 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:02:27 +0000 Subject: [PATCH 07/13] Revert "Sync with main - merge PyTorch distributed changes" This reverts commit 7b93dd5a802eae08347d5cb36d96a24a412ac0eb. --- .github/copilot-instructions.md | 7 +- .github/workflows/iris-tests-apptainer.yml | 17 +- README.md | 62 +++----- apptainer/iris.def | 6 +- docker/Dockerfile | 12 +- docs/CONTRIBUTING.md | 8 +- docs/conceptual/programming-model.md | 2 +- docs/conf.py | 2 +- docs/getting-started/installation.md | 3 +- docs/index.md | 60 +++---- docs/reference/api-iris-class.md | 2 +- examples/00_load/README.md | 2 +- examples/00_load/load_bench.py | 27 +--- examples/01_store/store_bench.py | 27 +--- examples/02_all_load/README.md | 2 +- examples/02_all_load/all_load_bench.py | 29 +--- examples/03_all_store/README.md | 2 +- examples/03_all_store/all_store_bench.py | 28 +--- examples/04_atomic_add/README.md | 2 +- examples/04_atomic_add/atomic_add_bench.py | 35 +---- examples/05_atomic_xchg/README.md | 2 +- examples/05_atomic_xchg/atomic_xchg_bench.py | 35 +---- .../message_passing_load_store.py | 29 +--- .../06_message_passing/message_passing_put.py | 27 +--- examples/07_gemm_all_scatter/benchmark.py | 45 ++---- .../08_gemm_atomics_all_reduce/benchmark.py | 32 +--- .../matmul_wrapper.py | 4 +- .../09_gemm_one_shot_all_reduce/benchmark.py | 30 +--- .../matmul_wrapper.py | 4 +- .../benchmark.py | 24 +-- .../benchmark.py | 24 +-- .../benchmark.py | 24 +-- examples/README.md | 28 ++-- examples/benchmark/reference/gemm.py | 1 - iris/__init__.py | 2 - iris/_distributed_helpers.py | 148 ------------------ iris/_mpi_helpers.py | 51 ++++++ iris/iris.py | 40 ++--- iris/logging.py | 2 +- pyproject.toml | 1 + tests/run_tests_distributed.py | 89 ----------- 41 files changed, 230 insertions(+), 747 deletions(-) delete mode 100644 iris/_distributed_helpers.py create mode 100644 iris/_mpi_helpers.py delete mode 100755 tests/run_tests_distributed.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cf5f9a0d..3a613ec9 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -7,7 +7,7 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM - Clean abstractions with full symmetric heap implementation - Pythonic PyTorch-like host APIs for tensor operations - Triton-style device APIs for load, store, and atomic operations -- Minimal dependencies (Triton, PyTorch, HIP runtime) +- Minimal dependencies (Triton, PyTorch, HIP runtime, mpi4py) - Comprehensive examples showing communication/computation overlap **FOLLOW THESE INSTRUCTIONS EXACTLY. Reference these instructions first before using search or bash commands.** @@ -17,6 +17,7 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM - **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI350X & MI355X) > **Note**: See below for instructions on development without AMD GPU access - **ROCm/HIP Toolkit**: Required for building C++/HIP components +- **MPI**: Required for multi-GPU operations - **Docker/Apptainer**: Recommended for containerized development ## Build @@ -77,8 +78,8 @@ pytest tests/unittests/ # Run example tests pytest tests/examples/ -# Run specific example -python examples/00_load/load_bench.py +# Run specific example (requires MPI and GPU) +mpirun -np 8 python examples/00_load/load_bench.py ``` ### Code Quality diff --git a/.github/workflows/iris-tests-apptainer.yml b/.github/workflows/iris-tests-apptainer.yml index 5e2d9a85..d39e5814 100644 --- a/.github/workflows/iris-tests-apptainer.yml +++ b/.github/workflows/iris-tests-apptainer.yml @@ -52,7 +52,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Run Iris Tests with ${{ matrix.ranks }} ranks + - name: Run Iris Tests with ${{ matrix.ranks }} MPI ranks run: | apptainer exec ~/apptainer/iris-dev.sif bash -c " set -e # Exit on any error @@ -60,17 +60,20 @@ jobs: # Install iris first pip install -e . - # Run examples tests one at a time using distributed wrapper + # Create function for mpirun with root permissions + mpirun-root() { mpirun --allow-run-as-root \"\$@\"; } + + # Run examples tests one at a time echo 'Running examples tests one at a time...' for test_file in tests/examples/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" - python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short + echo \"Testing: \$test_file with ${{ matrix.ranks }} MPI ranks\" + mpirun-root -np ${{ matrix.ranks }} python -m pytest \"\$test_file\" -v --tb=short done - # Run unit tests one at a time using distributed wrapper + # Run unit tests one at a time echo 'Running unit tests one at a time...' for test_file in tests/unittests/test_*.py; do - echo \"Testing: \$test_file with ${{ matrix.ranks }} ranks\" - python tests/run_tests_distributed.py --num_ranks ${{ matrix.ranks }} \"\$test_file\" -v --tb=short + echo \"Testing: \$test_file with ${{ matrix.ranks }} MPI ranks\" + mpirun-root -np ${{ matrix.ranks }} python -m pytest \"\$test_file\" -v --tb=short done " \ No newline at end of file diff --git a/README.md b/README.md index 639e8d9f..7f44fa5b 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,7 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations. Iris Here's a simple example showing how to perform remote memory operations between GPUs using Iris: ```python -import os import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import iris @@ -50,7 +47,7 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas pid = tl.program_id(0) block_start = pid * block_size offsets = block_start + tl.arange(0, block_size) - + # Guard for out-of-bounds accesses mask = offsets < buffer_size @@ -61,40 +58,29 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas source_rank, target_rank, heap_bases_ptr, mask=mask) -def _worker(rank, world_size): - # Torch distributed initialization - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", "29500") - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - - # Iris initialization - heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication - iris_ctx = iris.iris(heap_size) - cur_rank = iris_ctx.get_rank() - - # Iris tensor allocation - buffer_size = 4096 # 4K elements buffer - buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) - - # Launch the kernel on rank 0 - block_size = 1024 - grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) - source_rank = 0 - if cur_rank == source_rank: - kernel[grid]( - buffer, - buffer_size, - block_size, - iris_ctx.get_heap_bases(), - ) - - # Synchronize all ranks - iris_ctx.barrier() - dist.destroy_process_group() - -if __name__ == "__main__": - world_size = 2 # Using two ranks - mp.spawn(_worker, args=(world_size,), nprocs=world_size, join=True) +# Iris initialization +heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication +iris_ctx = iris.iris(heap_size) +cur_rank = iris_ctx.get_rank() + +# Iris tensor allocation +buffer_size = 4096 # 4K elements buffer +buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) + +# Launch the kernel on rank 0 +block_size = 1024 +grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) +source_rank = 0 +if cur_rank == source_rank: + kernel[grid]( + buffer, + buffer_size, + block_size, + iris_ctx.get_heap_bases(), + ) + +# Synchronize all ranks +iris_ctx.barrier() ``` ## Quick Start Guide diff --git a/apptainer/iris.def b/apptainer/iris.def index 31960182..d3ca16d3 100644 --- a/apptainer/iris.def +++ b/apptainer/iris.def @@ -10,7 +10,7 @@ From: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch export TRITON_PATH=/workspace/triton conda env list source /opt/conda/bin/activate py_3.10 - conda install -y -n py_3.10 -c conda-forge jupyter ninja cmake wheel + conda install -y -n py_3.10 -c conda-forge mpi4py openmpi jupyter ninja cmake wheel git clone https://github.com/triton-lang/triton.git \$TRITON_PATH cd \$TRITON_PATH git checkout dd5823453bcc7973eabadb65f9d827c43281c434 @@ -23,9 +23,9 @@ From: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch # Define environment variables export TRITON_PATH=/workspace/triton export PYTHONPATH=$TRITON_PATH/python/ - export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/opt/rocm/lib:/usr/lib/openmpi/lib:$LD_LIBRARY_PATH export ROCM_PATH=/opt/rocm - export PATH=/opt/conda/envs/py_3.10/bin:/opt/rocm/bin:$PATH + export PATH=/opt/conda/envs/py_3.10/bin:/opt/rocm/bin:/usr/lib/openmpi/bin:$PATH export OMPI_MCA_mtl="^ofi" export OMPI_MCA_pml="ob1" diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b49c01a..73738239 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -12,8 +12,8 @@ ENV TRITON_PATH=/opt/triton \ OMPI_MCA_mtl="^ofi" \ OMPI_MCA_pml="ob1" -ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:$LD_LIBRARY_PATH \ - PATH="$ROCM_PATH/bin:$PATH" +ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:/usr/lib/openmpi/lib:$LD_LIBRARY_PATH \ + PATH="$ROCM_PATH/bin:/usr/lib/openmpi/bin:$PATH" ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ OMPI_ALLOW_RUN_AS_ROOT=1 @@ -21,13 +21,19 @@ ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ # Install system packages RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential && \ + git wget ninja-build cmake python3-pip python3-dev build-essential \ + openmpi-bin libopenmpi-dev && \ rm -rf /var/lib/apt/lists/* # Install Python packages with pip RUN pip3 install --upgrade pip && \ pip3 install wheel jupyter +# This needs sudo, I can only get it to install with sudo +# or using conda, but conda runs into issues with too many requests. +# https://stackoverflow.com/a/54052470/5729690 +RUN sudo pip3 install mpi4py + # Clone and install Triton WORKDIR $TRITON_PATH RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 3f579361..d21d49ac 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -20,12 +20,8 @@ git checkout -b $USER/your-feature-name ruff check . ruff format . -# Run tests -python tests/run_tests_distributed.py tests/examples/test_all_load_bench.py --num_ranks 2 -v -python tests/run_tests_distributed.py tests/unittests/ --num_ranks 2 -v - -# Or run individual test files -python tests/run_tests_distributed.py tests/examples/test_load_bench.py --num_ranks 2 -v +# Run tests +pytest ``` ### 4. Commit and Push diff --git a/docs/conceptual/programming-model.md b/docs/conceptual/programming-model.md index ea2f772f..cf078042 100644 --- a/docs/conceptual/programming-model.md +++ b/docs/conceptual/programming-model.md @@ -25,7 +25,7 @@ html[data-theme="light"] .theme-switch-wrapper .light-theme-img { 1. **Designed by Experts, Built for Scale** - Written from scratch by GPU and distributed computing experts - - Minimal dependencies: only Triton, PyTorch, and HIP runtime + - Minimal dependencies: only Triton, PyTorch, HIP runtime and mpi4py (for initialization) - No external frameworks or heavyweight runtimes beyond core stack 2. **Clean Abstractions** diff --git a/docs/conf.py b/docs/conf.py index 9f509212..8210169d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -94,7 +94,7 @@ "triton", "triton.language", "numpy", - "iris._distributed_helpers", + "iris._mpi_helpers", "iris.hip", ] diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index 5a38cb99..9af00357 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -4,7 +4,7 @@ This guide covers how to install Iris on your system using various methods. ## Overview -Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, and Triton. This guide will walk you through the installation process using different approaches. +Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, MPI, and Triton. This guide will walk you through the installation process using different approaches. ## Prerequisites @@ -22,6 +22,7 @@ Iris has minimal dependencies including Python, PyTorch, ROCm HIP runtime, and T - Python 3.10+ - PyTorch 2.0+ (ROCm version) - ROCm 6.3.1+ HIP runtime +- OpenMPI - Git - CMake, Ninja, build-essential - Triton (specific commit: [dd5823453bcc7973eabadb65f9d827c43281c434](https://github.com/triton-lang/triton/tree/dd5823453bcc7973eabadb65f9d827c43281c434)) diff --git a/docs/index.md b/docs/index.md index ea3e3d90..2513188b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -46,10 +46,7 @@ cd iris && pip install -e . Here's a simple example showing how to perform remote memory operations between GPUs using Iris: ```python -import os import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import iris @@ -72,40 +69,29 @@ def kernel(buffer, buffer_size: tl.constexpr, block_size: tl.constexpr, heap_bas source_rank, target_rank, heap_bases_ptr, mask=mask) -def _worker(rank, world_size): - # Torch distributed initialization - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", "29500") - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - - # Iris initialization - heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication - iris_ctx = iris.iris(heap_size) - cur_rank = iris_ctx.get_rank() - - # Iris tensor allocation - buffer_size = 4096 # 4K elements buffer - buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) - - # Launch the kernel on rank 0 - block_size = 1024 - grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) - source_rank = 0 - if cur_rank == source_rank: - kernel[grid]( - buffer, - buffer_size, - block_size, - iris_ctx.get_heap_bases(), - ) - - # Synchronize all ranks - iris_ctx.barrier() - dist.destroy_process_group() - -if __name__ == "__main__": - world_size = 2 # Using two ranks - mp.spawn(_worker, args=(world_size,), nprocs=world_size, join=True) +# Iris initialization +heap_size = 2**30 # 1GiB symmetric heap for inter-GPU communication +iris_ctx = iris.iris(heap_size) +cur_rank = iris_ctx.get_rank() + +# Iris tensor allocation +buffer_size = 4096 # 4K elements buffer +buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32) + +# Launch the kernel on rank 0 +block_size = 1024 +grid = lambda meta: (triton.cdiv(buffer_size, meta["block_size"]),) +source_rank = 0 +if cur_rank == source_rank: + kernel[grid]( + buffer, + buffer_size, + block_size, + iris_ctx.get_heap_bases(), + ) + +# Synchronize all ranks +iris_ctx.barrier() ``` For more examples, see the [Examples](reference/examples.md) page with ready-to-run scripts and usage patterns. diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index f16eccf5..aa673af5 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -33,7 +33,7 @@ Use Iris-aware logging that automatically annotates each message with the curren ## Broadcast Helper -Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper. +Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal MPI helper. ```{eval-rst} .. automethod:: iris.iris.Iris.broadcast diff --git a/examples/00_load/README.md b/examples/00_load/README.md index aab1e0f1..e985c7e2 100644 --- a/examples/00_load/README.md +++ b/examples/00_load/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -python examples/00_load/load_bench.py --num_ranks 8 +mpirun -np 8 python examples/00_load/load_bench.py ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/00_load/load_bench.py b/examples/00_load/load_bench.py index fc8e4148..1ef16775 100755 --- a/examples/00_load/load_bench.py +++ b/examples/00_load/load_bench.py @@ -5,8 +5,6 @@ import argparse import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -99,7 +97,6 @@ def parse_args(): parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") parser.add_argument("-n", "--num_experiments", type=int, default=10, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=1, help="Number of warmup iterations") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -232,12 +229,9 @@ def print_bandwidth_matrix(matrix, label="Unidirectional LOAD bandwidth GiB/s [R raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -268,23 +262,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/01_store/store_bench.py b/examples/01_store/store_bench.py index 80e7a7e0..835809f2 100755 --- a/examples/01_store/store_bench.py +++ b/examples/01_store/store_bench.py @@ -5,8 +5,6 @@ import argparse import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -84,7 +82,6 @@ def parse_args(): parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") parser.add_argument("-n", "--num_experiments", type=int, default=10, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=1, help="Number of warmup iterations") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -205,12 +202,9 @@ def print_bandwidth_matrix(matrix, label="Unidirectional STORE bandwidth GiB/s [ raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -228,23 +222,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/02_all_load/README.md b/examples/02_all_load/README.md index 13a8def2..fd7f396a 100644 --- a/examples/02_all_load/README.md +++ b/examples/02_all_load/README.md @@ -10,7 +10,7 @@ All-Load benchmark using Iris. ## Usage ```terminal -python examples/02_all_load/all_load_bench.py --num_ranks 8 +mpirun -np 8 python examples/02_all_load/all_load_bench.py ``` On an MI300X, this example will run on 8 GPUs. It prints: diff --git a/examples/02_all_load/all_load_bench.py b/examples/02_all_load/all_load_bench.py index 6fb65c79..d3ab765a 100755 --- a/examples/02_all_load/all_load_bench.py +++ b/examples/02_all_load/all_load_bench.py @@ -4,8 +4,6 @@ import argparse import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -14,6 +12,7 @@ import iris + torch.manual_seed(123) random.seed(123) @@ -125,8 +124,6 @@ def parse_args(): parser.add_argument("-w", "--num_warmup", type=int, default=2, help="Number of warmup experiments") parser.add_argument("-a", "--active_ranks", type=int, default=8, help="Number of active ranks") parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") - return vars(parser.parse_args()) @@ -313,12 +310,9 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic heap_size = args["heap_size"] shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -360,23 +354,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_data, buffer_sizes, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/03_all_store/README.md b/examples/03_all_store/README.md index 7e691735..ed2b5685 100644 --- a/examples/03_all_store/README.md +++ b/examples/03_all_store/README.md @@ -10,7 +10,7 @@ All-Store benchmark using Iris. ## Usage ```terminal -python examples/03_all_store/all_store_bench.py --num_ranks 8 +mpirun -np 8 python examples/03_all_store/all_store_bench.py ``` On an MI300X, this example will run on 8 GPUs. It prints: diff --git a/examples/03_all_store/all_store_bench.py b/examples/03_all_store/all_store_bench.py index eac5dd5d..ce17a1e3 100755 --- a/examples/03_all_store/all_store_bench.py +++ b/examples/03_all_store/all_store_bench.py @@ -5,8 +5,6 @@ import argparse import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -15,6 +13,7 @@ import iris + torch.manual_seed(123) random.seed(123) @@ -88,7 +87,6 @@ def parse_args(): parser.add_argument("-w", "--num_warmup", type=int, default=2, help="Number of warmup experiments") parser.add_argument("-a", "--active_ranks", type=int, default=8, help="Number of active ranks") parser.add_argument("-o", "--output_file", type=str, default="", help="Output file") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -242,12 +240,9 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic heap_size = args["heap_size"] shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -289,23 +284,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_data, buffer_sizes, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/04_atomic_add/README.md b/examples/04_atomic_add/README.md index 127e51fb..242cabe5 100644 --- a/examples/04_atomic_add/README.md +++ b/examples/04_atomic_add/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -python examples/04_atomic_add/atomic_add_bench.py --num_ranks 8 +mpirun -np 8 python examples/04_atomic_add/atomic_add_bench.py ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/04_atomic_add/atomic_add_bench.py b/examples/04_atomic_add/atomic_add_bench.py index 9b6dfb4f..6b292736 100755 --- a/examples/04_atomic_add/atomic_add_bench.py +++ b/examples/04_atomic_add/atomic_add_bench.py @@ -3,18 +3,16 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse -import json -import random -import numpy as np import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl - +import random +import numpy as np +import json import iris + torch.manual_seed(123) random.seed(123) @@ -80,7 +78,6 @@ def parse_args(): parser.add_argument("-x", "--num_experiments", type=int, default=16, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=4, help="Number of warmup experiments") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -205,12 +202,9 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -229,23 +223,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/05_atomic_xchg/README.md b/examples/05_atomic_xchg/README.md index 691b70e0..fff66eed 100644 --- a/examples/05_atomic_xchg/README.md +++ b/examples/05_atomic_xchg/README.md @@ -10,7 +10,7 @@ Load benchmark using Iris. ## Usage ```terminal -python examples/05_atomic_xchg/atomic_xchg_bench.py --num_ranks 8 +mpirun -np 8 python examples/05_atomic_xchg/atomic_xchg_bench.py ``` On an MI300X, this example will run on 8 GPUs. It prints: ```terminal diff --git a/examples/05_atomic_xchg/atomic_xchg_bench.py b/examples/05_atomic_xchg/atomic_xchg_bench.py index 89d7792f..51ae2410 100755 --- a/examples/05_atomic_xchg/atomic_xchg_bench.py +++ b/examples/05_atomic_xchg/atomic_xchg_bench.py @@ -3,18 +3,16 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse -import json -import random -import numpy as np import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl - +import random +import numpy as np +import json import iris + torch.manual_seed(123) random.seed(123) @@ -84,7 +82,6 @@ def parse_args(): parser.add_argument("-x", "--num_experiments", type=int, default=16, help="Number of experiments") parser.add_argument("-w", "--num_warmup", type=int, default=4, help="Number of warmup experiments") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -209,12 +206,9 @@ def print_bandwidth_matrix( raise ValueError(f"Unsupported output file extension: {output_file}") -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) num_ranks = shmem.get_num_ranks() bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) @@ -233,23 +227,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if shmem.get_rank() == 0: print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"]) - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/06_message_passing/message_passing_load_store.py b/examples/06_message_passing/message_passing_load_store.py index 37db8bcb..9e19f564 100755 --- a/examples/06_message_passing/message_passing_load_store.py +++ b/examples/06_message_passing/message_passing_load_store.py @@ -3,13 +3,11 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse -import random import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl +import random import iris @@ -127,17 +125,13 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) dtype = torch_dtype_from_str(args["datatype"]) cur_rank = shmem.get_rank() @@ -205,23 +199,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem.barrier() - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index b4d54064..fc849b8b 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -4,8 +4,6 @@ import argparse import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random @@ -115,17 +113,13 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) dtype = torch_dtype_from_str(args["datatype"]) cur_rank = shmem.get_rank() @@ -193,23 +187,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem.barrier() - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 6d872e61..fa10b0c3 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -2,22 +2,21 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -import argparse -import json -import os -import random -import sys - import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton -from matmul_wrapper import matmul +import random +import sys +import os +import argparse +import json -import iris from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm +import iris + +from matmul_wrapper import matmul + torch.manual_seed(123) random.seed(123) @@ -53,17 +52,13 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -243,24 +238,6 @@ def run_experiment(): shmem.barrier() - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - # Use command line argument if provided, otherwise use num_ranks parameter - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index c17fb5ba..45503492 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -3,8 +3,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import random import sys @@ -70,19 +68,14 @@ def parse_args(): # For All Scatter, use: 288 # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for Stream-K") parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") - return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() - # Main benchmark logic shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() @@ -135,7 +128,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_tiles = total_blocks_M * total_blocks_N if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") + print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}") exit(1) tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) @@ -285,23 +278,6 @@ def run_experiment(): shmem.barrier() - dist.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - if __name__ == "__main__": main() diff --git a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py index ab014faf..6d6fcf2e 100644 --- a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py +++ b/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py @@ -70,10 +70,10 @@ def _call( total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - if total_programs_streamk > 0: # GEMM + if total_programs_streamk > 0: # Stream-K # last wave may occupy less than total_programs_streamk SMs total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile GEMM + data-parallel from original paper + # for two-tile Stream-K + data-parallel from original paper # if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: # total_tiles_streamk += total_programs_streamk # remaining tiles are computed using classical blocking diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index a79fc5fc..f7c531e9 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -3,8 +3,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import random import sys @@ -68,16 +66,15 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") + # For All Scatter, use: 288 + # For One Shot, use: 256 + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for Stream-K") parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -131,7 +128,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_tiles = total_blocks_M * total_blocks_N if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") + print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}") exit(1) tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) @@ -280,21 +277,6 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) if __name__ == "__main__": diff --git a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py index 83c9326f..468a16d7 100644 --- a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py +++ b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py @@ -70,10 +70,10 @@ def _call( total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - if total_programs_streamk > 0: # GEMM + if total_programs_streamk > 0: # Stream-K # last wave may occupy less than total_programs_streamk SMs total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile GEMM + data-parallel from original paper + # for two-tile Stream-K + data-parallel from original paper # if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: # total_tiles_streamk += total_programs_streamk # remaining tiles are computed using classical blocking diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index bb49bacb..f36c85a8 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -3,8 +3,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import random import sys @@ -57,15 +55,12 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -250,21 +245,6 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) if __name__ == "__main__": diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 41c164a8..19e08c7e 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -3,8 +3,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import random import sys @@ -58,15 +56,12 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -284,21 +279,6 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) if __name__ == "__main__": diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 5cdc3819..840db13d 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -3,8 +3,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist -import torch.multiprocessing as mp import triton import random import sys @@ -58,15 +56,12 @@ def parse_args(): "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" ) parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) -def _worker(local_rank: int, world_size: int, init_url: str, args: dict): - """Worker function for PyTorch distributed execution.""" - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) +def main(): + args = parse_args() shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() @@ -280,21 +275,6 @@ def run_experiment(): timestamps.to_json(filename, gpu_freq) shmem.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) if __name__ == "__main__": diff --git a/examples/README.md b/examples/README.md index 0794d70f..92c394f3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -34,39 +34,39 @@ This directory contains various algorithm implementations for distributed comput ### Basic Operations ```terminal # Example command to run distributed load operations -python examples/00_load/load_bench.py --num_ranks 8 # Load across GPUs -python examples/02_all_load/all_load_bench.py --num_ranks 8 # Simultaneous load on all GPUs +mpirun -np 8 python examples/00_load/load_bench.py # Load across GPUs +mpirun -np 8 python examples/02_all_load/all_load_bench.py # Simultaneous load on all GPUs # Example command to run distributed store operations -python examples/01_store/store_bench.py --num_ranks 8 # Store across GPUs -python examples/03_all_store/all_store_bench.py --num_ranks 8 # Simultaneous store on all GPUs +mpirun -np 8 python examples/01_store/store_bench.py # Store across GPUs +mpirun -np 8 python examples/03_all_store/all_store_bench.py # Simultaneous store on all GPUs # Example command to run atomic operations -python examples/04_atomic_add/atomic_add_bench.py --num_ranks 8 # Atomic add across GPUs -python examples/05_atomic_xchg/atomic_xchg_bench.py --num_ranks 8 # Atomic exchange across GPUs +mpirun -np 8 python examples/04_atomic_add/atomic_add_bench.py # Atomic add across GPUs +mpirun -np 8 python examples/05_atomic_xchg/atomic_xchg_bench.py # Atomic exchange across GPUs # Example command to run message passing -python examples/06_message_passing/message_passing_put.py --num_ranks 8 -python examples/06_message_passing/message_passing_load_store.py --num_ranks 8 +python examples/06_message_passing/message_passing_put.py +python examples/06_message_passing/message_passing_load_store.py ``` ### GEMM Operations ```terminal # Example command to run benchmark with all-scatter algorithm -python examples/07_gemm_all_scatter/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/07_gemm_all_scatter/benchmark.py --benchmark --validate # Example command to run benchmark with all-reduce algorithm -python examples/08_gemm_atomics_all_reduce/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/08_gemm_atomics_all_reduce/benchmark.py --benchmark --validate # Example command to run benchmark with one-shot all-reduce algorithm -python examples/09_gemm_one_shot_all_reduce/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/09_gemm_one_shot_all_reduce/benchmark.py --benchmark --validate # Example command to run benchmark with all-scatter and workgroup specialization -python examples/10_gemm_all_scatter_wg_specialization/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/10_gemm_all_scatter_wg_specialization/benchmark.py --benchmark --validate # Example command to run benchmark with all-scatter producer-consumer pattern -python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark --validate # Example command to run benchmark with all-scatter bulk synchronous approach -python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate --num_ranks 8 +mpirun -np 8 python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate ``` diff --git a/examples/benchmark/reference/gemm.py b/examples/benchmark/reference/gemm.py index b37b7e5a..61bed8c8 100755 --- a/examples/benchmark/reference/gemm.py +++ b/examples/benchmark/reference/gemm.py @@ -2,7 +2,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import torch.distributed as dist import triton import torch.cuda.nvtx as nvtx diff --git a/iris/__init__.py b/iris/__init__.py index 088a5334..f252ac4d 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -62,8 +62,6 @@ ERROR, ) -# Launcher functionality is now user code - see examples and documentation - # Pipe allocations via finegrained allocator current_dir = os.path.dirname(__file__) # Look for the library in the installed package location diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py deleted file mode 100644 index ce656f93..00000000 --- a/iris/_distributed_helpers.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -import torch -import torch.distributed as dist -import numpy as np - - -def _infer_device(): - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed is not initialized") - try: - backend = str(dist.get_backend()).lower() - except Exception: - backend = "gloo" - if backend == "nccl" and torch.cuda.is_available(): - return torch.device("cuda", torch.cuda.current_device()) - return torch.device("cpu") - - -def _nccl_dtype_supported(t: torch.Tensor) -> bool: - """Conservative whitelist for NCCL tensor dtypes.""" - supported = { - torch.int8, - torch.uint8, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - } - # bfloat16 is commonly supported in recent stacks; include if available - if hasattr(torch, "bfloat16"): - supported.add(torch.bfloat16) - return t.dtype in supported - - -def distributed_allgather(data): - """ - All-gather operation using PyTorch distributed. - - Args: - data: 1D numpy array to gather across all ranks - - Returns: - 2D numpy array with shape (world_size, len(data)) - """ - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed is not initialized") - - data = np.asarray(data) - assert data.ndim == 1, "Only 1D arrays are supported." - - world_size = dist.get_world_size() - device = _infer_device() - backend = str(dist.get_backend()).lower() - - # Fast path: tensor all_gather if dtype is NCCL-supported or backend != nccl - data_tensor = torch.from_numpy(data) - use_tensor_collective = backend != "nccl" or _nccl_dtype_supported(data_tensor) - - if use_tensor_collective: - data_tensor = data_tensor.to(device) - gathered_tensors = [torch.empty_like(data_tensor) for _ in range(world_size)] - dist.all_gather(gathered_tensors, data_tensor) - return torch.stack(gathered_tensors, dim=0).to("cpu").numpy() - - # Fallback for NCCL-unsupported dtypes (e.g., uint64/bool/etc.) - obj_list = [None for _ in range(world_size)] - # Use object collective (works across backends) - dist.all_gather_object(obj_list, data) - # Ensure uniform shapes and stack - return np.stack(obj_list, axis=0) - - -def distributed_broadcast_scalar(value=None, root=0): - """ - Broadcast a scalar value from root to all ranks. - - Args: - value: Value to broadcast (only used on root rank) - root: Root rank to broadcast from - - Returns: - Broadcasted value - """ - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed is not initialized") - - rank = dist.get_rank() - device = _infer_device() - backend = str(dist.get_backend()).lower() - - # First agree on dtype (numpy dtype object) - if rank == root: - if value is None: - raise ValueError("Root must provide a value.") - np_val = np.array(value) # captures dtype - dtype = np_val.dtype - else: - np_val = None - dtype = None - - dtype_obj = [dtype] - dist.broadcast_object_list(dtype_obj, src=root) - dtype = dtype_obj[0] - - # If NCCL can't handle this dtype, just broadcast the object directly. - if backend == "nccl": - # Try a quick check using a tiny tensor of the dtype - torch_dtype = torch.from_numpy(np.array(0, dtype=dtype)).dtype - dummy = torch.empty((), dtype=torch_dtype) - if not _nccl_dtype_supported(dummy): - obj = [value if rank == root else None] - dist.broadcast_object_list(obj, src=root) - return obj[0] - - # Tensor path: create a 0-D tensor, broadcast on the selected device - if rank != root: - np_val = np.empty((), dtype=dtype) - val_t = torch.from_numpy(np_val).to(device) - dist.broadcast(val_t, src=root) - return val_t.to("cpu").item() - - -def distributed_barrier(): - """ - Synchronization barrier using PyTorch distributed. - """ - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed is not initialized") - dist.barrier() - - -def init_distributed(): - """ - Initialize PyTorch distributed and return communicator info. - - Returns: - tuple: (communicator_placeholder, rank, world_size) - Note: communicator_placeholder is None since PyTorch distributed - uses global state rather than explicit communicator objects - """ - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed is not initialized. Call dist.init_process_group() first.") - rank = dist.get_rank() - world_size = dist.get_world_size() - return None, rank, world_size diff --git a/iris/_mpi_helpers.py b/iris/_mpi_helpers.py new file mode 100644 index 00000000..ff59a650 --- /dev/null +++ b/iris/_mpi_helpers.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +from mpi4py import MPI +import numpy as np + + +def mpi_allgather(data): + thread_comm = MPI.COMM_WORLD + shmcomm = thread_comm.Split_type(MPI.COMM_TYPE_SHARED) + shm_size = shmcomm.Get_size() + data = np.asarray(data) + assert len(data.shape) == 1, "Only 1D arrays are supported." + recv_data = np.empty(len(data) * shm_size, dtype=data.dtype) + shmcomm.Allgather(sendbuf=data, recvbuf=recv_data) + shmcomm.Free() + reshaped = recv_data.reshape(shm_size, len(data)) + return reshaped + + +def mpi_broadcast_scalar(value=None, root=0): + thread_comm = MPI.COMM_WORLD + shmcomm = thread_comm.Split_type(MPI.COMM_TYPE_SHARED) + shm_rank = shmcomm.Get_rank() + + if shm_rank == root: + assert value is not None, "Root must provide a value." + value = np.array(value) + dtype = value.dtype + else: + value = None + dtype = None + dtype = shmcomm.bcast(dtype, root=root) + if shm_rank != root: + value = np.empty(1, dtype=dtype) + else: + value = np.array([value], dtype=dtype) + shmcomm.Bcast(value, root=root) + shmcomm.Free() + return value[0] + + +def world_barrier(): + MPI.COMM_WORLD.Barrier() + + +def init_mpi(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + return comm, rank, world_size diff --git a/iris/iris.py b/iris/iris.py index 3ed89f3b..f67b4ed2 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -14,7 +14,7 @@ - Efficient load/store operations with rank-to-rank communication - Memory allocation and deallocation utilities - Built-in logging with rank information -- PyTorch distributed integration for distributed computing +- MPI integration for distributed computing Example: >>> import iris @@ -26,11 +26,11 @@ import triton import triton.language as tl -from iris._distributed_helpers import ( - init_distributed, - distributed_allgather, - distributed_barrier, - distributed_broadcast_scalar, +from iris._mpi_helpers import ( + init_mpi, + mpi_allgather, + world_barrier, + mpi_broadcast_scalar, ) from iris.hip import ( set_device, @@ -68,7 +68,7 @@ class Iris: def __init__(self, heap_size=1 << 30): # Initialize - comm, cur_rank, num_ranks = init_distributed() + comm, cur_rank, num_ranks = init_mpi() num_gpus = count_devices() gpu_id = cur_rank % num_gpus @@ -92,12 +92,12 @@ def __init__(self, heap_size=1 << 30): ipc_handles = np.zeros((num_ranks, 64), dtype=np.uint8) ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) - distributed_barrier() + world_barrier() - all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) - all_heap_bases = distributed_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) + all_ipc_handles = mpi_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) + all_heap_bases = mpi_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) - distributed_barrier() + world_barrier() ipc_heap_bases = np.zeros(num_ranks, dtype=np.uintp) for rank in range(num_ranks): @@ -110,10 +110,10 @@ def __init__(self, heap_size=1 << 30): for i in range(num_ranks): self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") - distributed_barrier() + world_barrier() self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) - distributed_barrier() + world_barrier() def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" @@ -194,7 +194,7 @@ def broadcast(self, value, source_rank): >>> value = 42 if iris_ctx.get_rank() == 0 else None >>> value = iris_ctx.broadcast(value, source_rank=0) """ - return distributed_broadcast_scalar(value, source_rank) + return mpi_broadcast_scalar(value, source_rank) def __allocate(self, num_elements, dtype): self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") @@ -1073,10 +1073,10 @@ def barrier(self, stream=None): Synchronize all ranks and their CUDA devices. This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has - finished all queued work, then performs a global distributed barrier so that all + finished all queued work, then performs a global MPI barrier so that all ranks reach the same point before proceeding. Args: - stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). + stream: If stream is given: wait only for that stream before MPI_Barrier. If stream is None: legacy behavior (device-wide sync). Example: >>> iris_ctx = iris.iris(1 << 20) @@ -1088,8 +1088,8 @@ def barrier(self, stream=None): else: stream.synchronize() - # Distributed barrier - distributed_barrier() + # MPI barrier + world_barrier() def get_device(self): """ @@ -1121,7 +1121,7 @@ def get_cu_count(self): def get_rank(self): """ - Get this process's rank id in the distributed communicator. + Get this process's rank id in the MPI communicator. Returns: int: Zero-based rank id of the current process. @@ -1135,7 +1135,7 @@ def get_rank(self): def get_num_ranks(self): """ - Get the total number of ranks in the distributed communicator. + Get the total number of ranks in the MPI communicator. Returns: int: World size (number of ranks). diff --git a/iris/logging.py b/iris/logging.py index 51a27941..e1a4d6fb 100644 --- a/iris/logging.py +++ b/iris/logging.py @@ -2,7 +2,7 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. """ -Iris logging module - provides logging functionality. +Iris logging module - provides logging functionality without MPI dependencies. """ import logging diff --git a/pyproject.toml b/pyproject.toml index 56ba24ad..b700c83d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ requires-python = ">=3.8" dependencies = [ "numpy", "requests", + "mpi4py", "ruff", "triton @ git+https://github.com/triton-lang/triton.git@dd5823453bcc7973eabadb65f9d827c43281c434" ] diff --git a/tests/run_tests_distributed.py b/tests/run_tests_distributed.py deleted file mode 100755 index 8c4af8e1..00000000 --- a/tests/run_tests_distributed.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -""" -Simple wrapper to run pytest tests within a single distributed process group. -This avoids the overhead of creating/destroying process groups for each test case. -""" - -import sys -import subprocess -import torch.multiprocessing as mp -import torch.distributed as dist -import socket -import os - - -def _find_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def _distributed_worker(rank, world_size, test_file, pytest_args): - """Worker function that runs pytest within a distributed process group.""" - # Initialize distributed once for all tests - init_method = "tcp://127.0.0.1:12355" - dist.init_process_group( - backend="nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - ) - - try: - # Import and run pytest directly - import pytest - import sys - - # Set up sys.argv for pytest - original_argv = sys.argv[:] - sys.argv = ["pytest", test_file] + pytest_args - - try: - # Run pytest directly in this process - exit_code = pytest.main([test_file] + pytest_args) - return exit_code - finally: - # Restore original argv - sys.argv = original_argv - - finally: - if dist.is_initialized(): - dist.destroy_process_group() - - -def main(): - if len(sys.argv) < 2: - print("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] ") - sys.exit(1) - - # Get number of ranks from args or default to 2 - num_ranks = 2 - args = sys.argv[1:] - - if "--num_ranks" in args: - idx = args.index("--num_ranks") - if idx + 1 < len(args): - num_ranks = int(args[idx + 1]) - # Remove --num_ranks and its value from args - args = args[:idx] + args[idx + 2 :] - - # The test file is the first argument after --num_ranks, everything else is pytest args - if not args: - print("Error: No test file specified") - sys.exit(1) - - test_file = args[0] - pytest_args = args[1:] # Everything after the test file - - print(f"Running {test_file} with {num_ranks} ranks") - print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}") - - # Run all tests within a single distributed process group - mp.spawn(_distributed_worker, args=(num_ranks, test_file, pytest_args), nprocs=num_ranks, join=True) - - -if __name__ == "__main__": - main() From 15f266d2aa11e75c9abfb735da77d68748b9991f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Wed, 10 Sep 2025 18:20:24 -0500 Subject: [PATCH 08/13] Sanitize and test code examples Signed-off-by: Muhammad Awad --- iris/iris.py | 115 +++++++++++++++++++++++++----------------------- iris/logging.py | 5 +++ 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index 3ed89f3b..f12da204 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -20,7 +20,6 @@ >>> import iris >>> ctx = iris.iris(heap_size=2**30) # 1GB heap >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) - >>> ctx.atomic_add(tensor.data_ptr(), 1.0, 0, 1) """ import triton @@ -62,7 +61,7 @@ class Iris: Example: >>> ctx = iris.iris(heap_size=2**31) # 2GB heap - >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") + >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) """ @@ -138,7 +137,9 @@ def debug(self, message): formatters can display the originating rank and world size. Example: - >>> iris_ctx.debug("Allocating buffers") + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers """ self._log_with_rank(logging.DEBUG, message) @@ -150,7 +151,8 @@ def info(self, message): message (str): Human-readable message to log at info level. Example: - >>> iris_ctx.info("Starting iteration 0") + >>> ctx = iris.iris() + >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 """ self._log_with_rank(logging.INFO, message) @@ -162,7 +164,8 @@ def warning(self, message): message (str): Human-readable message to log at warning level. Example: - >>> iris_ctx.warning("Memory usage is high") + >>> ctx = iris.iris() + >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high """ self._log_with_rank(logging.WARNING, message) @@ -174,7 +177,8 @@ def error(self, message): message (str): Human-readable message to log at error level. Example: - >>> iris_ctx.error("Failed to allocate memory") + >>> ctx = iris.iris() + >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory """ self._log_with_rank(logging.ERROR, message) @@ -191,8 +195,9 @@ def broadcast(self, value, source_rank): Any: The value broadcast to all ranks. Example: - >>> value = 42 if iris_ctx.get_rank() == 0 else None - >>> value = iris_ctx.broadcast(value, source_rank=0) + >>> ctx = iris.iris() + >>> value = 42 if ctx.cur_rank == 0 else None + >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 """ return distributed_broadcast_scalar(value, source_rank) @@ -241,9 +246,9 @@ def zeros_like( Default: torch.preserve_format. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> input_tensor = iris_ctx.ones(2, 3) - >>> zeros_tensor = iris_ctx.zeros_like(input_tensor) + >>> ctx = iris.iris(1 << 20) + >>> input_tensor = ctx.ones(2, 3) + >>> zeros_tensor = ctx.zeros_like(input_tensor) >>> print(zeros_tensor.shape) # torch.Size([2, 3]) """ self.debug( @@ -318,8 +323,8 @@ def arange( requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] >>> print(tensor.shape) # torch.Size([5]) """ self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") @@ -395,9 +400,10 @@ def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.zeros(2, 3) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.zeros(2, 3) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') """ self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -490,9 +496,10 @@ def randn( Works only for CPU tensors. Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.randn(2, 3) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randn(2, 3) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') """ self.debug( f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -558,9 +565,10 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.ones(2, 3) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.ones(2, 3) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') """ self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -622,9 +630,10 @@ def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.full((2, 3), 3.14) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.full((2, 3), 3.14) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') """ self.debug( f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -688,9 +697,10 @@ def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): Tensor: A tensor filled with random numbers from a uniform distribution. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.uniform((2, 3), low=0.0, high=1.0) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") size, num_elements = self.__parse_size(size) @@ -738,8 +748,8 @@ def empty( Default: torch.contiguous_format. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.empty(2, 3) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.empty(2, 3) >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( @@ -807,9 +817,10 @@ def randint( requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') """ self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -893,9 +904,9 @@ def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] - >>> print(tensor.shape) # torch.Size([5]) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') """ self.debug( f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -999,9 +1010,10 @@ def rand( Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> tensor = iris_ctx.rand(2, 3) # Random values in [0, 1) + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ self.debug( f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -1062,8 +1074,8 @@ def get_heap_bases(self): heap translation. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> heap_bases = iris_ctx.get_heap_bases() + >>> ctx = iris.iris(1 << 20) + >>> heap_bases = ctx.get_heap_bases() >>> print(heap_bases.shape) # torch.Size([num_ranks]) """ return self.heap_bases @@ -1079,8 +1091,8 @@ def barrier(self, stream=None): stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> iris_ctx.barrier() # Synchronize all ranks + >>> ctx = iris.iris(1 << 20) + >>> ctx.barrier() # Synchronize all ranks """ # Wait for all GPUs to finish work if stream is None: @@ -1099,8 +1111,8 @@ def get_device(self): torch.device: The CUDA device of Iris-managed memory. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> device = iris_ctx.get_device() + >>> ctx = iris.iris(1 << 20) + >>> device = ctx.get_device() >>> print(device) # cuda:0 """ return self.memory_pool.device @@ -1113,9 +1125,9 @@ def get_cu_count(self): int: Number of compute units on this rank's GPU. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> cu_count = iris_ctx.get_cu_count() - >>> print(f"GPU has {cu_count} CUs") + >>> ctx = iris.iris(1 << 20) + >>> cu_count = ctx.get_cu_count() + >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs """ return get_cu_count(self.gpu_id) @@ -1127,9 +1139,9 @@ def get_rank(self): int: Zero-based rank id of the current process. Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> rank = iris_ctx.get_rank() - >>> print(f"This is rank {rank}") + >>> ctx = iris.iris(1 << 20) + >>> rank = ctx.get_rank() + >>> print(f"This is rank {rank}") # This is rank 0 """ return self.cur_rank @@ -1141,9 +1153,9 @@ def get_num_ranks(self): int: World size (number of ranks). Example: - >>> iris_ctx = iris.iris(1 << 20) - >>> num_ranks = iris_ctx.get_num_ranks() - >>> print(f"Total ranks: {num_ranks}") + >>> ctx = iris.iris(1 << 20) + >>> num_ranks = ctx.get_num_ranks() + >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 """ return self.num_ranks @@ -1619,7 +1631,6 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 1 # Remote rank (destination) >>> increment = 5 >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1656,7 +1667,6 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 2 # Remote rank (destination) >>> decrement = 3 >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1694,7 +1704,6 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop >>> expected = 0 >>> new_val = 42 >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @@ -1731,7 +1740,6 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non >>> remote_rank = 1 # Remote rank (destination) >>> new_value = 99 >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1768,7 +1776,6 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 1 # Remote rank (destination) >>> mask_val = 0xFF >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1805,7 +1812,6 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 1 # Remote rank (destination) >>> mask_val = 0x0F >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1842,7 +1848,6 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, >>> remote_rank = 1 # Remote rank (destination) >>> mask_val = 0xF0 >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1879,7 +1884,6 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 1 # Remote rank (destination) >>> new_val = 10 >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1916,7 +1920,6 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> remote_rank = 1 # Remote rank (destination) >>> new_val = 100 >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) - >>> return old_val """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) diff --git a/iris/logging.py b/iris/logging.py index 51a27941..3aa5d232 100644 --- a/iris/logging.py +++ b/iris/logging.py @@ -51,5 +51,10 @@ def set_logger_level(level): Args: level: Logging level (iris.DEBUG, iris.INFO, iris.WARNING, iris.ERROR) + + Example: + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("This will now be visible") # [Iris] [0/1] This will now be visible """ logger.setLevel(level) From 8d037799f6608a71dfdc79dcae26ee499e1cabf5 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Wed, 10 Sep 2025 18:21:25 -0500 Subject: [PATCH 09/13] Add `set_logger_level` dooc string Signed-off-by: Muhammad Awad --- docs/reference/api-iris-class.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index f16eccf5..48ddaa66 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -24,6 +24,7 @@ Prefer using the convenience factory over calling the constructor directly: Use Iris-aware logging that automatically annotates each message with the current rank and world size. This is helpful when debugging multi-rank programs. ```{eval-rst} +.. autofunction:: iris.logging.set_logger_level .. automethod:: iris.iris.Iris.debug .. automethod:: iris.iris.Iris.info .. automethod:: iris.iris.Iris.warning From bc8b58e20b73e959ab48ec942aabbd0d53f1204a Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Wed, 10 Sep 2025 18:28:20 -0500 Subject: [PATCH 10/13] Imrpove docstring Signed-off-by: Muhammad Awad --- iris/iris.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index f12da204..e29d1c58 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -4,13 +4,13 @@ """ Iris: Multi-GPU Communication and Memory Management Framework -Iris is a high-performance framework for multi-GPU communication and memory management, -providing efficient distributed tensor operations, atomic operations, and memory allocation -across multiple GPUs in a cluster. +Iris is a high-performance framework that enables seamless multi-GPU programming in Triton, +enabling fine-grained communication and compute overlap natively in Triton +across multiple GPUs with SHMEM-like Remote Memory Access (RMA) capabilities. Key Features: - Symmetric heap management across multiple GPUs -- High-performance atomic operations (add, sub, cas, xchg, xor, and, or, min, max) +- High-performance atomic operations (add, cas, xchg, xor, and, or, min, max) - Efficient load/store operations with rank-to-rank communication - Memory allocation and deallocation utilities - Built-in logging with rank information From e1fd7d516425281a0e3fbd1e3e3792fdde024eb2 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Wed, 10 Sep 2025 18:32:56 -0500 Subject: [PATCH 11/13] Fix concurrency of deploying docs Signed-off-by: Muhammad Awad --- .github/workflows/docs.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 701fb824..3f8cd009 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -14,6 +14,7 @@ on: - 'docs/**' - 'iris/**' - 'examples/**' + - '.github/workflows/docs.yml' permissions: contents: read @@ -21,8 +22,8 @@ permissions: id-token: write concurrency: - group: "pages" - cancel-in-progress: true + group: "pages-${{ github.ref }}" + cancel-in-progress: false jobs: build: @@ -65,7 +66,7 @@ jobs: url: ${{ steps.deployment.outputs.page_url }} runs-on: ubuntu-latest needs: build - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' && github.event_name == 'push' steps: - name: Deploy to GitHub Pages id: deployment From e00f714b33836c427eeef12857a6291cb536805d Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Wed, 10 Sep 2025 18:41:23 -0500 Subject: [PATCH 12/13] Remove unneded function Signed-off-by: Muhammad Awad --- docs/reference/api-iris-class.md | 6 ++ .../08_gemm_atomics_all_reduce/benchmark.py | 2 +- .../09_gemm_one_shot_all_reduce/benchmark.py | 2 +- iris/__init__.py | 4 +- iris/util.py | 60 +++++++------------ 5 files changed, 31 insertions(+), 43 deletions(-) diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index 48ddaa66..a14fb680 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -32,6 +32,12 @@ Use Iris-aware logging that automatically annotates each message with the curren ``` +## Utility Functions + +```{eval-rst} +.. autofunction:: iris.util.do_bench +``` + ## Broadcast Helper Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper. diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index c17fb5ba..31de4fa3 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -167,7 +167,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): def preamble(): shmem.barrier() - iris.memset_tensor(tile_completed, 0) + tile_completed.zero_() shmem.barrier() def run_experiment(): diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index a79fc5fc..212bc857 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -163,7 +163,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): def preamble(): shmem.barrier() - iris.memset_tensor(tile_completed, 0) + tile_completed.zero_() shmem.barrier() def run_experiment(): diff --git a/iris/__init__.py b/iris/__init__.py index d11dd587..d50319db 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -12,7 +12,7 @@ - Iris: Main class for multi-GPU operations - Atomic operations: add, sub, cas, xchg, xor, and, or, min, max - Memory operations: load, store, get, put -- Utility functions: do_bench, memset_tensor +- Utility functions: do_bench - HIP integration for AMD GPU support - Logging utilities with rank information @@ -46,7 +46,6 @@ from .util import ( do_bench, - memset_tensor, ) from . import hip @@ -98,7 +97,6 @@ "atomic_min", "atomic_max", "do_bench", - "memset_tensor", "hip", "set_logger_level", "logger", diff --git a/iris/util.py b/iris/util.py index 03d24a46..36b70e9b 100644 --- a/iris/util.py +++ b/iris/util.py @@ -1,6 +1,28 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + import statistics import math import triton @@ -124,41 +146,3 @@ def do_bench( times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] return _summarize_statistics(times, quantiles, return_mode) - - -@triton.jit -def memset_kernel(ptr, value, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - v = tl.full([BLOCK_SIZE], value, dtype=tl.int32) - tl.store(ptr + offsets, v, mask=mask) - - -def memset_tensor(tensor, value): - """ - Set all elements of a tensor to a specified value using a Triton kernel. - - Args: - tensor (torch.Tensor): Contiguous int32 tensor to modify in-place. - value (int): Value to set all elements to. - - Example: - >>> import iris - >>> import torch - >>> tensor = torch.zeros(100, dtype=torch.int32, device='cuda') - >>> iris.memset_tensor(tensor, 42) - >>> assert torch.all(tensor == 42) - """ - assert tensor.is_contiguous(), "Tensor must be contiguous" - assert tensor.dtype == torch.int32, "Only torch.int32 tensors are supported" - n_elements = tensor.numel() - BLOCK_SIZE = 1024 - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - memset_kernel[grid]( - tensor, - value, - n_elements, - BLOCK_SIZE=BLOCK_SIZE, - ) From 2365dde7c85d26fecde3217ffd8498c4b2f00163 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 10 Sep 2025 23:41:43 +0000 Subject: [PATCH 13/13] Apply Ruff auto-fixes --- iris/util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/iris/util.py b/iris/util.py index 36b70e9b..8c861851 100644 --- a/iris/util.py +++ b/iris/util.py @@ -31,8 +31,6 @@ def get_empty_cache_for_benchmark(): - import torch - cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") @@ -42,8 +40,6 @@ def clear_cache(cache): def create_timing_event(): - import torch - return torch.cuda.Event(enable_timing=True)