-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
269 lines (236 loc) · 10.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
------------------------------------------------------------------------
Modified from HumanSD (https://github.com/IDEA-Research/HumanSD/tree/main)
------------------------------------------------------------------------
"""
import json
import cv2
import numpy as np
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset
import os
import argparse
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from PIL import Image
from ldm.util import instantiate_from_config, load_model_from_config
from cldm.utils import ImageLogger, CUDACallback, save_configs, load_state_dict
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config",
type=str,
help="path to config which constructs model")
parser.add_argument("--max_epochs",
type=int,
default=10,
help="how many samples to produce for each given prompt. A.k.a. batch size")
parser.add_argument("--devices",
type=int,
default=1,
help="how many gpus to train on")
parser.add_argument("-r",
"--resume",
type=str,
nargs="?",
const=True,
help="resume from checkpoint")
parser.add_argument("-s",
"--seed",
type=int,
default=23,
help="seed for seed_everything")
parser.add_argument("--log_frequency",
type=int,
default=300,
help="log images every certain steps")
parser.add_argument("--scale_lr",
type=str2bool,
nargs="?",
const=True,
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate")
# argument for ControlNet only
parser.add_argument("--control_ckpt",
type=str,
default=None,
help="path to the pre-generated model, please see tool_add_control.py in https://github.com/lllyasviel/ControlNet/tree/main")
parser.add_argument("--sd_locked",
default=True,
type=str2bool,
help="freeze SD decoder layers")
parser.add_argument("--only_mid_control",
default=False,
type=str2bool,
help="output of controlnet is only added to middle SD block")
parser.add_argument("--config_metrics",
type=str,
default="utils/metrics/mini_metrics.yaml",
help="path to config evaluation metrics, used in validation step")
opt = parser.parse_args()
seed_everything(opt.seed)
config = OmegaConf.load(opt.config)
run_name, model_name = config.name, opt.config.split('/')[-2]
print(f'training model {model_name}')
if not os.path.exists(os.path.join('experiments', model_name)):
os.mkdir(os.path.join('experiments', model_name))
# Configs
max_epochs = opt.max_epochs
logger_freq = opt.log_frequency
batch_size = config.data.params.batch_size
learning_rate = config.model.learning_rate
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
trainer_config["accelerator"] = "gpu"
trainer_config["max_epochs"] = max_epochs
trainer_config["devices"] = opt.devices
# check if resume
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
trainer_config["resume_from_checkpoint"] = opt.resume
trainer_opt = argparse.Namespace(**trainer_config)
# define model
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = instantiate_from_config(config.model).cpu()
assert opt.control_ckpt is not None, 'please specify the control_ckpt argument, see tool_add_control.py in https://github.com/lllyasviel/ControlNet/tree/main'
m, u = model.load_state_dict(load_state_dict(opt.control_ckpt, location='cpu'), strict=False)
if len(m) > 0:
print("missing keys:")
print(m)
if len(u) > 0:
print("unexpected keys:")
print(u)
model.sd_locked = opt.sd_locked
model.only_mid_control = opt.only_mid_control
model.learning_rate = learning_rate
# define dataset
train_set = instantiate_from_config(config.data.params.train)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
# the save directory already exists
if os.path.exists(os.path.join('experiments', model_name, run_name)) and not opt.resume:
import time
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
run_name = run_name + '_' + current_time
print(f'Warnning: Run name already exists in experiments! Add time: {current_time} to run name.')
output_path = os.path.join('experiments', model_name, run_name)
if not os.path.exists(output_path):
os.mkdir(output_path)
# save missing/unexpected keys
unloaded_keys = {'missing_keys':m, 'unexpected_keys':u}
save_json_path = os.path.join(output_path, 'unloaded_keys.json')
with open(save_json_path, "w") as outfile:
json.dump(unloaded_keys, outfile)
# define callbacks
# checkpoints callback
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": output_path,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
# other callbacks
default_callbacks_cfg = {
"img_logger": {
"target": "train.ImageLogger",
"params": {
"batch_frequency": logger_freq,
"run_name": run_name
}
},
"cuda_callback": {
"target": "train.CUDACallback"
},
"learning_rate_logger": {
"target": "train.LearningRateMonitor",
"params": {
"logging_interval": "step",
}
},
}
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint':
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(output_path, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 50000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
# define trainer
trainer_kwargs = dict()
tb_logger = TensorBoardLogger(os.path.join("experiments", model_name), name=run_name)
trainer_kwargs["logger"] = tb_logger
trainer_kwargs["callbacks"] = callbacks
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
# configure learning rate
if opt.scale_lr:
bs, base_lr = config.data.params.batch_size, config.model.learning_rate
ngpu = opt.devices
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = config.model.learning_rate
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# Train!
save_configs([config, callbacks_cfg], output_path, opt.config)
trainer.fit(model, train_loader)
print('training done.')
# save dict
print(f'saving model to {output_path}')
torch.save(model.state_dict(), os.path.join(output_path, 'final.pth'))
if __name__ == "__main__":
main()