@@ -74,6 +74,140 @@ def test_run():
74
74
assert output == fake_results
75
75
76
76
77
+ def test_run_with_local_rank_simple ():
78
+ client = Mock ()
79
+ client .scheduler_info = Mock (return_value = {"workers" : workers })
80
+
81
+ fake_pytorch_func = Mock ()
82
+
83
+ fake_results = []
84
+ worker_keys = sorted (workers .keys ())
85
+ for idx , worker in enumerate (worker_keys ):
86
+ r = Mock ()
87
+ r .result = Mock (return_value = idx )
88
+ fake_results .append (r )
89
+
90
+ client .submit = Mock (side_effect = fake_results )
91
+ output = run (client , fake_pytorch_func , pass_local_rank = True )
92
+
93
+ client .submit .assert_any_call (
94
+ dispatch_with_ddp ,
95
+ pytorch_function = fake_pytorch_func ,
96
+ master_addr = host ,
97
+ master_port = 23456 ,
98
+ rank = 0 ,
99
+ local_rank = 0 ,
100
+ world_size = len (workers ),
101
+ workers = [worker_keys [0 ]],
102
+ backend = "nccl" ,
103
+ )
104
+ client .submit .assert_any_call (
105
+ dispatch_with_ddp ,
106
+ pytorch_function = fake_pytorch_func ,
107
+ master_addr = host ,
108
+ master_port = 23456 ,
109
+ rank = 1 ,
110
+ local_rank = 0 ,
111
+ workers = [worker_keys [1 ]],
112
+ world_size = len (workers ),
113
+ backend = "nccl" ,
114
+ )
115
+ client .submit .assert_any_call (
116
+ dispatch_with_ddp ,
117
+ pytorch_function = fake_pytorch_func ,
118
+ master_addr = host ,
119
+ master_port = 23456 ,
120
+ rank = 2 ,
121
+ local_rank = 0 ,
122
+ workers = [worker_keys [2 ]],
123
+ world_size = len (workers ),
124
+ backend = "nccl" ,
125
+ )
126
+ client .submit .assert_any_call (
127
+ dispatch_with_ddp ,
128
+ pytorch_function = fake_pytorch_func ,
129
+ master_addr = host ,
130
+ master_port = 23456 ,
131
+ rank = 3 ,
132
+ local_rank = 0 ,
133
+ workers = [worker_keys [3 ]],
134
+ world_size = len (workers ),
135
+ backend = "nccl" ,
136
+ )
137
+ assert output == fake_results
138
+
139
+
140
+ def test_run_with_local_rank_complex ():
141
+ workers = {
142
+ "tcp://1.2.3.4:8786" : {"host" : "1.2.3.4" },
143
+ "tcp://1.2.3.4:8787" : {"host" : "1.2.3.4" },
144
+ "tcp://3.2.3.4:8786" : {"host" : "3.2.3.4" },
145
+ "tcp://3.2.3.4:8787" : {"host" : "3.2.3.4" },
146
+ }
147
+ host_name = sorted (workers .keys ())[0 ]
148
+ host = workers [host_name ]["host" ]
149
+ client = Mock ()
150
+ client .scheduler_info = Mock (return_value = {"workers" : workers })
151
+
152
+ fake_pytorch_func = Mock ()
153
+
154
+ fake_results = []
155
+ worker_keys = sorted (workers .keys ())
156
+ for idx , worker in enumerate (worker_keys ):
157
+ r = Mock ()
158
+ r .result = Mock (return_value = idx )
159
+ fake_results .append (r )
160
+
161
+ client .submit = Mock (side_effect = fake_results )
162
+ output = run (client , fake_pytorch_func , pass_local_rank = True )
163
+
164
+ client .submit .assert_any_call (
165
+ dispatch_with_ddp ,
166
+ pytorch_function = fake_pytorch_func ,
167
+ master_addr = host ,
168
+ master_port = 23456 ,
169
+ rank = 0 ,
170
+ local_rank = 0 ,
171
+ world_size = len (workers ),
172
+ workers = [worker_keys [0 ]],
173
+ backend = "nccl" ,
174
+ )
175
+ client .submit .assert_any_call (
176
+ dispatch_with_ddp ,
177
+ pytorch_function = fake_pytorch_func ,
178
+ master_addr = host ,
179
+ master_port = 23456 ,
180
+ rank = 1 ,
181
+ local_rank = 1 ,
182
+ workers = [worker_keys [1 ]],
183
+ world_size = len (workers ),
184
+ backend = "nccl" ,
185
+ )
186
+ client .submit .assert_any_call (
187
+ dispatch_with_ddp ,
188
+ pytorch_function = fake_pytorch_func ,
189
+ master_addr = host ,
190
+ master_port = 23456 ,
191
+ rank = 2 ,
192
+ local_rank = 0 ,
193
+ workers = [worker_keys [2 ]],
194
+ world_size = len (workers ),
195
+ backend = "nccl" ,
196
+ )
197
+ client .submit .assert_any_call (
198
+ dispatch_with_ddp ,
199
+ pytorch_function = fake_pytorch_func ,
200
+ master_addr = host ,
201
+ master_port = 23456 ,
202
+ rank = 3 ,
203
+ local_rank = 1 ,
204
+ workers = [worker_keys [3 ]],
205
+ world_size = len (workers ),
206
+ backend = "nccl" ,
207
+ )
208
+ assert output == fake_results
209
+
210
+
77
211
def test_dispatch_with_ddp ():
78
212
pytorch_func = Mock ()
79
213
0 commit comments