-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
47b5663
commit ff71a6a
Showing
16 changed files
with
2,573 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import argparse | ||
import pickle | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.models import inception_v3, Inception3 | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from inception import InceptionV3 | ||
from dataset import MultiResolutionDataset | ||
|
||
|
||
class Inception3Feature(Inception3): | ||
def forward(self, x): | ||
if x.shape[2] != 299 or x.shape[3] != 299: | ||
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) | ||
|
||
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 | ||
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 | ||
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 | ||
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 | ||
|
||
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 | ||
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 | ||
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 | ||
|
||
x = self.Mixed_5b(x) # 35 x 35 x 192 | ||
x = self.Mixed_5c(x) # 35 x 35 x 256 | ||
x = self.Mixed_5d(x) # 35 x 35 x 288 | ||
|
||
x = self.Mixed_6a(x) # 35 x 35 x 288 | ||
x = self.Mixed_6b(x) # 17 x 17 x 768 | ||
x = self.Mixed_6c(x) # 17 x 17 x 768 | ||
x = self.Mixed_6d(x) # 17 x 17 x 768 | ||
x = self.Mixed_6e(x) # 17 x 17 x 768 | ||
|
||
x = self.Mixed_7a(x) # 17 x 17 x 768 | ||
x = self.Mixed_7b(x) # 8 x 8 x 1280 | ||
x = self.Mixed_7c(x) # 8 x 8 x 2048 | ||
|
||
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 | ||
|
||
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 | ||
|
||
|
||
def load_patched_inception_v3(): | ||
# inception = inception_v3(pretrained=True) | ||
# inception_feat = Inception3Feature() | ||
# inception_feat.load_state_dict(inception.state_dict()) | ||
inception_feat = InceptionV3([3], normalize_input=False) | ||
|
||
return inception_feat | ||
|
||
|
||
@torch.no_grad() | ||
def extract_features(loader, inception, device): | ||
pbar = tqdm(loader) | ||
|
||
feature_list = [] | ||
|
||
for img in pbar: | ||
img = img.to(device) | ||
feature = inception(img)[0].view(img.shape[0], -1) | ||
feature_list.append(feature.to('cpu')) | ||
|
||
features = torch.cat(feature_list, 0) | ||
|
||
return features | ||
|
||
|
||
if __name__ == '__main__': | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Calculate Inception v3 features for datasets' | ||
) | ||
parser.add_argument('--size', type=int, default=256) | ||
parser.add_argument('--batch', default=64, type=int, help='batch size') | ||
parser.add_argument('--n_sample', type=int, default=50000) | ||
parser.add_argument('--flip', action='store_true') | ||
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') | ||
|
||
args = parser.parse_args() | ||
|
||
inception = load_patched_inception_v3() | ||
inception = nn.DataParallel(inception).eval().to(device) | ||
|
||
transform = transforms.Compose( | ||
[ | ||
#transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | ||
] | ||
) | ||
|
||
dset = MultiResolutionDataset(args.path) | ||
loader = DataLoader(dset, batch_size=args.batch, num_workers=4) | ||
|
||
features = extract_features(loader, inception, device).numpy() | ||
|
||
features = features[: args.n_sample] | ||
|
||
print('extracted {} features'.format(features.shape[0])) | ||
|
||
mean = np.mean(features, 0) | ||
cov = np.cov(features, rowvar=False) | ||
|
||
name = os.path.splitext(os.path.basename(args.path))[0] | ||
|
||
with open('inception_{}.pkl'.format(name), 'wb') as f: | ||
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import io | ||
import lmdb | ||
from PIL import Image | ||
|
||
import torch | ||
from torchvision import transforms | ||
import random | ||
|
||
_transform = transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),] | ||
) | ||
|
||
class MultiResolutionDataset(torch.utils.data.Dataset): | ||
def __init__(self, path): | ||
self.env = lmdb.open(path, max_readers=32, readonly=True, lock=False, | ||
readahead=False, meminit=False,) | ||
|
||
if not self.env: | ||
raise IOError('Cannot open lmdb dataset', path) | ||
|
||
with self.env.begin(write=False) as txn: | ||
self.length = int(txn.get('total'.encode('utf-8')).decode('utf-8')) | ||
self.width = int(txn.get('width'.encode('utf-8')).decode('utf-8')) | ||
self.height = int(txn.get('height'.encode('utf-8')).decode('utf-8')) | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, index): | ||
with self.env.begin(write=False) as txn: | ||
key = '{}-{}-{}'.format(self.width, self.height, str(index).zfill(7)).encode('utf-8') | ||
img_bytes = txn.get(key) | ||
buffer = io.BytesIO(img_bytes) | ||
img = Image.open(buffer) | ||
img = _transform(img) | ||
return img#, random.random() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import math | ||
import pickle | ||
|
||
import torch | ||
from torch import distributed as dist | ||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
def get_rank(): | ||
if not dist.is_available(): | ||
return 0 | ||
|
||
if not dist.is_initialized(): | ||
return 0 | ||
|
||
return dist.get_rank() | ||
|
||
|
||
def synchronize(): | ||
if not dist.is_available(): | ||
return | ||
|
||
if not dist.is_initialized(): | ||
return | ||
|
||
world_size = dist.get_world_size() | ||
|
||
if world_size == 1: | ||
return | ||
|
||
dist.barrier() | ||
|
||
|
||
def get_world_size(): | ||
if not dist.is_available(): | ||
return 1 | ||
|
||
if not dist.is_initialized(): | ||
return 1 | ||
|
||
return dist.get_world_size() | ||
|
||
|
||
def reduce_sum(tensor): | ||
if not dist.is_available(): | ||
return tensor | ||
|
||
if not dist.is_initialized(): | ||
return tensor | ||
|
||
tensor = tensor.clone() | ||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | ||
|
||
return tensor | ||
|
||
|
||
def gather_grad(params): | ||
world_size = get_world_size() | ||
|
||
if world_size == 1: | ||
return | ||
|
||
for param in params: | ||
if param.grad is not None: | ||
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) | ||
param.grad.data.div_(world_size) | ||
|
||
|
||
def all_gather(data): | ||
world_size = get_world_size() | ||
|
||
if world_size == 1: | ||
return [data] | ||
|
||
buffer = pickle.dumps(data) | ||
storage = torch.ByteStorage.from_buffer(buffer) | ||
tensor = torch.ByteTensor(storage).to('cuda') | ||
|
||
local_size = torch.IntTensor([tensor.numel()]).to('cuda') | ||
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] | ||
dist.all_gather(size_list, local_size) | ||
size_list = [int(size.item()) for size in size_list] | ||
max_size = max(size_list) | ||
|
||
tensor_list = [] | ||
for _ in size_list: | ||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) | ||
|
||
if local_size != max_size: | ||
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') | ||
tensor = torch.cat((tensor, padding), 0) | ||
|
||
dist.all_gather(tensor_list, tensor) | ||
|
||
data_list = [] | ||
|
||
for size, tensor in zip(size_list, tensor_list): | ||
buffer = tensor.cpu().numpy().tobytes()[:size] | ||
data_list.append(pickle.loads(buffer)) | ||
|
||
return data_list | ||
|
||
|
||
def reduce_loss_dict(loss_dict): | ||
world_size = get_world_size() | ||
|
||
if world_size < 2: | ||
return loss_dict | ||
|
||
with torch.no_grad(): | ||
keys = [] | ||
losses = [] | ||
|
||
for k in sorted(loss_dict.keys()): | ||
keys.append(k) | ||
losses.append(loss_dict[k]) | ||
|
||
losses = torch.stack(losses, 0) | ||
dist.reduce(losses, dst=0) | ||
|
||
if dist.get_rank() == 0: | ||
losses /= world_size | ||
|
||
reduced_losses = {k: v for k, v in zip(keys, losses)} | ||
|
||
return reduced_losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import argparse | ||
import pickle | ||
|
||
import torch | ||
from torch import nn | ||
import numpy as np | ||
from scipy import linalg | ||
from tqdm import tqdm | ||
|
||
from model import Generator | ||
from calc_inception import load_patched_inception_v3 | ||
|
||
|
||
@torch.no_grad() | ||
def extract_feature_from_samples( | ||
generator, inception, truncation, truncation_latent, batch_size, n_sample, device | ||
): | ||
n_batch = n_sample // batch_size | ||
resid = n_sample - (n_batch * batch_size) | ||
batch_sizes = [batch_size] * n_batch + [resid] | ||
features = [] | ||
|
||
for batch in tqdm(batch_sizes): | ||
latent = torch.randn(batch, 512, device=device) | ||
img, _,_,_,_ = g([latent], truncation=truncation, truncation_latent=truncation_latent) | ||
feat = inception(img)[0].view(img.shape[0], -1) | ||
features.append(feat.to('cpu')) | ||
|
||
features = torch.cat(features, 0) | ||
|
||
return features | ||
|
||
|
||
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): | ||
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) | ||
|
||
if not np.isfinite(cov_sqrt).all(): | ||
print('product of cov matrices is singular') | ||
offset = np.eye(sample_cov.shape[0]) * eps | ||
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) | ||
|
||
if np.iscomplexobj(cov_sqrt): | ||
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): | ||
m = np.max(np.abs(cov_sqrt.imag)) | ||
|
||
raise ValueError('Imaginary component {}'.format(m)) | ||
|
||
cov_sqrt = cov_sqrt.real | ||
|
||
mean_diff = sample_mean - real_mean | ||
mean_norm = mean_diff @ mean_diff | ||
|
||
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) | ||
|
||
fid = mean_norm + trace | ||
|
||
return fid | ||
|
||
|
||
if __name__ == '__main__': | ||
device = 'cuda' | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--truncation', type=float, default=1) | ||
parser.add_argument('--truncation_mean', type=int, default=4096) | ||
parser.add_argument('--batch', type=int, default=64) | ||
parser.add_argument('--n_sample', type=int, default=50000) | ||
parser.add_argument('--size', type=int, default=256) | ||
parser.add_argument('--inception', type=str, default='./inception/inception_cat.pkl') | ||
parser.add_argument('ckpt', metavar='CHECKPOINT') | ||
|
||
args = parser.parse_args() | ||
|
||
ckpt = torch.load(args.ckpt) | ||
|
||
g = Generator(args.size, 512, 8).to(device) | ||
g.load_state_dict(ckpt['g_ema']) | ||
g = nn.DataParallel(g) | ||
g.eval() | ||
|
||
if args.truncation < 1: | ||
with torch.no_grad(): | ||
mean_latent = g.mean_latent(args.truncation_mean) | ||
|
||
else: | ||
mean_latent = None | ||
|
||
inception = nn.DataParallel(load_patched_inception_v3()).to(device) | ||
inception.eval() | ||
|
||
features = extract_feature_from_samples( | ||
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device | ||
).numpy() | ||
print('extracted {} features'.format(features.shape[0])) | ||
|
||
sample_mean = np.mean(features, 0) | ||
sample_cov = np.cov(features, rowvar=False) | ||
|
||
with open(args.inception, 'rb') as f: | ||
embeds = pickle.load(f) | ||
real_mean = embeds['mean'] | ||
real_cov = embeds['cov'] | ||
|
||
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) | ||
|
||
print('fid:', fid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python3 fid.py cat_model.pt |
Binary file not shown.
Oops, something went wrong.