Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision components related model.gans.pix2pix #883

Draft
wants to merge 54 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
15e6b10
add description
Jungwon-Lee Sep 7, 2022
722990a
mode weight_init to inner method
Jungwon-Lee Sep 7, 2022
3a7ad92
temp
Jungwon-Lee Sep 8, 2022
7aca0b9
Add typo
Jungwon-Lee Sep 8, 2022
414cf6f
fix bugs
Jungwon-Lee Sep 8, 2022
92f4cf6
add forward
Jungwon-Lee Sep 8, 2022
67806e0
rename generator
Jungwon-Lee Sep 8, 2022
b750a6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2022
cd194dc
add test_pix2pix_components
Jungwon-Lee Sep 18, 2022
d36fcc2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2022
cc120fa
remove conflict
Jungwon-Lee Sep 18, 2022
9d94641
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2022
a991cd5
fix test_generator
Jungwon-Lee Sep 18, 2022
ead7789
Merge branch 'master' into revision_pix2pix
Borda Sep 19, 2022
99ca850
Merge branch 'master' into revision_pix2pix
otaj Sep 19, 2022
8ca7428
fix test_pix2pix
Jungwon-Lee Sep 19, 2022
69d7b40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2022
638c98d
add pix2pix module test
Jungwon-Lee Sep 19, 2022
481bc78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2022
6f15f91
fix pix2pix_module test
Jungwon-Lee Sep 20, 2022
07682ad
Merge branch 'master' into revision_pix2pix
Borda Oct 27, 2022
2926bb7
Merge branch 'master' into revision_pix2pix
Nov 2, 2022
ca2fb9c
Merge branch 'master' into revision_pix2pix
Jungwon-Lee Nov 4, 2022
20bcb24
remove unnecessary args
Jungwon-Lee Sep 21, 2022
35c7c0e
add test case
Jungwon-Lee Sep 27, 2022
2743958
Merge branch 'master' into revision_pix2pix
otaj Nov 4, 2022
f9c8cdf
Merge branch 'master' into revision_pix2pix
Borda Jan 8, 2023
8004fc9
Merge branch 'master' into revision_pix2pix
Borda May 18, 2023
bd23c27
update mergify team
Borda May 19, 2023
dc9a4d0
Merge branch 'master' into revision_pix2pix
Borda May 19, 2023
be8afb5
Merge branch 'master' into revision_pix2pix
Borda May 19, 2023
a196a86
Merge branch 'master' into revision_pix2pix
mergify[bot] May 19, 2023
be21857
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
d2ffaf8
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
9df8915
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
729ce70
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
db5153a
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
007ba53
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
6a8b923
Merge branch 'master' into revision_pix2pix
Borda May 20, 2023
c168766
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2023
73b3adf
Merge branch 'master' into revision_pix2pix
mergify[bot] May 20, 2023
7e501ce
Merge branch 'master' into revision_pix2pix
mergify[bot] May 21, 2023
6a90c42
Merge branch 'master' into revision_pix2pix
mergify[bot] May 21, 2023
26ec2fd
Merge branch 'master' into revision_pix2pix
mergify[bot] May 22, 2023
6ec8477
Merge branch 'master' into revision_pix2pix
mergify[bot] May 22, 2023
591db73
Merge branch 'master' into revision_pix2pix
mergify[bot] May 29, 2023
d3ca74b
Merge branch 'master' into revision_pix2pix
mergify[bot] May 30, 2023
3fa4323
Merge branch 'master' into revision_pix2pix
Borda May 31, 2023
0152d2b
Merge branch 'master' into revision_pix2pix
mergify[bot] May 31, 2023
304869d
Merge branch 'master' into revision_pix2pix
Borda May 31, 2023
26ced42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
f3f76ac
Merge branch 'master' into revision_pix2pix
mergify[bot] Jun 12, 2023
f7db4f3
Merge branch 'master' into revision_pix2pix
mergify[bot] Jun 16, 2023
9d7a520
Merge branch 'master' into revision_pix2pix
mergify[bot] Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 55 additions & 74 deletions src/pl_bolts/models/gans/pix2pix/components.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,36 @@
import torch
from torch import nn
from torch import Tensor, nn

from pl_bolts.utils.stability import under_review


@under_review()
class UpSampleConv(nn.Module):
def __init__(
self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True, dropout=False
self,
in_channels: int,
out_channels: int,
batchnorm: bool = True,
dropout: bool = False,
) -> None:
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
self.dropout = dropout

self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)
layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel=4, strides=2, padding=1)]

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)

if activation:
self.act = nn.ReLU(True)
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(True))

if dropout:
self.drop = nn.Dropout2d(0.5)

def forward(self, x):
x = self.deconv(x)
if self.batchnorm:
x = self.bn(x)
layers.append(nn.Dropout2d(0.5))
self.model = nn.Sequential(*layers)

if self.dropout:
x = self.drop(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self.model(x)


@under_review()
class DownSampleConv(nn.Module):
def __init__(
self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True
self,
in_channels: int,
out_channels: int,
batchnorm: bool = True,
) -> None:
"""Paper details:

Expand All @@ -47,102 +39,91 @@ def __init__(
- Convolutions in the encoder downsample by a factor of 2
"""
super().__init__()
self.activation = activation
self.batchnorm = batchnorm

self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)
layers = [nn.Conv2d(in_channels, out_channels, kernel=4, strides=2, padding=1)]

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2))

if activation:
self.act = nn.LeakyReLU(0.2)
self.model = nn.Sequential(*layers)

def forward(self, x):
x = self.conv(x)
if self.batchnorm:
x = self.bn(x)
if self.activation:
x = self.act(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self.model(x)


@under_review()
class Generator(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
def __init__(self, in_channels: int, out_channels: int) -> None:
"""Paper details:

- Encoder: C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- All convolutions are 4x4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
"""
super().__init__()

# encoder/donwsample convs
self.encoders = [
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128
DownSampleConv(64, 128), # bs x 128 x 64 x 64
DownSampleConv(128, 256), # bs x 256 x 32 x 32
DownSampleConv(256, 512), # bs x 512 x 16 x 16
DownSampleConv(512, 512), # bs x 512 x 8 x 8
DownSampleConv(512, 512), # bs x 512 x 4 x 4
DownSampleConv(512, 512), # bs x 512 x 2 x 2
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1
]
self.encoders = nn.ModuleList(
[
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128
DownSampleConv(64, 128), # bs x 128 x 64 x 64
DownSampleConv(128, 256), # bs x 256 x 32 x 32
DownSampleConv(256, 512), # bs x 512 x 16 x 16
DownSampleConv(512, 512), # bs x 512 x 8 x 8
DownSampleConv(512, 512), # bs x 512 x 4 x 4
DownSampleConv(512, 512), # bs x 512 x 2 x 2
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1
]
)

# decoder/upsample convs
self.decoders = [
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8
UpSampleConv(1024, 512), # bs x 512 x 16 x 16
UpSampleConv(1024, 256), # bs x 256 x 32 x 32
UpSampleConv(512, 128), # bs x 128 x 64 x 64
UpSampleConv(256, 64), # bs x 64 x 128 x 128
]
self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
self.decoders = nn.ModuleList(
[
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8
UpSampleConv(1024, 512), # bs x 512 x 16 x 16
UpSampleConv(1024, 256), # bs x 256 x 32 x 32
UpSampleConv(512, 128), # bs x 128 x 64 x 64
UpSampleConv(256, 64), # bs x 64 x 128 x 128
]
)
self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
self.tanh = nn.Tanh()

self.encoders = nn.ModuleList(self.encoders)
self.decoders = nn.ModuleList(self.decoders)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
skips_cons = []
for encoder in self.encoders:
x = encoder(x)

skips_cons.append(x)

skips_cons = list(reversed(skips_cons[:-1]))
decoders = self.decoders[:-1]

for decoder, skip in zip(decoders, skips_cons):
x = decoder(x)
# print(x.shape, skip.shape)
x = torch.cat((x, skip), axis=1)

x = self.decoders[-1](x)
# print(x.shape)
x = self.final_conv(x)
return self.tanh(x)


@under_review()
class PatchGAN(nn.Module):
def __init__(self, input_channels) -> None:
def __init__(self, input_channels: int) -> None:
super().__init__()
self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
self.d2 = DownSampleConv(64, 128)
self.d3 = DownSampleConv(128, 256)
self.d4 = DownSampleConv(256, 512)
self.final = nn.Conv2d(512, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()

def forward(self, x, y):
def forward(self, x: Tensor, y: Tensor) -> Tensor:
x = torch.cat([x, y], axis=1)
x0 = self.d1(x)
x1 = self.d2(x0)
x2 = self.d3(x1)
x3 = self.d4(x2)
return self.final(x3)
xn = self.final(x3)
return self.sigmoid(xn)
70 changes: 50 additions & 20 deletions src/pl_bolts/models/gans/pix2pix/pix2pix_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,55 @@
from torch import nn

from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN
from pl_bolts.utils.stability import under_review


@under_review()
def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)


@under_review()
class Pix2Pix(LightningModule):
def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200) -> None:
"""Pix2Pix implementation from the paper
Paper: `Image-to-Image Translation with Conditional Adversarial Networks. <https://arxiv.org/abs/1611.07004>`

Example::
from pl_bolts.models.gans import Pix2Pix

model = Pix2Pix()
datamodule = CustomDataModule()

Trainer(gpus=1).fit(model, datamodule)
"""

def __init__(
self, in_channels: int, out_channels: int, learning_rate: float = 0.0002, lambda_recon: int = 200
) -> None:
"""
Args:
in_channels: Number of channels of the conditional images from the dataset
out_channels: Number of channels of the real images from the dataset
learning_rate: Learning rate
lambda_recon: Lambda of reconstruction loss
"""
super().__init__()
self.save_hyperparameters()

self.gen = Generator(in_channels, out_channels)
# networks
self.generator = Generator(in_channels, out_channels)
self.patch_gan = PatchGAN(in_channels + out_channels)

# intializing weights
self.gen = self.gen.apply(_weights_init)
self.patch_gan = self.patch_gan.apply(_weights_init)
self.generator = self.generator.apply(self._weights_init)
self.patch_gan = self.patch_gan.apply(self._weights_init)

# criterion
self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()

def forward(self, x):
return self.generator(x)

def _gen_step(self, real_images, conditioned_images):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
fake_images = self.gen(conditioned_images)
# discriminate fake image
fake_images = self.generator(conditioned_images)
disc_logits = self.patch_gan(fake_images, conditioned_images)

# calculate adversarial loss
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

# calculate reconstruction loss
Expand All @@ -45,28 +61,42 @@ def _gen_step(self, real_images, conditioned_images):
return adversarial_loss + lambda_recon * recon_loss

def _disc_step(self, real_images, conditioned_images):
fake_images = self.gen(conditioned_images).detach()
# discriminate fake image
fake_images = self.generator(conditioned_images).detach()
fake_logits = self.patch_gan(fake_images, conditioned_images)

# discriminate real image
real_logits = self.patch_gan(real_images, conditioned_images)

# calculate adversarial loss
fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
return (real_loss + fake_loss) / 2

@staticmethod
def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0.0)

def configure_optimizers(self):
lr = self.hparams.learning_rate
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr)
gen_opt = torch.optim.Adam(self.generator.parameters(), lr=lr)
disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr)
return disc_opt, gen_opt

def training_step(self, batch, batch_idx, optimizer_idx):
real, condition = batch

loss = None
# Train discriminator (patchGAN)
if optimizer_idx == 0:
loss = self._disc_step(real, condition)
self.log("PatchGAN Loss", loss)

# Train generator
elif optimizer_idx == 1:
loss = self._gen_step(real, condition)
self.log("Generator Loss", loss)
Expand Down
24 changes: 23 additions & 1 deletion tests/models/gans/integration/test_gans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import pytest
from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule
from pl_bolts.datasets.dummy_dataset import DummyDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet
from pl_bolts.models.gans import DCGAN, GAN, SRGAN, Pix2Pix, SRResNet
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -66,3 +67,24 @@ def test_sr_modules(tmpdir, datadir, sr_module_cls, scale_factor):
model = sr_module_cls(image_channels=1, scale_factor=scale_factor)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, dl)


@pytest.mark.parametrize("dataset_cls", [DummyDataset])
@pytest.mark.parametrize(
("in_shape", "out_shape"),
[
pytest.param((3, 256, 256), (3, 256, 256), id="img shape (3, 256, 256), (3, 256, 256)"),
pytest.param((1, 256, 256), (3, 256, 256), id="img shape (1, 256, 256), (3, 256, 256)"),
Comment on lines +76 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pix2pix should be able to handle images of lower sizes than 256x256

pytest.param((3, 128, 128), (3, 128, 128), id="img shape (3, 128, 128), (3, 128, 128)"),
pytest.param((1, 128, 128), (3, 128, 128), id="img shape (1, 128, 128), (3, 128, 128)"),
pytest.param((3, 64, 64), (3, 64, 64), id="img shape (3, 64, 64), (3, 64, 64)"),
pytest.param((1, 64, 64), (3, 64, 64), id="img shape (1, 64, 64), (3, 64, 64)"),
],
)
def test_pix2pix(tmpdir, datadir, dataset_cls, in_shape, out_shape):
seed_everything(42)

dl = DataLoader(dataset_cls(out_shape, in_shape))
model = Pix2Pix(in_channels=in_shape[0], out_channels=out_shape[0])
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, dl)
Loading