Skip to content

Commit 5488289

Browse files
authored
Merge pull request #43 from sberbank-ai/feature/dwt_vae
Feature/dwt vae
2 parents a23a834 + a1980cf commit 5488289

10 files changed

+533
-10
lines changed

.coveragerc

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[run]
2+
omit =
3+
# omit this single file
4+
rudalle/vae/pytorch_wavelets_utils.py

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
88

99
```
10-
pip install rudalle==0.0.1rc6
10+
pip install rudalle==0.0.1rc7
1111
```
1212
### 🤗 HF Models:
1313
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
@@ -92,6 +92,7 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
9292

9393
### 🚀 Contributors 🚀
9494

95+
- [@bes](https://github.com/bes-dev) shared [great idea and realization with IDWT](https://github.com/bes-dev/vqvae_dwt_distiller.pytorch) for decoding images with higher quality 512x512! 😈💪
9596
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
9697
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
9798
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ transformers~=4.10.2
44
youtokentome~=1.0.6
55
omegaconf>=2.0.0
66
einops~=0.3.2
7+
PyWavelets==1.1.1
78
torch
89
torchvision
910
matplotlib

rudalle/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@
2222
'image_prompts',
2323
]
2424

25-
__version__ = '0.0.1-rc6'
25+
__version__ = '0.0.1-rc7'

rudalle/vae/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@
88
from .model import VQGanGumbelVAE
99

1010

11-
def get_vae(pretrained=True, cache_dir='/tmp/rudalle'):
11+
def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
1212
# TODO
1313
config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml'))
14-
vae = VQGanGumbelVAE(config)
14+
vae = VQGanGumbelVAE(config, dwt=dwt)
1515
if pretrained:
1616
repo_id = 'shonenkov/rudalle-utils'
17-
filename = 'vqgan.gumbelf8-sber.model.ckpt'
17+
if dwt:
18+
filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt'
19+
else:
20+
filename = 'vqgan.gumbelf8-sber.model.ckpt'
1821
cache_dir = join(cache_dir, 'vae')
1922
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
2023
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
2124
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
22-
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
25+
if dwt:
26+
vae.load_state_dict(checkpoint['state_dict'])
27+
else:
28+
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
2329
print('vae --> ready')
2430
return vae

rudalle/vae/decoder_dwt.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# -*- coding: utf-8 -*-
2+
import pywt
3+
import torch
4+
import torch.nn as nn
5+
from taming.modules.diffusionmodules.model import Decoder
6+
7+
from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int
8+
9+
10+
class DecoderDWT(nn.Module):
11+
def __init__(self, ddconfig, embed_dim):
12+
super().__init__()
13+
if ddconfig.out_ch != 12:
14+
ddconfig.out_ch = 12
15+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
16+
self.decoder = Decoder(**ddconfig)
17+
self.idwt = DWTInverse(mode='zero', wave='db1')
18+
19+
def forward(self, x):
20+
# x = self.post_quant_conv(x)
21+
freq = self.decoder(x)
22+
img = self.dwt_to_img(freq)
23+
return img
24+
25+
def dwt_to_img(self, img):
26+
b, c, h, w = img.size()
27+
low = img[:, :3, :, :]
28+
high = img[:, 3:, :, :].view(b, 3, 3, h, w)
29+
return self.idwt((low, [high]))
30+
31+
32+
class DWTInverse(nn.Module):
33+
""" Performs a 2d DWT Inverse reconstruction of an image
34+
35+
Args:
36+
wave (str or pywt.Wavelet): Which wavelet to use
37+
C: deprecated, will be removed in future
38+
"""
39+
40+
def __init__(self, wave='db1', mode='zero', trace_model=False):
41+
super().__init__()
42+
if isinstance(wave, str):
43+
wave = pywt.Wavelet(wave)
44+
if isinstance(wave, pywt.Wavelet):
45+
g0_col, g1_col = wave.rec_lo, wave.rec_hi
46+
g0_row, g1_row = g0_col, g1_col
47+
else:
48+
if len(wave) == 2:
49+
g0_col, g1_col = wave[0], wave[1]
50+
g0_row, g1_row = g0_col, g1_col
51+
elif len(wave) == 4:
52+
g0_col, g1_col = wave[0], wave[1]
53+
g0_row, g1_row = wave[2], wave[3]
54+
# Prepare the filters
55+
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
56+
self.register_buffer('g0_col', filts[0])
57+
self.register_buffer('g1_col', filts[1])
58+
self.register_buffer('g0_row', filts[2])
59+
self.register_buffer('g1_row', filts[3])
60+
self.mode = mode
61+
self.trace_model = trace_model
62+
63+
def forward(self, coeffs):
64+
"""
65+
Args:
66+
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
67+
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
68+
W_{in}')` and yh is a list of bandpass tensors of shape
69+
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
70+
the format returned by DWTForward
71+
72+
Returns:
73+
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
74+
75+
Note:
76+
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
77+
downsampled shapes of the DWT pyramid.
78+
79+
Note:
80+
Can have None for any of the highpass scales and will treat the
81+
values as zeros (not in an efficient way though).
82+
"""
83+
yl, yh = coeffs
84+
ll = yl
85+
mode = mode_to_int(self.mode)
86+
87+
# Do a multilevel inverse transform
88+
for h in yh[::-1]:
89+
if h is None:
90+
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
91+
ll.shape[-1], device=ll.device)
92+
93+
# 'Unpad' added dimensions
94+
if ll.shape[-2] > h.shape[-2]:
95+
ll = ll[..., :-1, :]
96+
if ll.shape[-1] > h.shape[-1]:
97+
ll = ll[..., :-1]
98+
if not self.trace_model:
99+
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
100+
else:
101+
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
102+
return ll

rudalle/vae/model.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@
88
from einops import rearrange
99
from taming.modules.diffusionmodules.model import Encoder, Decoder
1010

11+
from .decoder_dwt import DecoderDWT
12+
1113

1214
class VQGanGumbelVAE(torch.nn.Module):
1315

14-
def __init__(self, config):
16+
def __init__(self, config, dwt=False):
1517
super().__init__()
1618
model = GumbelVQ(
1719
ddconfig=config.model.params.ddconfig,
1820
n_embed=config.model.params.n_embed,
1921
embed_dim=config.model.params.embed_dim,
2022
kl_weight=config.model.params.kl_weight,
23+
dwt=dwt,
2124
)
2225
self.model = model
2326
self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2))
@@ -79,11 +82,12 @@ def forward(self, z, temp=None, return_logits=False):
7982

8083
class GumbelVQ(nn.Module):
8184

82-
def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8):
85+
def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8):
8386
super().__init__()
8487
z_channels = ddconfig['z_channels']
88+
self.dwt = dwt
8589
self.encoder = Encoder(**ddconfig)
86-
self.decoder = Decoder(**ddconfig)
90+
self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig)
8791
self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
8892
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
8993
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
@@ -95,6 +99,9 @@ def encode(self, x):
9599
return quant, emb_loss, info
96100

97101
def decode(self, quant):
98-
quant = self.post_quant_conv(quant)
102+
if self.dwt:
103+
quant = self.decoder.post_quant_conv(quant)
104+
else:
105+
quant = self.post_quant_conv(quant)
99106
dec = self.decoder(quant)
100107
return dec

0 commit comments

Comments
 (0)