Skip to content

Commit 19229bd

Browse files
committed
update training code.
1 parent 17aaae0 commit 19229bd

31 files changed

+2717
-1790
lines changed

README.md

+34-6
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ You can find a video recording of the demo [here](https://youtu.be/_yEziPl9AkM?t
9494
To get the pre-trained generator, encoder, and editing directions, run:
9595

9696
```python
97-
import model
97+
import models
9898

9999
pretrained_type = 'generator' # choosing from ['generator', 'encoder', 'boundary']
100100
config_name = 'anycost-ffhq-config-f' # replace the config name for other models
101-
model.get_pretrained(pretrained_type, config=config_name)
101+
models.get_pretrained(pretrained_type, config=config_name)
102102
```
103103

104104
We also provide the face attribute classifier (which is general for different generators) for computing the editing directions. You can get it by running:
105105

106106
```python
107-
model.get_pretrained('attribute-predictor')
107+
models.get_pretrained('attribute-predictor')
108108
```
109109

110110
The attribute classifier takes in the face images in FFHQ format.
@@ -114,9 +114,9 @@ The attribute classifier takes in the face images in FFHQ format.
114114
After loading the Anycost generator, we can run it at a wide range of computational costs. For example:
115115

116116
```python
117-
from model.dynamic_channel import set_uniform_channel_ratio, reset_generator
117+
from models.dynamic_channel import set_uniform_channel_ratio, reset_generator
118118

119-
g = model.get_pretrained('generator', config='anycost-ffhq-config-f') # anycost uniform
119+
g = models.get_pretrained('generator', config='anycost-ffhq-config-f') # anycost uniform
120120
set_uniform_channel_ratio(g, 0.5) # set channel
121121
g.target_res = 512 # set resolution
122122
out, _ = g(...) # generate image
@@ -223,7 +223,35 @@ python metrics/eval_encoder.py \
223223

224224
### Training
225225

226-
The training code will be updated shortly.
226+
We provide the scripts to train Anycost GAN on FFHQ dataset.
227+
228+
- Training the original StyleGAN2 on FFHQ
229+
230+
```
231+
horovodrun -np 8 bash scripts/train_stylegan2_ffhq.sh
232+
```
233+
234+
The training of original StyleGAN2 is time-consuming. We recommend downloading the converted checkpoints from [here](https://www.dropbox.com/sh/l8g9amoduz99kjh/AAAY9LYZk2CnsO43ywDrLZpEa?dl=0) and place it under `checkpoint/`.
235+
236+
- Training Anycost GAN: mult-resolution
237+
238+
```
239+
horovodrun -np 8 bash scripts/train_stylegan2_multires_ffhq.sh
240+
```
241+
242+
Note that after each epoch, we evaluate the FIDs of two resolutions (1024&512) to better monitor the training progress. We also apply distillation to accelearte the convergence, which is not used for the ablation in the paper.
243+
244+
- Training Anycost GAN: adaptive-channel
245+
246+
```
247+
horovodrun -np 8 bash scripts/train_stylegan2_multires_adach_ffhq.sh
248+
```
249+
250+
Here we set a longer training epoch for a more stable reproduction, which might not be necessary (depending on the randomness).
251+
252+
253+
254+
**Note**: We trained our models on Titan RTX GPUs with 24GB memory. For GPUs with smaller memory, you may need to reduce the resolution/model size/batch size/etc. and adjust other hyper-parameters accordingly.
227255

228256

229257

demo.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33
import os
44
from PIL import Image
5-
from model.dynamic_channel import set_uniform_channel_ratio, reset_generator
6-
import model
5+
from models.dynamic_channel import set_uniform_channel_ratio, reset_generator
6+
import models
77
import time
88

99
import sys
@@ -109,7 +109,7 @@ def __init__(self):
109109
self.set_text_format(attr_label, 'right', 15)
110110
attr_label.move(520 - 110, 470 + i_slider * 40 + 2)
111111

112-
# build model sliders
112+
# build models sliders
113113
base_h = 560
114114
channel_label = QLabel(self)
115115
channel_label.setText('channel:')
@@ -187,7 +187,7 @@ def load_assets(self):
187187
self.anycost_resolution = 1024
188188

189189
# build the generator
190-
self.generator = model.get_pretrained('generator', config).to(device)
190+
self.generator = models.get_pretrained('generator', config).to(device)
191191
self.generator.eval()
192192
self.mean_latent = self.generator.mean_style(10000)
193193

@@ -213,7 +213,7 @@ def load_assets(self):
213213
'mustache': '22_Mustache',
214214
}
215215

216-
boundaries = model.get_pretrained('boundary', config)
216+
boundaries = models.get_pretrained('boundary', config)
217217
self.direction_dict = dict()
218218
for k, v in direction_map.items():
219219
self.direction_dict[k] = boundaries[v].view(1, 1, -1)

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ dependencies:
2020
- torchprofile==0.0.2
2121
- pyqt5==5.15.2
2222
- horovod==0.21.3
23+
- tensorboard==2.4.1

metrics/__init__.py

Whitespace-only changes.

metrics/attribute_consistency.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import math
66
import torch
77
from tqdm import tqdm
8-
import model
8+
import models
99
import horovod.torch as hvd
10-
from utils import adaptive_resize
10+
from utils.torch_utils import adaptive_resize
1111

1212

1313
def compute_attribute_consistency(g, sub_g, n_sample, batch_size):
14-
attr_pred = model.get_pretrained('attribute-predictor').to(device)
14+
attr_pred = models.get_pretrained('attribute-predictor').to(device)
1515
attr_pred.eval()
1616

1717
n_batch = math.ceil(n_sample * 1. / batch_size / hvd.size())
@@ -56,11 +56,11 @@ def compute_attribute_consistency(g, sub_g, n_sample, batch_size):
5656
hvd.init()
5757
torch.cuda.set_device(hvd.local_rank())
5858

59-
generator = model.get_pretrained('generator', args.config).to(device).eval()
59+
generator = models.get_pretrained('generator', args.config).to(device).eval()
6060

61-
sub_generator = model.get_pretrained('generator', args.config).to(device).eval()
61+
sub_generator = models.get_pretrained('generator', args.config).to(device).eval()
6262
if args.channel_ratio:
63-
from model.dynamic_channel import set_uniform_channel_ratio
63+
from models.dynamic_channel import set_uniform_channel_ratio
6464
set_uniform_channel_ratio(sub_generator, args.channel_ratio)
6565

6666
if args.target_res is not None:

metrics/eval_encoder.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import torch
77
import torch.nn as nn
88
from tqdm import tqdm
9-
import model
10-
from utils import adaptive_resize
9+
import models
10+
from utils.torch_utils import adaptive_resize
1111
from thirdparty.celeba_hq_split import get_celeba_hq_split
1212
from torchvision import transforms
1313
import lpips
14-
from utils import AverageMeter
15-
from model.dynamic_channel import set_uniform_channel_ratio, remove_sub_channel_config
16-
from utils import NativeDataset
14+
from utils.torch_utils import AverageMeter
15+
from models.dynamic_channel import set_uniform_channel_ratio, remove_sub_channel_config
16+
from utils.datasets import NativeDataset
1717

1818

1919
def validate():
@@ -89,10 +89,10 @@ def validate():
8989
args = parser.parse_args()
9090

9191
# build models
92-
generator = model.get_pretrained('generator', args.config).to(device)
92+
generator = models.get_pretrained('generator', args.config).to(device)
9393
generator.eval()
9494

95-
encoder = model.get_pretrained('encoder', args.config).to(device)
95+
encoder = models.get_pretrained('encoder', args.config).to(device)
9696
encoder.eval()
9797

9898
# build test dataset

metrics/fid.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from scipy import linalg
99
from tqdm import tqdm
10-
import model
10+
import models
1111

1212

1313
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
@@ -79,12 +79,12 @@ def compute_fid():
7979
hvd.init()
8080
torch.cuda.set_device(hvd.local_rank())
8181

82-
generator = model.get_pretrained('generator', args.config).to(device)
82+
generator = models.get_pretrained('generator', args.config).to(device)
8383
generator.eval()
8484

8585
# set sub-generator
8686
if args.channel_ratio:
87-
from model.dynamic_channel import set_uniform_channel_ratio, CHANNEL_CONFIGS
87+
from models.dynamic_channel import set_uniform_channel_ratio, CHANNEL_CONFIGS
8888

8989
assert args.channel_ratio in CHANNEL_CONFIGS
9090
set_uniform_channel_ratio(generator, args.channel_ratio)
@@ -103,7 +103,7 @@ def compute_fid():
103103
except:
104104
print(' * Profiling failed. Passed.')
105105

106-
inception = model.get_pretrained('inception').to(device)
106+
inception = models.get_pretrained('inception').to(device)
107107
inception.eval()
108108

109109
inception_features = extract_feature_from_samples()

metrics/ppl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from tqdm import tqdm
1010
import lpips
11-
import model
11+
import models
1212
import horovod.torch as hvd
1313

1414

@@ -108,11 +108,11 @@ def compute_ppl(g, n_sample, batch_size, space='w', sampling='end', eps=1e-4, cr
108108
hvd.init()
109109
torch.cuda.set_device(hvd.local_rank())
110110

111-
generator = model.get_pretrained('generator', args.config).to(device)
111+
generator = models.get_pretrained('generator', args.config).to(device)
112112
generator.eval()
113113

114114
if args.channel_ratio:
115-
from model.dynamic_channel import set_uniform_channel_ratio
115+
from models.dynamic_channel import set_uniform_channel_ratio
116116
set_uniform_channel_ratio(generator, args.channel_ratio)
117117

118118
if args.target_res is not None:

model/dynamic_channel.py

-78
This file was deleted.

model/__init__.py models/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .anycost_gan import Generator
22
import torch
33
from torchvision import models
4-
from utils import safe_load_state_dict_from_url
4+
from utils.torch_utils import safe_load_state_dict_from_url
55

66
URL_TEMPLATE = 'https://hanlab.mit.edu/projects/anycost-gan/files/{}_{}.pt'
77

@@ -44,7 +44,7 @@ def get_pretrained(model, config=None):
4444
style_dim = 512
4545
else:
4646
raise NotImplementedError
47-
from model.encoder import ResNet50Encoder
47+
from models.encoder import ResNet50Encoder
4848
model = ResNet50Encoder(n_style=n_style, style_dim=style_dim)
4949
model.load_state_dict(load_state_dict_from_url(url, 'state_dict'))
5050
return model
@@ -53,7 +53,7 @@ def get_pretrained(model, config=None):
5353
predictor.fc = torch.nn.Linear(predictor.fc.in_features, 40 * 2)
5454
predictor.load_state_dict(load_state_dict_from_url(url, 'state_dict'))
5555
return predictor
56-
elif model == 'inception': # inception model
56+
elif model == 'inception': # inception models
5757
from thirdparty.inception import InceptionV3
5858
return InceptionV3([3], normalize_input=False, resize_input=True)
5959
elif model == 'boundary':

model/anycost_gan.py models/anycost_gan.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn
66

7-
from model.ops import *
7+
from models.ops import *
88

99
G_CHANNEL_CONFIG = {
1010
4: 4096,
@@ -214,8 +214,8 @@ def __init__(self, resolution, channel_multiplier=2, channel_max=512, blur_kerne
214214
EqualLinear(channels[4], 1),
215215
)
216216

217-
def forward(self, input):
218-
out = self.convs(input)
217+
def forward(self, x):
218+
out = self.convs(x)
219219

220220
batch, channel, height, width = out.shape
221221
group = min(batch, self.stddev_group)
@@ -237,7 +237,7 @@ def forward(self, input):
237237

238238
class DiscriminatorMultiRes(nn.Module):
239239
def __init__(self, resolution, channel_multiplier=2, channel_max=512, blur_kernel=(1, 3, 3, 1), act_func='lrelu',
240-
n_res=1):
240+
n_res=1, modulate=False):
241241
super().__init__()
242242

243243
channels = {k: min(channel_max, int(v * channel_multiplier)) for k, v in D_CHANNEL_CONFIG.items()}
@@ -255,7 +255,11 @@ def __init__(self, resolution, channel_multiplier=2, channel_max=512, blur_kerne
255255
self.blocks = nn.ModuleList()
256256
for i in range(log_res, 2, -1):
257257
out_channel = channels[2 ** (i - 1)] # the out channel corresponds to a lower resolution
258-
self.blocks.append(ResBlock(in_channel, out_channel, blur_kernel, act_func=act_func))
258+
self.blocks.append(
259+
ResBlock(in_channel, out_channel, blur_kernel, act_func=act_func,
260+
modulate=modulate and i in list(range(log_res, 2, -1))[-2:], # add g_arch modulation
261+
g_arch_len=4 * (log_res * 2 - 2))
262+
)
259263
in_channel = out_channel
260264

261265
self.stddev_group = 4
@@ -267,12 +271,12 @@ def __init__(self, resolution, channel_multiplier=2, channel_max=512, blur_kerne
267271
EqualLinear(channels[4], 1),
268272
)
269273

270-
def forward(self, x):
274+
def forward(self, x, g_arch=None):
271275
res = x.shape[-1]
272276
idx = self.res2idx[res]
273277
out = self.convs[idx](x)
274278
for i in range(idx, len(self.blocks)):
275-
out = self.blocks[i](out)
279+
out = self.blocks[i](out, g_arch)
276280

277281
out = self.minibatch_discrimination(out, self.stddev_group, self.stddev_feat)
278282
out = self.final_conv(out).view(out.shape[0], -1)

0 commit comments

Comments
 (0)