Skip to content

Commit e72704a

Browse files
github-actions[bot]astroC86
authored andcommitted
Apply Ruff auto-fixes
1 parent ad03093 commit e72704a

File tree

1 file changed

+27
-47
lines changed

1 file changed

+27
-47
lines changed

tests/examples/test_load_latency.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,8 @@
99
import numpy as np
1010
import iris
1111
from iris._mpi_helpers import mpi_allgather
12-
# from examples.common.utils import read_realtime
13-
14-
@triton.jit
15-
def read_realtime():
16-
tmp = tl.inline_asm_elementwise(
17-
asm="mov.u64 $0, %globaltimer;",
18-
constraints=("=l"),
19-
args=[],
20-
dtype=tl.int64,
21-
is_pure=False,
22-
pack=1,
23-
)
24-
return tmp
12+
from examples.common.utils import read_realtime
2513

26-
@triton.jit()
27-
def gather_latencies(
28-
local_latency,
29-
global_latency,
30-
curr_rank,
31-
num_ranks ,
32-
BLOCK_SIZE: tl.constexpr,
33-
heap_bases: tl.tensor
34-
):
35-
pid = tl.program_id(0)
36-
block_start = pid * BLOCK_SIZE
37-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
38-
39-
latency_mask = offsets < num_ranks
40-
iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask)
4114

4215
@triton.jit()
4316
def ping_pong(
@@ -66,7 +39,7 @@ def ping_pong(
6639
start = read_realtime()
6740
tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
6841
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
69-
token_first_done = i + 1
42+
token_first_done = i + 1
7043
token_second_done = i + 2
7144
if curr_rank == first_rank:
7245
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
@@ -82,8 +55,9 @@ def ping_pong(
8255
stop = read_realtime()
8356
tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
8457

58+
8559
if __name__ == "__main__":
86-
dtype = torch.int32
60+
dtype = torch.int32
8761
heap_size = 1 << 32
8862
shmem = iris.iris(heap_size)
8963
num_ranks = shmem.get_num_ranks()
@@ -96,42 +70,48 @@ def ping_pong(
9670
iter = 200
9771
skip = 1
9872
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
99-
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
73+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
10074

101-
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
75+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
10276

10377
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
10478
result_buffer = shmem.zeros_like(source_buffer)
105-
flag = shmem.ones(1, dtype=dtype)
79+
flag = shmem.ones(1, dtype=dtype)
10680

10781
grid = lambda meta: (1,)
10882
for source_rank in range(num_ranks):
10983
for destination_rank in range(num_ranks):
11084
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
11185
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112-
ping_pong[grid](source_buffer,
113-
BUFFER_LEN,
114-
skip, iter,
115-
flag,
116-
cur_rank, peer_for_me,
117-
BLOCK_SIZE,
118-
heap_bases,
119-
mm_begin_timestamp,
120-
mm_end_timestamp)
86+
ping_pong[grid](
87+
source_buffer,
88+
BUFFER_LEN,
89+
skip,
90+
iter,
91+
flag,
92+
cur_rank,
93+
peer_for_me,
94+
BLOCK_SIZE,
95+
heap_bases,
96+
mm_begin_timestamp,
97+
mm_end_timestamp,
98+
)
12199
shmem.barrier()
122-
100+
123101
for destination_rank in range(num_ranks):
124-
local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter
125-
102+
local_latency[destination_rank] = (
103+
mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]
104+
) / iter
105+
126106
latency_matrix = mpi_allgather(local_latency.cpu())
127107

128108
if cur_rank == 0:
129-
with open(f"latency.txt", "w") as f:
109+
with open("latency.txt", "w") as f:
130110
f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n")
131111
for i in range(num_ranks):
132112
row_entries = []
133113
for j in range(num_ranks):
134114
val = float(latency_matrix[i, j])
135115
row_entries.append(f"{val:0.6f}")
136116
line = f"R{i}," + ", ".join(row_entries) + "\n"
137-
f.write(line)
117+
f.write(line)

0 commit comments

Comments
 (0)