Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Heliang-Zheng committed Mar 24, 2021
1 parent 47b5663 commit ff71a6a
Show file tree
Hide file tree
Showing 16 changed files with 2,573 additions and 0 deletions.
116 changes: 116 additions & 0 deletions calc_inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import argparse
import pickle
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import inception_v3, Inception3
import numpy as np
from tqdm import tqdm

from inception import InceptionV3
from dataset import MultiResolutionDataset


class Inception3Feature(Inception3):
def forward(self, x):
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)

x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64

x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192

x = self.Mixed_5b(x) # 35 x 35 x 192
x = self.Mixed_5c(x) # 35 x 35 x 256
x = self.Mixed_5d(x) # 35 x 35 x 288

x = self.Mixed_6a(x) # 35 x 35 x 288
x = self.Mixed_6b(x) # 17 x 17 x 768
x = self.Mixed_6c(x) # 17 x 17 x 768
x = self.Mixed_6d(x) # 17 x 17 x 768
x = self.Mixed_6e(x) # 17 x 17 x 768

x = self.Mixed_7a(x) # 17 x 17 x 768
x = self.Mixed_7b(x) # 8 x 8 x 1280
x = self.Mixed_7c(x) # 8 x 8 x 2048

x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048

return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048


def load_patched_inception_v3():
# inception = inception_v3(pretrained=True)
# inception_feat = Inception3Feature()
# inception_feat.load_state_dict(inception.state_dict())
inception_feat = InceptionV3([3], normalize_input=False)

return inception_feat


@torch.no_grad()
def extract_features(loader, inception, device):
pbar = tqdm(loader)

feature_list = []

for img in pbar:
img = img.to(device)
feature = inception(img)[0].view(img.shape[0], -1)
feature_list.append(feature.to('cpu'))

features = torch.cat(feature_list, 0)

return features


if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(
description='Calculate Inception v3 features for datasets'
)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=64, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=50000)
parser.add_argument('--flip', action='store_true')
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file')

args = parser.parse_args()

inception = load_patched_inception_v3()
inception = nn.DataParallel(inception).eval().to(device)

transform = transforms.Compose(
[
#transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)

dset = MultiResolutionDataset(args.path)
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)

features = extract_features(loader, inception, device).numpy()

features = features[: args.n_sample]

print('extracted {} features'.format(features.shape[0]))

mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)

name = os.path.splitext(os.path.basename(args.path))[0]

with open('inception_{}.pkl'.format(name), 'wb') as f:
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f)
38 changes: 38 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import io
import lmdb
from PIL import Image

import torch
from torchvision import transforms
import random

_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]
)

class MultiResolutionDataset(torch.utils.data.Dataset):
def __init__(self, path):
self.env = lmdb.open(path, max_readers=32, readonly=True, lock=False,
readahead=False, meminit=False,)

if not self.env:
raise IOError('Cannot open lmdb dataset', path)

with self.env.begin(write=False) as txn:
self.length = int(txn.get('total'.encode('utf-8')).decode('utf-8'))
self.width = int(txn.get('width'.encode('utf-8')).decode('utf-8'))
self.height = int(txn.get('height'.encode('utf-8')).decode('utf-8'))

def __len__(self):
return self.length

def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = '{}-{}-{}'.format(self.width, self.height, str(index).zfill(7)).encode('utf-8')
img_bytes = txn.get(key)
buffer = io.BytesIO(img_bytes)
img = Image.open(buffer)
img = _transform(img)
return img#, random.random()
126 changes: 126 additions & 0 deletions distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import math
import pickle

import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler


def get_rank():
if not dist.is_available():
return 0

if not dist.is_initialized():
return 0

return dist.get_rank()


def synchronize():
if not dist.is_available():
return

if not dist.is_initialized():
return

world_size = dist.get_world_size()

if world_size == 1:
return

dist.barrier()


def get_world_size():
if not dist.is_available():
return 1

if not dist.is_initialized():
return 1

return dist.get_world_size()


def reduce_sum(tensor):
if not dist.is_available():
return tensor

if not dist.is_initialized():
return tensor

tensor = tensor.clone()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

return tensor


def gather_grad(params):
world_size = get_world_size()

if world_size == 1:
return

for param in params:
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data.div_(world_size)


def all_gather(data):
world_size = get_world_size()

if world_size == 1:
return [data]

buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')

local_size = torch.IntTensor([tensor.numel()]).to('cuda')
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)

tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))

if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
tensor = torch.cat((tensor, padding), 0)

dist.all_gather(tensor_list, tensor)

data_list = []

for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))

return data_list


def reduce_loss_dict(loss_dict):
world_size = get_world_size()

if world_size < 2:
return loss_dict

with torch.no_grad():
keys = []
losses = []

for k in sorted(loss_dict.keys()):
keys.append(k)
losses.append(loss_dict[k])

losses = torch.stack(losses, 0)
dist.reduce(losses, dst=0)

if dist.get_rank() == 0:
losses /= world_size

reduced_losses = {k: v for k, v in zip(keys, losses)}

return reduced_losses
107 changes: 107 additions & 0 deletions fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse
import pickle

import torch
from torch import nn
import numpy as np
from scipy import linalg
from tqdm import tqdm

from model import Generator
from calc_inception import load_patched_inception_v3


@torch.no_grad()
def extract_feature_from_samples(
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
):
n_batch = n_sample // batch_size
resid = n_sample - (n_batch * batch_size)
batch_sizes = [batch_size] * n_batch + [resid]
features = []

for batch in tqdm(batch_sizes):
latent = torch.randn(batch, 512, device=device)
img, _,_,_,_ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
feat = inception(img)[0].view(img.shape[0], -1)
features.append(feat.to('cpu'))

features = torch.cat(features, 0)

return features


def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)

if not np.isfinite(cov_sqrt).all():
print('product of cov matrices is singular')
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))

if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))

raise ValueError('Imaginary component {}'.format(m))

cov_sqrt = cov_sqrt.real

mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff

trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)

fid = mean_norm + trace

return fid


if __name__ == '__main__':
device = 'cuda'

parser = argparse.ArgumentParser()

parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--batch', type=int, default=64)
parser.add_argument('--n_sample', type=int, default=50000)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--inception', type=str, default='./inception/inception_cat.pkl')
parser.add_argument('ckpt', metavar='CHECKPOINT')

args = parser.parse_args()

ckpt = torch.load(args.ckpt)

g = Generator(args.size, 512, 8).to(device)
g.load_state_dict(ckpt['g_ema'])
g = nn.DataParallel(g)
g.eval()

if args.truncation < 1:
with torch.no_grad():
mean_latent = g.mean_latent(args.truncation_mean)

else:
mean_latent = None

inception = nn.DataParallel(load_patched_inception_v3()).to(device)
inception.eval()

features = extract_feature_from_samples(
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
).numpy()
print('extracted {} features'.format(features.shape[0]))

sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)

with open(args.inception, 'rb') as f:
embeds = pickle.load(f)
real_mean = embeds['mean']
real_cov = embeds['cov']

fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)

print('fid:', fid)
1 change: 1 addition & 0 deletions fid.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 fid.py cat_model.pt
Binary file added inception/inception_cat.pkl
Binary file not shown.
Loading

0 comments on commit ff71a6a

Please sign in to comment.