Skip to content

Commit 76cf3a5

Browse files
authored
Merge pull request #14 from saturncloud/feat/local-rank-support
feat/local rank support
2 parents d9a6e4f + a5b65f8 commit 76cf3a5

File tree

2 files changed

+198
-23
lines changed

2 files changed

+198
-23
lines changed

dask_pytorch_ddp/dispatch.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,88 @@
33
"""
44

55
import os
6-
from typing import List, Callable, Tuple, Any
6+
from typing import List, Callable, Any, Dict
77
from dask.distributed import Client
88
import torch.distributed as dist
99

1010

11-
def _get_worker_info(client: Client) -> Tuple[List[str], str]:
11+
def _get_worker_info(client: Client) -> List[Dict]:
1212
"""
1313
returns a list of workers (sorted), and the DNS name for the master host
1414
The master is the 0th worker's host
1515
"""
1616
workers = client.scheduler_info()["workers"]
1717
worker_keys = sorted(workers.keys())
18+
workers_by_host: Dict[str, List[str]] = {}
19+
for key in worker_keys:
20+
worker = workers[key]
21+
host = worker["host"]
22+
workers_by_host.setdefault(host, []).append(key)
1823
host = workers[worker_keys[0]]["host"]
19-
return worker_keys, host
24+
all_workers = []
25+
global_rank = 0
26+
for host in sorted(workers_by_host.keys()):
27+
local_rank = 0
28+
for worker in workers_by_host[host]:
29+
all_workers.append(
30+
dict(
31+
worker=worker,
32+
local_rank=local_rank,
33+
global_rank=global_rank,
34+
host=host,
35+
)
36+
)
37+
local_rank += 1
38+
global_rank += 1
39+
return all_workers
2040

2141

22-
def run(client: Client, pytorch_function: Callable, *args, backend: str = "nccl", **kwargs):
42+
def run(
43+
client: Client,
44+
pytorch_function: Callable,
45+
*args,
46+
backend: str = "nccl",
47+
pass_local_rank: bool = False,
48+
**kwargs
49+
):
2350
"""
2451
Dispatch a pytorch function over a dask cluster, and returns a list of futures
2552
for the resulting tasks
2653
"""
27-
worker_keys, host = _get_worker_info(client)
28-
world_size = len(worker_keys)
54+
all_workers = _get_worker_info(client)
55+
world_size = len(all_workers)
2956
port = 23456 # pick a free port?
30-
31-
futures = [
32-
client.submit(
33-
dispatch_with_ddp,
34-
pytorch_function=pytorch_function,
35-
master_addr=host,
36-
master_port=port,
37-
rank=idx,
38-
world_size=world_size,
39-
*args,
40-
backend=backend,
41-
workers=[w],
42-
**kwargs
43-
)
44-
for idx, w in enumerate(worker_keys)
45-
]
46-
57+
host = all_workers[0]["host"]
58+
futures = []
59+
for worker in all_workers:
60+
if pass_local_rank:
61+
fut = client.submit(
62+
dispatch_with_ddp,
63+
pytorch_function=pytorch_function,
64+
master_addr=host,
65+
master_port=port,
66+
rank=worker["global_rank"],
67+
world_size=world_size,
68+
*args,
69+
local_rank=worker["local_rank"],
70+
backend=backend,
71+
workers=[worker["worker"]],
72+
**kwargs
73+
)
74+
else:
75+
fut = client.submit(
76+
dispatch_with_ddp,
77+
pytorch_function=pytorch_function,
78+
master_addr=host,
79+
master_port=port,
80+
rank=worker["global_rank"],
81+
world_size=world_size,
82+
*args,
83+
backend=backend,
84+
workers=[worker["worker"]],
85+
**kwargs
86+
)
87+
futures.append(fut)
4788
return futures
4889

4990

tests/test_dispatch.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,140 @@ def test_run():
7474
assert output == fake_results
7575

7676

77+
def test_run_with_local_rank_simple():
78+
client = Mock()
79+
client.scheduler_info = Mock(return_value={"workers": workers})
80+
81+
fake_pytorch_func = Mock()
82+
83+
fake_results = []
84+
worker_keys = sorted(workers.keys())
85+
for idx, worker in enumerate(worker_keys):
86+
r = Mock()
87+
r.result = Mock(return_value=idx)
88+
fake_results.append(r)
89+
90+
client.submit = Mock(side_effect=fake_results)
91+
output = run(client, fake_pytorch_func, pass_local_rank=True)
92+
93+
client.submit.assert_any_call(
94+
dispatch_with_ddp,
95+
pytorch_function=fake_pytorch_func,
96+
master_addr=host,
97+
master_port=23456,
98+
rank=0,
99+
local_rank=0,
100+
world_size=len(workers),
101+
workers=[worker_keys[0]],
102+
backend="nccl",
103+
)
104+
client.submit.assert_any_call(
105+
dispatch_with_ddp,
106+
pytorch_function=fake_pytorch_func,
107+
master_addr=host,
108+
master_port=23456,
109+
rank=1,
110+
local_rank=0,
111+
workers=[worker_keys[1]],
112+
world_size=len(workers),
113+
backend="nccl",
114+
)
115+
client.submit.assert_any_call(
116+
dispatch_with_ddp,
117+
pytorch_function=fake_pytorch_func,
118+
master_addr=host,
119+
master_port=23456,
120+
rank=2,
121+
local_rank=0,
122+
workers=[worker_keys[2]],
123+
world_size=len(workers),
124+
backend="nccl",
125+
)
126+
client.submit.assert_any_call(
127+
dispatch_with_ddp,
128+
pytorch_function=fake_pytorch_func,
129+
master_addr=host,
130+
master_port=23456,
131+
rank=3,
132+
local_rank=0,
133+
workers=[worker_keys[3]],
134+
world_size=len(workers),
135+
backend="nccl",
136+
)
137+
assert output == fake_results
138+
139+
140+
def test_run_with_local_rank_complex():
141+
workers = {
142+
"tcp://1.2.3.4:8786": {"host": "1.2.3.4"},
143+
"tcp://1.2.3.4:8787": {"host": "1.2.3.4"},
144+
"tcp://3.2.3.4:8786": {"host": "3.2.3.4"},
145+
"tcp://3.2.3.4:8787": {"host": "3.2.3.4"},
146+
}
147+
host_name = sorted(workers.keys())[0]
148+
host = workers[host_name]["host"]
149+
client = Mock()
150+
client.scheduler_info = Mock(return_value={"workers": workers})
151+
152+
fake_pytorch_func = Mock()
153+
154+
fake_results = []
155+
worker_keys = sorted(workers.keys())
156+
for idx, worker in enumerate(worker_keys):
157+
r = Mock()
158+
r.result = Mock(return_value=idx)
159+
fake_results.append(r)
160+
161+
client.submit = Mock(side_effect=fake_results)
162+
output = run(client, fake_pytorch_func, pass_local_rank=True)
163+
164+
client.submit.assert_any_call(
165+
dispatch_with_ddp,
166+
pytorch_function=fake_pytorch_func,
167+
master_addr=host,
168+
master_port=23456,
169+
rank=0,
170+
local_rank=0,
171+
world_size=len(workers),
172+
workers=[worker_keys[0]],
173+
backend="nccl",
174+
)
175+
client.submit.assert_any_call(
176+
dispatch_with_ddp,
177+
pytorch_function=fake_pytorch_func,
178+
master_addr=host,
179+
master_port=23456,
180+
rank=1,
181+
local_rank=1,
182+
workers=[worker_keys[1]],
183+
world_size=len(workers),
184+
backend="nccl",
185+
)
186+
client.submit.assert_any_call(
187+
dispatch_with_ddp,
188+
pytorch_function=fake_pytorch_func,
189+
master_addr=host,
190+
master_port=23456,
191+
rank=2,
192+
local_rank=0,
193+
workers=[worker_keys[2]],
194+
world_size=len(workers),
195+
backend="nccl",
196+
)
197+
client.submit.assert_any_call(
198+
dispatch_with_ddp,
199+
pytorch_function=fake_pytorch_func,
200+
master_addr=host,
201+
master_port=23456,
202+
rank=3,
203+
local_rank=1,
204+
workers=[worker_keys[3]],
205+
world_size=len(workers),
206+
backend="nccl",
207+
)
208+
assert output == fake_results
209+
210+
77211
def test_dispatch_with_ddp():
78212
pytorch_func = Mock()
79213

0 commit comments

Comments
 (0)