-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmulti_processing.py
More file actions
138 lines (109 loc) · 4.25 KB
/
multi_processing.py
File metadata and controls
138 lines (109 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import time
from utils import *
import torch
import torch.multiprocessing as mp
class MultiProcessWorker(mp.Process):
# TODO: Make environment init threadsafe
def __init__(self, id, trainer_maker, comm, seed, *args, **kwargs):
self.id = id
self.seed = seed
super(MultiProcessWorker, self).__init__()
self.trainer = trainer_maker()
self.comm = comm
def run(self):
torch.manual_seed(self.seed + self.id + 1)
np.random.seed(self.seed + self.id + 1)
while True:
task = self.comm.recv()
if type(task) == list:
task, epoch = task
if task == 'quit':
return
elif task == 'run_batch':
batch, stat = self.trainer.run_batch(epoch)
self.trainer.optimizer.zero_grad()
s = self.trainer.compute_grad(batch)
merge_stat(s, stat)
self.comm.send(stat)
elif task == 'send_grads':
grads = []
for p in self.trainer.params:
if p._grad is not None:
grads.append(p._grad.data)
self.comm.send(grads)
elif task == 'get_gpu_mem':
self.comm.send((self.id, self.trainer.get_memory_peak()))
elif task == 'reset_gpu_mem':
self.trainer.reset_memory_peak()
class MultiProcessTrainer(object):
def __init__(self, args, trainer_maker):
self.comms = []
self.trainer = trainer_maker()
# itself will do the same job as workers
self.nworkers = args.nprocesses - 1
for i in range(self.nworkers):
comm, comm_remote = mp.Pipe()
self.comms.append(comm)
worker = MultiProcessWorker(i, trainer_maker, comm_remote, seed=args.seed)
worker.start()
self.grads = None
self.worker_grads = None
self.is_random = args.random
self.reset_mem_peak()
def quit(self):
for comm in self.comms:
comm.send('quit')
def obtain_grad_pointers(self):
# only need perform this once
if self.grads is None:
self.grads = []
for p in self.trainer.params:
if p._grad is not None:
self.grads.append(p._grad.data)
if self.worker_grads is None:
self.worker_grads = []
for comm in self.comms:
comm.send('send_grads')
grad = comm.recv()
self.worker_grads.append(grad)
def train_batch(self, epoch):
self.cpu_memory_peak = np.zeros((self.nworkers + 1))
self.gpu_memory_peak = np.zeros((self.nworkers + 1))
# run workers in parallel
for comm in self.comms:
comm.send('reset_gpu_mem')
comm.send(['run_batch', epoch])
# run its own trainer
batch, stat = self.trainer.run_batch(epoch)
self.trainer.optimizer.zero_grad()
s = self.trainer.compute_grad(batch)
merge_stat(s, stat)
# check if workers are finished
for comm in self.comms:
s = comm.recv()
merge_stat(s, stat)
# add gradients of workers
self.obtain_grad_pointers()
for i in range(len(self.grads)):
for g in self.worker_grads:
self.grads[i] += g[i]
self.grads[i] /= stat['num_steps']
self.trainer.optimizer.step()
for comm in self.comms:
comm.send('get_gpu_mem')
id, mem = comm.recv()
self.update_mem_peak(id + 1, *mem)
return stat, np.array(self.cpu_memory_peak), np.array(self.gpu_memory_peak)
def state_dict(self):
return self.trainer.state_dict()
def load_state_dict(self, state):
self.trainer.load_state_dict(state)
def reset_mem_peak(self):
self.cpu_memory_peak = np.zeros((self.nworkers + 1))
self.gpu_memory_peak = np.zeros((self.nworkers + 1))
self.trainer.reset_memory_peak()
def update_mem_peak(self, id, cpu_mem, gpu_mem):
if cpu_mem > self.cpu_memory_peak[id]:
self.cpu_memory_peak[id] = cpu_mem
if gpu_mem > self.gpu_memory_peak[id]:
self.gpu_memory_peak[id] = gpu_mem