9
9
import numpy as np
10
10
import iris
11
11
from iris ._mpi_helpers import mpi_allgather
12
- # from examples.common.utils import read_realtime
13
-
14
- @triton .jit
15
- def read_realtime ():
16
- tmp = tl .inline_asm_elementwise (
17
- asm = "mov.u64 $0, %globaltimer;" ,
18
- constraints = ("=l" ),
19
- args = [],
20
- dtype = tl .int64 ,
21
- is_pure = False ,
22
- pack = 1 ,
23
- )
24
- return tmp
12
+ from examples .common .utils import read_realtime
25
13
26
- @triton .jit ()
27
- def gather_latencies (
28
- local_latency ,
29
- global_latency ,
30
- curr_rank ,
31
- num_ranks ,
32
- BLOCK_SIZE : tl .constexpr ,
33
- heap_bases : tl .tensor
34
- ):
35
- pid = tl .program_id (0 )
36
- block_start = pid * BLOCK_SIZE
37
- offsets = block_start + tl .arange (0 , BLOCK_SIZE )
38
-
39
- latency_mask = offsets < num_ranks
40
- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
41
14
42
15
@triton .jit ()
43
16
def ping_pong (
@@ -66,7 +39,7 @@ def ping_pong(
66
39
start = read_realtime ()
67
40
tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
68
41
first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69
- token_first_done = i + 1
42
+ token_first_done = i + 1
70
43
token_second_done = i + 2
71
44
if curr_rank == first_rank :
72
45
iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
@@ -82,8 +55,9 @@ def ping_pong(
82
55
stop = read_realtime ()
83
56
tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
84
57
58
+
85
59
if __name__ == "__main__" :
86
- dtype = torch .int32
60
+ dtype = torch .int32
87
61
heap_size = 1 << 32
88
62
shmem = iris .iris (heap_size )
89
63
num_ranks = shmem .get_num_ranks ()
@@ -96,42 +70,48 @@ def ping_pong(
96
70
iter = 200
97
71
skip = 1
98
72
mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99
- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
73
+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
100
74
101
- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
75
+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
102
76
103
77
source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
104
78
result_buffer = shmem .zeros_like (source_buffer )
105
- flag = shmem .ones (1 , dtype = dtype )
79
+ flag = shmem .ones (1 , dtype = dtype )
106
80
107
81
grid = lambda meta : (1 ,)
108
82
for source_rank in range (num_ranks ):
109
83
for destination_rank in range (num_ranks ):
110
84
if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
111
85
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112
- ping_pong [grid ](source_buffer ,
113
- BUFFER_LEN ,
114
- skip , iter ,
115
- flag ,
116
- cur_rank , peer_for_me ,
117
- BLOCK_SIZE ,
118
- heap_bases ,
119
- mm_begin_timestamp ,
120
- mm_end_timestamp )
86
+ ping_pong [grid ](
87
+ source_buffer ,
88
+ BUFFER_LEN ,
89
+ skip ,
90
+ iter ,
91
+ flag ,
92
+ cur_rank ,
93
+ peer_for_me ,
94
+ BLOCK_SIZE ,
95
+ heap_bases ,
96
+ mm_begin_timestamp ,
97
+ mm_end_timestamp ,
98
+ )
121
99
shmem .barrier ()
122
-
100
+
123
101
for destination_rank in range (num_ranks ):
124
- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125
-
102
+ local_latency [destination_rank ] = (
103
+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
104
+ ) / iter
105
+
126
106
latency_matrix = mpi_allgather (local_latency .cpu ())
127
107
128
108
if cur_rank == 0 :
129
- with open (f "latency.txt" , "w" ) as f :
109
+ with open ("latency.txt" , "w" ) as f :
130
110
f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131
111
for i in range (num_ranks ):
132
112
row_entries = []
133
113
for j in range (num_ranks ):
134
114
val = float (latency_matrix [i , j ])
135
115
row_entries .append (f"{ val :0.6f} " )
136
116
line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137
- f .write (line )
117
+ f .write (line )
0 commit comments