Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added HW89/homework/ckpt/discriminator_epoch_19.pt
Binary file not shown.
Binary file added HW89/homework/ckpt/discriminator_epoch_5.pt
Binary file not shown.
Binary file added HW89/homework/ckpt/generator_epoch_19.pt
Binary file not shown.
Binary file added HW89/homework/ckpt/generator_epoch_5.pt
Binary file not shown.
Empty file added HW89/homework/dcgan/__init__.py
Empty file.
69 changes: 69 additions & 0 deletions HW89/homework/dcgan/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

#куски кода взяты с https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dcgan/dcgan.py

import torch.nn as nn


class DCGenerator(nn.Module):

def __init__(self, image_size):
super(DCGenerator, self).__init__()
self.channels = 3

self.init_size = image_size // 4
self.l1 = nn.Sequential(nn.Linear(100, 128*self.init_size**2))

self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, self.channels, 3, stride=1, padding=1),
nn.Tanh()
)

def forward(self, data):
out = self.l1(data.squeeze())
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img


class DCDiscriminator(nn.Module):

def __init__(self, image_size):
super(DCDiscriminator, self).__init__()

self.channels = 3

def discriminator_block(in_filters, out_filters, bn=True):
block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block

self.model = nn.Sequential(
*discriminator_block(self.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)

# The height and width of downsampled image
ds_size = image_size // 2**4
self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1),
nn.Sigmoid())

def forward(self, data):
out = self.model(data)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)

return validity
116 changes: 116 additions & 0 deletions HW89/homework/dcgan/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torchvision.utils as vutils
from tensorboardX import SummaryWriter

import metric


class DCGANTrainer:

def __init__(self, discriminator, generator, optimizer_d, optimizer_g, latent_size=100,
device='cpu', metrics_dir='metrics', save_root='ckpt', log_dir=None, start_epoch=0):
self.net_g = generator
self.net_d = discriminator
self.optimizer_d = optimizer_d
self.optimizer_g = optimizer_g
self.latent_size = latent_size
self.device = device

self.metric_dir = metrics_dir
self.save_root = save_root

self.net_g.to(device)
self.net_d.to(device)

self.start_epoch = start_epoch
if self.start_epoch == 0:
self.net_g.apply(self._weights_init)
self.net_d.apply(self._weights_init)

self.writer = SummaryWriter(log_dir=log_dir)

@staticmethod
def _weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

def save(self, epoch):
os.makedirs(self.save_root, exist_ok=True)
torch.save(self.net_g.state_dict(), os.path.join(self.save_root, 'generator_epoch_{}.pt'.format(epoch)))
torch.save(self.net_d.state_dict(), os.path.join(self.save_root, 'discriminator_epoch_{}.pt'.format(epoch)))

def train(self, dataloader, n_epoch=25, n_show_samples=8, show_img_every=10, log_metrics_every=100,
metrics_dataset='cifar10', metrics_to_log=('inception-score', 'mode-score', 'fid')):
criterion = nn.BCELoss()

global_step = 0
for epoch in range(self.start_epoch, n_epoch):
for i, data in enumerate(dataloader):

self.net_d.zero_grad()
real, _ = data
real = real.to(self.device)

target = torch.ones(real.size()[0], device=self.device)

output = self.net_d(real)
err_d_real = criterion(output, target)

noise = torch.randn(real.size()[0], self.latent_size, 1, 1, device=self.device)
fake = self.net_g(noise)

if global_step % show_img_every == 0:
x = vutils.make_grid(fake[:n_show_samples, :, :, :], normalize=True, scale_each=True)
self.writer.add_image('img/fake', x, global_step)

y = vutils.make_grid(real[:n_show_samples, :, :, :], normalize=True, scale_each=True)
self.writer.add_image('img/real', y, global_step)

target = torch.zeros(real.size()[0], device=self.device)
output = self.net_d(fake.detach())
err_d_fake = criterion(output, target)

err_d = err_d_real + err_d_fake
err_d.backward()
self.optimizer_d.step()

self.net_g.zero_grad()
target = torch.ones(real.size()[0], device=self.device)
output = self.net_d(fake)
err_g = criterion(output, target)
err_g.backward()
self.optimizer_g.step()

#logging.info(f'epoch: [{epoch}/{n_epoch}] iter: [{i}/{len(dataloader)}] loss_D: {err_d:.4f} '
logging.info('epoch: [{}/{}] iter: [{}/{}] loss_D: {:.4f} '\
.format(epoch, n_epoch, i, len(dataloader), err_d)
#f'loss_G: {err_g:.4f}')
+ 'loss_G: {:.4f}'.format(err_g))
self.writer.add_scalar('data/loss_discriminator', err_d, global_step)
self.writer.add_scalar('data/loss_generator', err_g, global_step)

self.net_g.eval()
if global_step % log_metrics_every == 0:
image_size = real.shape[-1]
report_dict = metric.compute_metrics(metrics_dataset,
image_size=image_size,
metrics_root=Path(self.metric_dir),
batch_size=dataloader.batch_size, netG=self.net_g)

for mtrc in metrics_to_log:
self.writer.add_scalar('data/{}'.format(mtrc), report_dict[mtrc], global_step)
self.net_g.train()
global_step += 1

self.save(epoch)
Loading