Skip to content

Commit 48b72ce

Browse files
committed
Docstring updates
1 parent 4405774 commit 48b72ce

File tree

6 files changed

+28
-15
lines changed

6 files changed

+28
-15
lines changed

helion/_compiler/helper_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def create_combine_function_wrapper(
8080
Args:
8181
combine_fn: The original combine function
8282
is_tuple_input: Whether the input is a tuple
83-
target_format: Either 'tuple' or 'unpacked' format
84-
- 'tuple': expects (left_tuple, right_tuple) for tuple inputs
85-
- 'unpacked': expects (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) for tuple inputs
83+
target_format: Either 'tuple' or 'unpacked'. The 'tuple' option expects
84+
(left_tuple, right_tuple) inputs, while 'unpacked' expects
85+
(left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) inputs
8686
8787
Returns:
8888
A wrapper function that converts between the formats

helion/autotuner/base_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,7 @@ def wait_for_all(
863863
864864
Args:
865865
futures: A list of PrecompileFuture objects.
866+
desc: Optional description used for the progress display.
866867
867868
Returns:
868869
A list of boolean values indicating completion status.

helion/language/creation_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def zeros(
3636
Args:
3737
shape: A list of sizes (or tile indices which are implicitly converted to sizes)
3838
dtype: Data type of the tensor (default: torch.float32)
39+
device: Device must match the current compile environment device
3940
4041
Returns:
4142
torch.Tensor: A device tensor of the given shape and dtype filled with zeros
@@ -82,6 +83,7 @@ def full(
8283
shape: A list of sizes (or tile indices which are implicitly converted to sizes)
8384
value: The value to fill the tensor with
8485
dtype: The data type of the tensor (default: torch.float32)
86+
device: Device must match the current compile environment device
8587
8688
Returns:
8789
torch.Tensor: A device tensor of the given shape and dtype filled with value
@@ -192,9 +194,10 @@ def arange(
192194
automatically using the current kernel's device and index dtype.
193195
194196
Args:
195-
*args: Variable arguments passed to torch.arange(start, end, step).
197+
args: Positional arguments passed to torch.arange(start, end, step).
196198
dtype: Data type of the result tensor (defaults to kernel's index dtype)
197-
**kwargs: Additional keyword arguments passed to torch.arange
199+
device: Device must match the current compile environment device
200+
kwargs: Additional keyword arguments passed to torch.arange
198201
199202
Returns:
200203
torch.Tensor: 1D tensor containing the sequence

helion/language/random_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def rand(
3232
Args:
3333
shape: A list of sizes for the output tensor
3434
seed: A single element int64 tensor or int literal
35+
device: Device must match the current compile environment device
3536
3637
Returns:
3738
torch.Tensor: A device tensor of float32 dtype filled with uniform random values in [0, 1)

helion/runtime/kernel.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def autotune(
268268
Args:
269269
args: Example arguments used for benchmarking during autotuning.
270270
force: If True, force full autotuning even if a config is provided.
271-
**options: Additional options for autotuning.
271+
options: Additional keyword options forwarded to the autotuner.
272272
273273
Returns:
274274
Config: The best configuration found during autotuning.
@@ -490,7 +490,7 @@ def autotune(
490490
Args:
491491
args: Example arguments used for benchmarking during autotuning.
492492
force: If True, force full autotuning even if a config is provided.
493-
**kwargs: Additional options for autotuning.
493+
kwargs: Additional keyword options forwarded to the autotuner.
494494
495495
Returns:
496496
Config: The best configuration found during autotuning.
@@ -678,13 +678,16 @@ def kernel(
678678
679679
Args:
680680
fn: The function to be wrapped by the Kernel. If None, a decorator is returned.
681-
config: A single configuration to use for the kernel. See :class:`~helion.Config` for details.
682-
configs: A list of configurations to use for the kernel. Can only specify one of config or configs.
683-
See :class:`~helion.Config` for details.
681+
config: A single configuration to use for the kernel. Refer to the
682+
``helion.Config`` class for details.
683+
configs: A list of configurations to use for the kernel. Can only specify
684+
one of config or configs. Refer to the ``helion.Config`` class for
685+
details.
684686
key: Optional callable returning a hashable that augments the specialization key.
685687
settings: Keyword arguments representing settings for the Kernel.
686-
Can also use settings=Settings(...) to pass a Settings object directly.
687-
See :class:`~helion.Settings` for available options.
688+
Can also use settings=Settings(...) to pass a Settings object
689+
directly. Refer to the ``helion.Settings`` class for available
690+
options.
688691
689692
Returns:
690693
object: A Kernel object or a decorator that returns a Kernel object.

helion/runtime/triton_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,12 @@ def triton_wait_multiple_signal(
138138
sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType]
139139
) -> None:
140140
"""
141-
Simultenuoslly wait for multiple global memory barrier to reach the expected value.
141+
Simultaneously wait for multiple global memory barriers to reach the
142+
expected value.
142143
143-
This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.
144+
Each thread in a CTA spin-waits and continuously checks its assigned memory
145+
location until it reaches the expected value, providing synchronization
146+
across CTAs.
144147
145148
Args:
146149
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
@@ -149,7 +152,9 @@ def triton_wait_multiple_signal(
149152
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
150153
scope: Scope of the atomic operation. Options: "gpu", "sys"
151154
op: Atomic operation type: "ld", "atomic_cas"
152-
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
155+
skip_sync: Skip CTA synchronization after acquiring the barrier
156+
(default False).
157+
sync_before: Add a CTA sync before the wait (default False)
153158
"""
154159
tl.static_assert(
155160
(sem == "acquire" or sem == "relaxed") or sem == "release",

0 commit comments

Comments
 (0)