diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index 3af72cf..b44b91a 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -10,7 +10,13 @@ from torch import nn from torch.nn import functional as F -import kornia.augmentation as K + +# import .cutouts +# import .cutouts as cutouts +# import cutouts + +from .cutouts import augs as cutouts_augs +from .cutouts import samplers as cutouts_samplers PADDING_MODES = { "mirror": "reflect", @@ -43,19 +49,7 @@ def __init__( self.cut_sizes = [p.visual.input_resolution for p in perceptors] self.cutn = cutn self.noise_fac = noise_fac - self.augs = nn.Sequential( - K.RandomHorizontalFlip(p=0.3), - K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), - K.RandomPerspective( - 0.2, - p=0.4, - ), - K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), - K.RandomErasing( - scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 - ), - nn.Identity(), - ) + self.augs = cutouts_augs.pytti_classic() self.input_axes = ("n", "s", "y", "x") self.output_axes = ("c", "n", "i") self.perceptors = perceptors @@ -64,69 +58,34 @@ def __init__( self.border_mode = border_mode def make_cutouts( - self, input: torch.Tensor, side_x, side_y, cut_size, device=DEVICE + self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + #### + # padding, + # cutn, + # cut_pow, + # border_mode, + # augs, + # noise_fac, + #### + device=DEVICE, ) -> Tuple[list, list, list]: - min_size = min(side_x, side_y, cut_size) - max_size = min(side_x, side_y) - paddingx = min(round(side_x * self.padding), side_x) - paddingy = min(round(side_y * self.padding), side_y) - cutouts = [] - offsets = [] - sizes = [] - for _ in range(self.cutn): - # mean is 0.8 - # varience is 0.3 - size = int( - max_size - * ( - torch.zeros( - 1, - ) - .normal_(mean=0.8, std=0.3) - .clip(cut_size / max_size, 1.0) - ** self.cut_pow - ) - ) - offsetx_max = side_x - size + 1 - offsety_max = side_y - size + 1 - if self.border_mode == "clamp": - offsetx = torch.clamp( - (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) - .floor() - .int(), - 0, - offsetx_max, - ) - offsety = torch.clamp( - (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) - .floor() - .int(), - 0, - offsety_max, - ) - cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] - else: - px = min(size, paddingx) - py = min(size, paddingy) - offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() - offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() - cutout = input[ - :, - :, - paddingy + offsety : paddingy + offsety + size, - paddingx + offsetx : paddingx + offsetx + size, - ] - cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) - offsets.append( - torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) - ) - sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) - cutouts = self.augs(torch.cat(cutouts)) - offsets = torch.cat(offsets) - sizes = torch.cat(sizes) - if self.noise_fac: - facs = cutouts.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) - cutouts.add_(facs * torch.randn_like(cutouts)) + cutouts, offsets, sizes = cutouts_samplers.pytti_classic( + input=input, + side_x=side_x, + side_y=side_y, + cut_size=cut_size, + padding=self.padding, + cutn=self.cutn, + cut_pow=self.cut_pow, + border_mode=self.border_mode, + augs=self.augs, + noise_fac=self.noise_fac, + device=DEVICE, + ) return cutouts, offsets, sizes def forward( diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pytti/Perceptor/cutouts/augs.py b/src/pytti/Perceptor/cutouts/augs.py new file mode 100644 index 0000000..32873dd --- /dev/null +++ b/src/pytti/Perceptor/cutouts/augs.py @@ -0,0 +1,18 @@ +import kornia.augmentation as K +from torch import nn + + +def pytti_classic(): + return nn.Sequential( + K.RandomHorizontalFlip(p=0.3), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective( + 0.2, + p=0.4, + ), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + K.RandomErasing( + scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 + ), + nn.Identity(), + ) diff --git a/src/pytti/Perceptor/cutouts/samplers.py b/src/pytti/Perceptor/cutouts/samplers.py new file mode 100644 index 0000000..086680d --- /dev/null +++ b/src/pytti/Perceptor/cutouts/samplers.py @@ -0,0 +1,117 @@ +""" +Methods for obtaining cutouts, agnostic to augmentations. + +Cutout choices have a significant impact on the performance of the perceptors and the +overall look of the image. + +The objects defined here probably are only being used in pytti.Perceptor.cutouts.Embedder.HDMultiClipEmbedder, but they +should be sufficiently general for use in notebooks without pyttitools otherwise in use. +""" + +import torch +from typing import Tuple +from torch.nn import functional as F + +PADDING_MODES = { + "mirror": "reflect", + "smear": "replicate", + "wrap": "circular", + "black": "constant", +} + +# ( +# cut_size = 64 +# cut_pow = 0.5 +# noise_fac = 0.0 +# cutn = 8 +# border_mode = "clamp" +# augs = None +# return Cutout( +# cut_size=cut_size, +# cut_pow=cut_pow, +# noise_fac=noise_fac, +# cutn=cutn, +# border_mode=border_mode, +# augs=augs, +# ) + + +def pytti_classic( + # self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + padding, + cutn, + cut_pow, + border_mode, + augs, + noise_fac, + device, +) -> Tuple[list, list, list]: + """ + This is the cutout method that was already in use in the original pytti. + """ + min_size = min(side_x, side_y, cut_size) + max_size = min(side_x, side_y) + paddingx = min(round(side_x * padding), side_x) + paddingy = min(round(side_y * padding), side_y) + cutouts = [] + offsets = [] + sizes = [] + for _ in range(cutn): + # mean is 0.8 + # varience is 0.3 + size = int( + max_size + * ( + torch.zeros( + 1, + ) + .normal_(mean=0.8, std=0.3) + .clip(cut_size / max_size, 1.0) + ** cut_pow + ) + ) + offsetx_max = side_x - size + 1 + offsety_max = side_y - size + 1 + if border_mode == "clamp": + offsetx = torch.clamp( + (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) + .floor() + .int(), + 0, + offsetx_max, + ) + offsety = torch.clamp( + (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) + .floor() + .int(), + 0, + offsety_max, + ) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + else: + px = min(size, paddingx) + py = min(size, paddingy) + offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() + offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() + cutout = input[ + :, + :, + paddingy + offsety : paddingy + offsety + size, + paddingx + offsetx : paddingx + offsetx + size, + ] + cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) + offsets.append( + torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) + ) + sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) + cutouts = augs(torch.cat(cutouts)) + offsets = torch.cat(offsets) + sizes = torch.cat(sizes) + if noise_fac: + facs = cutouts.new_empty([cutn, 1, 1, 1]).uniform_(0, noise_fac) + cutouts.add_(facs * torch.randn_like(cutouts)) + return cutouts, offsets, sizes diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 79ea590..c1e35b9 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -4,6 +4,10 @@ from PIL import Image from pytti.tensor_tools import named_rearrange +# for typing +import torch +from pytti.LossAug.BaseLossClass import Loss + SUPPORTED_MODES = ["L", "RGB", "I", "F"] @@ -25,13 +29,13 @@ def __init__(self, width: int, height: int, pixel_format: str = "RGB"): self.lr = 0.02 self.latent_strength = 0 - def decode_training_tensor(self): + def decode_training_tensor(self) -> torch.Tensor: """ returns a decoded tensor of this image for training """ return self.decode_tensor() - def get_image_tensor(self): + def get_image_tensor(self) -> torch.Tensor: """ optional method: returns an [n x w_i x h_i] tensor representing the local image data those data will be used for animation if afforded @@ -41,26 +45,26 @@ def get_image_tensor(self): def clone(self): raise NotImplementedError - def get_latent_tensor(self, detach=False): + def get_latent_tensor(self, detach=False) -> torch.Tensor: if detach: return self.get_image_tensor().detach() else: return self.get_image_tensor() - def set_image_tensor(self, tensor): + def set_image_tensor(self, tensor: torch.Tensor): """ optional method: accepts an [n x w_i x h_i] tensor representing the local image data those data will be by the animation system """ raise NotImplementedError - def decode_tensor(self): + def decode_tensor(self) -> torch.Tensor: """ returns a decoded tensor of this image """ raise NotImplementedError - def encode_image(self, pil_image): + def encode_image(self, pil_image: Image): """ overwrites this image with the input image pil_image: (Image) input image @@ -79,7 +83,7 @@ def update(self): """ pass - def make_latent(self, pil_image): + def make_latent(self, pil_image: Image) -> torch.Tensor: try: dummy = self.clone() except NotImplementedError: @@ -88,7 +92,7 @@ def make_latent(self, pil_image): return dummy.get_latent_tensor(detach=True) @classmethod - def get_preferred_loss(cls): + def get_preferred_loss(cls) -> Loss: from pytti.LossAug.HSVLossClass import HSVLoss return HSVLoss @@ -96,7 +100,7 @@ def get_preferred_loss(cls): def image_loss(self): return [] - def decode_image(self): + def decode_image(self) -> Image: """ render a PIL Image version of this image """ @@ -112,7 +116,7 @@ def decode_image(self): ) return Image.fromarray(array) - def forward(self): + def forward(self) -> torch.Tensor: """ returns a decoded tensor of this image """ diff --git a/tests/test_image_models.py b/tests/test_image_models.py new file mode 100644 index 0000000..4d779ec --- /dev/null +++ b/tests/test_image_models.py @@ -0,0 +1,113 @@ +import pytest +from loguru import logger +import torch + +import pytti.image_models +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage + + +## simple models ## + + +def test_differentiabble_image_model(): + """ + Test that the DifferentiableImage can be instantiated + """ + logger.debug( + DifferentiableImage.get_preferred_loss() + ) # pytti.LossAug.HSVLossClass.HSVLoss + image = DifferentiableImage( + width=10, + height=10, + ) + logger.debug(image.output_axes) # x y s + logger.debug(image.lr) # 0.02 + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + assert image + + +def test_rgb_image_model(): + """ + Test that the RGBImage can be instantiated + """ + logger.debug(RGBImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + image = RGBImage( + width=10, + height=10, + ) + logger.debug(image.output_axes) # n x y s ... when does n != 1? + logger.debug(image.lr) # 0.02 + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + assert image + + +## more complex models ## + + +def test_ema_image(): + """ + Test that the EMAImage can be instantiated + """ + logger.debug(EMAImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + image = EMAImage( + width=10, + height=10, + tensor=torch.zeros(10, 10), + decay=0.5, + ) + logger.debug(image.output_axes) # x y s + logger.debug(image.lr) # 0.02 + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + assert image + + +def test_pixel_image(): + """ + Test that the PixelImage can be instantiated + """ + logger.debug(PixelImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + image = PixelImage( + width=10, + height=10, + scale=1, + pallet_size=1, + n_pallets=1, + ) + logger.debug(image.output_axes) # n s y x ... uh ok, sure. + logger.debug(image.lr) # 0.02 + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss + assert image + + +# def test_vqgan_image_valid(): +# """ +# Test that the VQGANImage can be instantiated +# """ +# image = VQGANImage( +# width=10, +# height=10, +# model=SOME_VQGAN_MODEL, +# ) +# logger.debug(image.output_axes) +# logger.debug(image.lr) ### self.lr = 0.15 if VQGAN_IS_GUMBEL else 0.1 +# assert image + + +def test_vqgan_image_invalid_string(): + """ + Test that the VQGANImage can be instantiated + """ + logger.debug( + VQGANImage.get_preferred_loss() + ) # pytti.LossAug.LatentLossClass.LatentLoss + with pytest.raises(AttributeError): + image = VQGANImage( + width=10, + height=10, + model="this isn't actually a valid value for this field", + ) + logger.debug(image.output_axes)