|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import pywt |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from taming.modules.diffusionmodules.model import Decoder |
| 6 | + |
| 7 | +from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int |
| 8 | + |
| 9 | + |
| 10 | +class DecoderDWT(nn.Module): |
| 11 | + def __init__(self, ddconfig, embed_dim): |
| 12 | + super().__init__() |
| 13 | + if ddconfig.out_ch != 12: |
| 14 | + ddconfig.out_ch = 12 |
| 15 | + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) |
| 16 | + self.decoder = Decoder(**ddconfig) |
| 17 | + self.idwt = DWTInverse(mode='zero', wave='db1') |
| 18 | + |
| 19 | + def forward(self, x): |
| 20 | + # x = self.post_quant_conv(x) |
| 21 | + freq = self.decoder(x) |
| 22 | + img = self.dwt_to_img(freq) |
| 23 | + return img |
| 24 | + |
| 25 | + def dwt_to_img(self, img): |
| 26 | + b, c, h, w = img.size() |
| 27 | + low = img[:, :3, :, :] |
| 28 | + high = img[:, 3:, :, :].view(b, 3, 3, h, w) |
| 29 | + return self.idwt((low, [high])) |
| 30 | + |
| 31 | + |
| 32 | +class DWTInverse(nn.Module): |
| 33 | + """ Performs a 2d DWT Inverse reconstruction of an image |
| 34 | +
|
| 35 | + Args: |
| 36 | + wave (str or pywt.Wavelet): Which wavelet to use |
| 37 | + C: deprecated, will be removed in future |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, wave='db1', mode='zero', trace_model=False): |
| 41 | + super().__init__() |
| 42 | + if isinstance(wave, str): |
| 43 | + wave = pywt.Wavelet(wave) |
| 44 | + if isinstance(wave, pywt.Wavelet): |
| 45 | + g0_col, g1_col = wave.rec_lo, wave.rec_hi |
| 46 | + g0_row, g1_row = g0_col, g1_col |
| 47 | + else: |
| 48 | + if len(wave) == 2: |
| 49 | + g0_col, g1_col = wave[0], wave[1] |
| 50 | + g0_row, g1_row = g0_col, g1_col |
| 51 | + elif len(wave) == 4: |
| 52 | + g0_col, g1_col = wave[0], wave[1] |
| 53 | + g0_row, g1_row = wave[2], wave[3] |
| 54 | + # Prepare the filters |
| 55 | + filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) |
| 56 | + self.register_buffer('g0_col', filts[0]) |
| 57 | + self.register_buffer('g1_col', filts[1]) |
| 58 | + self.register_buffer('g0_row', filts[2]) |
| 59 | + self.register_buffer('g1_row', filts[3]) |
| 60 | + self.mode = mode |
| 61 | + self.trace_model = trace_model |
| 62 | + |
| 63 | + def forward(self, coeffs): |
| 64 | + """ |
| 65 | + Args: |
| 66 | + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: |
| 67 | + yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', |
| 68 | + W_{in}')` and yh is a list of bandpass tensors of shape |
| 69 | + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match |
| 70 | + the format returned by DWTForward |
| 71 | +
|
| 72 | + Returns: |
| 73 | + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` |
| 74 | +
|
| 75 | + Note: |
| 76 | + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly |
| 77 | + downsampled shapes of the DWT pyramid. |
| 78 | +
|
| 79 | + Note: |
| 80 | + Can have None for any of the highpass scales and will treat the |
| 81 | + values as zeros (not in an efficient way though). |
| 82 | + """ |
| 83 | + yl, yh = coeffs |
| 84 | + ll = yl |
| 85 | + mode = mode_to_int(self.mode) |
| 86 | + |
| 87 | + # Do a multilevel inverse transform |
| 88 | + for h in yh[::-1]: |
| 89 | + if h is None: |
| 90 | + h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], |
| 91 | + ll.shape[-1], device=ll.device) |
| 92 | + |
| 93 | + # 'Unpad' added dimensions |
| 94 | + if ll.shape[-2] > h.shape[-2]: |
| 95 | + ll = ll[..., :-1, :] |
| 96 | + if ll.shape[-1] > h.shape[-1]: |
| 97 | + ll = ll[..., :-1] |
| 98 | + if not self.trace_model: |
| 99 | + ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) |
| 100 | + else: |
| 101 | + ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) |
| 102 | + return ll |
0 commit comments