diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 701fb824..3f8cd009 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -14,6 +14,7 @@ on: - 'docs/**' - 'iris/**' - 'examples/**' + - '.github/workflows/docs.yml' permissions: contents: read @@ -21,8 +22,8 @@ permissions: id-token: write concurrency: - group: "pages" - cancel-in-progress: true + group: "pages-${{ github.ref }}" + cancel-in-progress: false jobs: build: @@ -65,7 +66,7 @@ jobs: url: ${{ steps.deployment.outputs.page_url }} runs-on: ubuntu-latest needs: build - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' && github.event_name == 'push' steps: - name: Deploy to GitHub Pages id: deployment diff --git a/docs/reference/api-iris-class.md b/docs/reference/api-iris-class.md index f16eccf5..a14fb680 100644 --- a/docs/reference/api-iris-class.md +++ b/docs/reference/api-iris-class.md @@ -24,6 +24,7 @@ Prefer using the convenience factory over calling the constructor directly: Use Iris-aware logging that automatically annotates each message with the current rank and world size. This is helpful when debugging multi-rank programs. ```{eval-rst} +.. autofunction:: iris.logging.set_logger_level .. automethod:: iris.iris.Iris.debug .. automethod:: iris.iris.Iris.info .. automethod:: iris.iris.Iris.warning @@ -31,6 +32,12 @@ Use Iris-aware logging that automatically annotates each message with the curren ``` +## Utility Functions + +```{eval-rst} +.. autofunction:: iris.util.do_bench +``` + ## Broadcast Helper Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper. diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index c17fb5ba..31de4fa3 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -167,7 +167,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): def preamble(): shmem.barrier() - iris.memset_tensor(tile_completed, 0) + tile_completed.zero_() shmem.barrier() def run_experiment(): diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index a79fc5fc..212bc857 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -163,7 +163,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): def preamble(): shmem.barrier() - iris.memset_tensor(tile_completed, 0) + tile_completed.zero_() shmem.barrier() def run_experiment(): diff --git a/iris/__init__.py b/iris/__init__.py index d11dd587..d50319db 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -12,7 +12,7 @@ - Iris: Main class for multi-GPU operations - Atomic operations: add, sub, cas, xchg, xor, and, or, min, max - Memory operations: load, store, get, put -- Utility functions: do_bench, memset_tensor +- Utility functions: do_bench - HIP integration for AMD GPU support - Logging utilities with rank information @@ -46,7 +46,6 @@ from .util import ( do_bench, - memset_tensor, ) from . import hip @@ -98,7 +97,6 @@ "atomic_min", "atomic_max", "do_bench", - "memset_tensor", "hip", "set_logger_level", "logger", diff --git a/iris/iris.py b/iris/iris.py index dfa63a99..e29d1c58 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -4,13 +4,13 @@ """ Iris: Multi-GPU Communication and Memory Management Framework -Iris is a high-performance framework for multi-GPU communication and memory management, -providing efficient distributed tensor operations, atomic operations, and memory allocation -across multiple GPUs in a cluster. +Iris is a high-performance framework that enables seamless multi-GPU programming in Triton, +enabling fine-grained communication and compute overlap natively in Triton +across multiple GPUs with SHMEM-like Remote Memory Access (RMA) capabilities. Key Features: - Symmetric heap management across multiple GPUs -- High-performance atomic operations (add, sub, cas, xchg, xor, and, or, min, max) +- High-performance atomic operations (add, cas, xchg, xor, and, or, min, max) - Efficient load/store operations with rank-to-rank communication - Memory allocation and deallocation utilities - Built-in logging with rank information @@ -20,7 +20,6 @@ >>> import iris >>> ctx = iris.iris(heap_size=2**30) # 1GB heap >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) - >>> ctx.atomic_add(tensor.data_ptr(), 1.0, 0, 1) """ import triton @@ -62,7 +61,7 @@ class Iris: Example: >>> ctx = iris.iris(heap_size=2**31) # 2GB heap - >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") + >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) """ @@ -138,7 +137,9 @@ def debug(self, message): formatters can display the originating rank and world size. Example: - >>> iris_ctx.debug("Allocating buffers") + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers """ self._log_with_rank(logging.DEBUG, message) @@ -150,7 +151,8 @@ def info(self, message): message (str): Human-readable message to log at info level. Example: - >>> iris_ctx.info("Starting iteration 0") + >>> ctx = iris.iris() + >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 """ self._log_with_rank(logging.INFO, message) @@ -160,6 +162,10 @@ def warning(self, message): Args: message (str): Human-readable message to log at warning level. + + Example: + >>> ctx = iris.iris() + >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high """ self._log_with_rank(logging.WARNING, message) @@ -169,6 +175,10 @@ def error(self, message): Args: message (str): Human-readable message to log at error level. + + Example: + >>> ctx = iris.iris() + >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory """ self._log_with_rank(logging.ERROR, message) @@ -185,8 +195,9 @@ def broadcast(self, value, source_rank): Any: The value broadcast to all ranks. Example: - >>> value = 42 if iris_ctx.get_rank() == 0 else None - >>> value = iris_ctx.broadcast(value, source_rank=0) + >>> ctx = iris.iris() + >>> value = 42 if ctx.cur_rank == 0 else None + >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 """ return distributed_broadcast_scalar(value, source_rank) @@ -233,6 +244,12 @@ def zeros_like( Default: False. memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.preserve_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> input_tensor = ctx.ones(2, 3) + >>> zeros_tensor = ctx.zeros_like(input_tensor) + >>> print(zeros_tensor.shape) # torch.Size([2, 3]) """ self.debug( f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -304,6 +321,11 @@ def arange( device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> print(tensor.shape) # torch.Size([5]) """ self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") @@ -376,6 +398,12 @@ def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.zeros(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') """ self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -466,6 +494,12 @@ def randn( Default: False. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randn(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') """ self.debug( f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -529,6 +563,12 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.ones(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') """ self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -588,6 +628,12 @@ def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, Default: if None, uses the current device for the default tensor type. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.full((2, 3), 3.14) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') """ self.debug( f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -649,6 +695,12 @@ def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): Returns: Tensor: A tensor filled with random numbers from a uniform distribution. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") size, num_elements = self.__parse_size(size) @@ -694,6 +746,11 @@ def empty( Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.contiguous_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.empty(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) """ self.debug( f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -758,6 +815,12 @@ def randint( layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') """ self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") @@ -839,6 +902,11 @@ def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') """ self.debug( f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" @@ -940,6 +1008,12 @@ def rand( Default: False. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ self.debug( f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" @@ -998,6 +1072,11 @@ def get_heap_bases(self): torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` on the Iris device. Pass this to device-side Triton kernels that require heap translation. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> heap_bases = ctx.get_heap_bases() + >>> print(heap_bases.shape) # torch.Size([num_ranks]) """ return self.heap_bases @@ -1010,6 +1089,10 @@ def barrier(self, stream=None): ranks reach the same point before proceeding. Args: stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.barrier() # Synchronize all ranks """ # Wait for all GPUs to finish work if stream is None: @@ -1026,6 +1109,11 @@ def get_device(self): Returns: torch.device: The CUDA device of Iris-managed memory. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> device = ctx.get_device() + >>> print(device) # cuda:0 """ return self.memory_pool.device @@ -1035,6 +1123,11 @@ def get_cu_count(self): Returns: int: Number of compute units on this rank's GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> cu_count = ctx.get_cu_count() + >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs """ return get_cu_count(self.gpu_id) @@ -1044,6 +1137,11 @@ def get_rank(self): Returns: int: Zero-based rank id of the current process. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> rank = ctx.get_rank() + >>> print(f"This is rank {rank}") # This is rank 0 """ return self.cur_rank @@ -1053,6 +1151,11 @@ def get_num_ranks(self): Returns: int: World size (number of ranks). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> num_ranks = ctx.get_num_ranks() + >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 """ return self.num_ranks @@ -1379,6 +1482,15 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): Returns: Block: The loaded value from the target memory location. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Load data from rank 1's memory into the current rank + >>> cur_rank = 0 # Current rank + >>> remote_rank = 1 # Remote rank to load from + >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) + >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) result = tl.load(translated_ptr, mask=mask) @@ -1405,6 +1517,15 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Store value 42 into rank 1's heap from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> value = 42 + >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) tl.store(translated_ptr, value, mask=mask) @@ -1430,6 +1551,13 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) @@ -1457,6 +1585,13 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): Returns: None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) @@ -1487,11 +1622,56 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically add 5 to rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> increment = 5 + >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) +@triton.jit +def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Atomically subtracts data from the specified rank's memory location. + + This function performs an atomic subtraction operation by translating the pointer + from the from_rank's address space to the to_rank's address space and atomically + subtracting the provided data from the to_rank memory location. If the from_rank and to_rank are the same, + this function performs a local atomic subtraction operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local. + val (Block): The tensor of elements to be subtracted atomically. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + + Returns: + Block: The value at the memory location before the atomic subtraction. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically subtract 3 from rank 2's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 2 # Remote rank (destination) + >>> decrement = 3 + >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): """ @@ -1514,6 +1694,16 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop Returns: Block: The value contained at the memory location before the atomic operation attempt. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Compare-and-swap on rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> expected = 0 + >>> new_val = 42 + >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @@ -1541,6 +1731,15 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Exchange value with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_value = 99 + >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1568,6 +1767,15 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically XOR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xFF + >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1595,6 +1803,15 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically AND with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0x0F + >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1622,6 +1839,15 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically OR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xF0 + >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1649,6 +1875,15 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find minimum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 10 + >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1676,6 +1911,15 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None Returns: Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find maximum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 100 + >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) @@ -1690,5 +1934,10 @@ def iris(heap_size=1 << 30): Returns: Iris: An initialized Iris instance. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(2**30) # 1GB heap + >>> tensor = iris_ctx.zeros(1024, 1024) """ return Iris(heap_size) diff --git a/iris/logging.py b/iris/logging.py index 51a27941..3aa5d232 100644 --- a/iris/logging.py +++ b/iris/logging.py @@ -51,5 +51,10 @@ def set_logger_level(level): Args: level: Logging level (iris.DEBUG, iris.INFO, iris.WARNING, iris.ERROR) + + Example: + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("This will now be visible") # [Iris] [0/1] This will now be visible """ logger.setLevel(level) diff --git a/iris/util.py b/iris/util.py index 17da52de..8c861851 100644 --- a/iris/util.py +++ b/iris/util.py @@ -1,6 +1,28 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + import statistics import math import triton @@ -9,8 +31,6 @@ def get_empty_cache_for_benchmark(): - import torch - cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") @@ -20,8 +40,6 @@ def clear_cache(cache): def create_timing_event(): - import torch - return torch.cuda.Event(enable_timing=True) @@ -68,6 +86,28 @@ def do_bench( quantiles=None, return_mode="mean", ): + """ + Benchmark a function by timing its execution. + + Args: + fn (callable): Function to benchmark. + barrier_fn (callable, optional): Function to call for synchronization. Default: no-op. + preamble_fn (callable, optional): Function to call before each execution. Default: no-op. + n_warmup (int, optional): Number of warmup iterations. Default: 25. + n_repeat (int, optional): Number of timing iterations. Default: 100. + quantiles (list, optional): Quantiles to return instead of summary statistic. Default: None. + return_mode (str, optional): Summary statistic to return ("mean", "min", "max", "median", "all"). Default: "mean". + + Returns: + float or list: Timing result(s) in milliseconds. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(1 << 20) + >>> def test_fn(): + >>> tensor = iris_ctx.zeros(1000, 1000) + >>> time_ms = iris.do_bench(test_fn, barrier_fn=iris_ctx.barrier) + """ # Wait for anything that happened before barrier_fn() preamble_fn() @@ -102,27 +142,3 @@ def do_bench( times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] return _summarize_statistics(times, quantiles, return_mode) - - -@triton.jit -def memset_kernel(ptr, value, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - v = tl.full([BLOCK_SIZE], value, dtype=tl.int32) - tl.store(ptr + offsets, v, mask=mask) - - -def memset_tensor(tensor, value): - assert tensor.is_contiguous(), "Tensor must be contiguous" - assert tensor.dtype == torch.int32, "Only torch.int32 tensors are supported" - n_elements = tensor.numel() - BLOCK_SIZE = 1024 - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - memset_kernel[grid]( - tensor, - value, - n_elements, - BLOCK_SIZE=BLOCK_SIZE, - )