Skip to content

Improve docs #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Improve docs #49

wants to merge 6 commits into from

Conversation

mawad-amd
Copy link
Collaborator

@mawad-amd mawad-amd commented Jul 13, 2025

First attempt at improving docs. local/remote is still a bit confusing.

Closes #46

  • Translate should be private
  • Improve docs

@Copilot Copilot AI review requested due to automatic review settings July 13, 2025 08:09
@mawad-amd mawad-amd requested review from neoblizz and BKP as code owners July 13, 2025 08:09
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the internal pointer translation helper, tightens up public exports, and enriches docstrings for remote memory operations.

  • Renames translate to private __translate and updates all function signatures to use local_ptr/local_rank/remote_rank conventions
  • Enhances docstrings for load, store, get, put, and atomic operations with detailed parameter and return descriptions
  • Removes translate from public exports in __init__.py

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
iris/iris.py Renamed translate to __translate, updated signatures and improved docstrings
iris/init.py Removed translate from imports and __all__ to make it an internal helper
Comments suppressed due to low confidence (3)

iris/iris.py:318

  • Add a docstring for __translate explaining its purpose, parameters, and return value to maintain consistency with other Triton helper functions.
def __translate(local_ptr, local_rank, remote_rank, heap_bases, debug=False):

iris/init.py:11

  • Removing translate from the public exports is a breaking change; consider deprecating it for one release or updating the changelog to notify users of this API change.
    iris,

iris/iris.py:318

  • The new pointer translation logic in __translate (and its use in load/store/get/put) should have unit tests to verify correctness across different rank configurations.
def __translate(local_ptr, local_rank, remote_rank, heap_bases, debug=False):

@neoblizz
Copy link
Member

Why do we have to describe it as local/remote only? If we use iris.load it should do local-load if your source and destination ranks are the same, correct?

@mawad-amd
Copy link
Collaborator Author

Correct. It is confusing and I would like to resolve that. Do you have suggestions?

I have a few:

# 1. Emphasizes direction of data flow
def load(pointer, to_rank, from_rank, heap_bases, mask=None):

# 2. Uses 'cur' to indicate the calling rank
def load(pointer, cur_rank, dst_rank, heap_bases, mask=None):

# 3. Generalizes roles as caller/target
def load(pointer, caller_rank, target_rank, heap_bases, mask=None):

The pointer argument can be named local_ptr, sym_ptr, address, or pointer.

I do like the from/to at the moment. Here are all APIs:

def load(pointer, to_rank, from_rank, heap_bases, mask=None):

def store(pointer, val, from_rank, to_rank, heap_bases, mask=None):

def get(from_pointer, to_pointer, to_rank, from_rank, heap_bases, mask=None):

def put(from_pointer, to_pointer, from_rank, to_rank, heap_bases, mask=None):

def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None):

def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

@neoblizz
Copy link
Member

neoblizz commented Jul 13, 2025

Correct. It is confusing and I would like to resolve that. Do you have suggestions?

I have a few:

# 1. Emphasizes direction of data flow
def load(pointer, to_rank, from_rank, heap_bases, mask=None):

# 2. Uses 'cur' to indicate the calling rank
def load(pointer, cur_rank, dst_rank, heap_bases, mask=None):

# 3. Generalizes roles as caller/target
def load(pointer, caller_rank, target_rank, heap_bases, mask=None):

The pointer argument can be named local_ptr, sym_ptr, address, or pointer.

I do like the from/to at the moment. Here are all APIs:

def load(pointer, to_rank, from_rank, heap_bases, mask=None):

def store(pointer, val, from_rank, to_rank, heap_bases, mask=None):

def get(from_pointer, to_pointer, to_rank, from_rank, heap_bases, mask=None):

def put(from_pointer, to_pointer, from_rank, to_rank, heap_bases, mask=None):

def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None):

def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

Notes

  • I do like to/from, but not for the from_pointer field.
  • Is there a notion of "triggering" a load/store/put/get from a rank thats not "current" or "remote" rank? If so, current cannot be used.

Other suggestions:

  • Use value instead of val,
  • Use semantics instead of sem
  • The problematic (imo) APIs are:
def get(from_pointer, to_pointer, to_rank, from_rank, heap_bases, mask=None):
def put(from_pointer, to_pointer, from_rank, to_rank, heap_bases, mask=None):

@mawad-amd
Copy link
Collaborator Author

mawad-amd commented Jul 14, 2025

Is there a notion of "triggering" a load/store/put/get from a rank thats not "current" or "remote" rank? If so, current cannot be used.

No. For that reason, I was considering adding a text to warn against that e.g., for load "to_rank must be the same rank issuing the operation." But if there is some an interesting use-case we can accommodate that.
Eventually the "current rank" should be implicit and removed, alongside the heap_bases.

The shortened val and sem are to match Triton but I am okay with either (and I slightly prefer spelling out the complete word).

How about this:

def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def store(pointer, val, from_rank, to_rank, heap_bases, mask=None):
def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None):
def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

and,

def get(dst_ptr, src_ptr, to_rank, from_rank, heap_bases, mask=None):
def put(dst_ptr, src_ptr, to_rank, from_rank, heap_bases, mask=None):

@neoblizz
Copy link
Member

neoblizz commented Jul 14, 2025

No. For that reason, I was considering adding a text to warn against that e.g., for load "to_rank must be the same rank issuing the operation." But if there is some an interesting use-case we can accommodate that.
Eventually the "current rank" should be implicit and removed, alongside the heap_bases.

I am not sure I understand the warning. You mean from_rank must be the same rank issuing the op?
Separately, I think it can be true where thats not the case. Where we can have one GPU thread somehow initiate a copy/move. Lets discuss in a call separately with @BKP.

The shortened val and sem are to match Triton but I am okay with either (and I slightly prefer spelling out the complete word).

How about this:

def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def store(pointer, val, from_rank, to_rank, heap_bases, mask=None):
def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None):
def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):

and,

def get(dst_ptr, src_ptr, to_rank, from_rank, heap_bases, mask=None):
def put(dst_ptr, src_ptr, to_rank, from_rank, heap_bases, mask=None):

I prefer spelled out as well --- triton is not consistent with their naming, you can easily find examples of these.

I like this! (with sem, val, maybe even cmp fully spelled out. I know its minor, but sem could be semaphore or semantics or something else, idk)

Copy link
Member

@neoblizz neoblizz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still lots of remote, lets discuss the best way to word some of the descriptions.

"""
Loads a value from the specified memory location and rank.
Loads a value from a remote rank's memory location.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Loads a value from a remote rank's memory location.
Loads a value stored from a pointer of the specified rank.


This function performs a remote memory read operation by translating the pointer
from the from_rank's address space to the to_rank's address space and loading
data from the remote memory location.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data from the remote memory location.
data from the remote memory location. If `to_rank` is the same as `from_rank`,
this function performs a local load operation instead.

heap_bases (int): The heap bases.
mask (Optional[tl.tensor], optional): A boolean tensor used to guard memory accesses.
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
to_rank (int): The rank ID for the pointer where the load will occur. `to_rank` must be the rank where the pointer resides.

mask (Optional[tl.tensor], optional): A boolean tensor used to guard memory accesses. Defaults to None.
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space.
val (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which the pointer originates. `from_rank` must be the rank where the pointer resides.

"""
Loads a value from the specified memory location and rank.
Copies data from a remote rank's memory to the current rank's local memory.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of the word remote again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update get and put docs
2 participants