Skip to content

image model type hints #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: test
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 35 additions & 76 deletions src/pytti/Perceptor/Embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from torch import nn
from torch.nn import functional as F

import kornia.augmentation as K

# import .cutouts
# import .cutouts as cutouts
# import cutouts

from .cutouts import augs as cutouts_augs
from .cutouts import samplers as cutouts_samplers

PADDING_MODES = {
"mirror": "reflect",
Expand Down Expand Up @@ -43,19 +49,7 @@ def __init__(
self.cut_sizes = [p.visual.input_resolution for p in perceptors]
self.cutn = cutn
self.noise_fac = noise_fac
self.augs = nn.Sequential(
K.RandomHorizontalFlip(p=0.3),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
K.RandomPerspective(
0.2,
p=0.4,
),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
K.RandomErasing(
scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7
),
nn.Identity(),
)
self.augs = cutouts_augs.pytti_classic()
self.input_axes = ("n", "s", "y", "x")
self.output_axes = ("c", "n", "i")
self.perceptors = perceptors
Expand All @@ -64,69 +58,34 @@ def __init__(
self.border_mode = border_mode

def make_cutouts(
self, input: torch.Tensor, side_x, side_y, cut_size, device=DEVICE
self,
input: torch.Tensor,
side_x,
side_y,
cut_size,
####
# padding,
# cutn,
# cut_pow,
# border_mode,
# augs,
# noise_fac,
####
device=DEVICE,
) -> Tuple[list, list, list]:
min_size = min(side_x, side_y, cut_size)
max_size = min(side_x, side_y)
paddingx = min(round(side_x * self.padding), side_x)
paddingy = min(round(side_y * self.padding), side_y)
cutouts = []
offsets = []
sizes = []
for _ in range(self.cutn):
# mean is 0.8
# varience is 0.3
size = int(
max_size
* (
torch.zeros(
1,
)
.normal_(mean=0.8, std=0.3)
.clip(cut_size / max_size, 1.0)
** self.cut_pow
)
)
offsetx_max = side_x - size + 1
offsety_max = side_y - size + 1
if self.border_mode == "clamp":
offsetx = torch.clamp(
(torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx)
.floor()
.int(),
0,
offsetx_max,
)
offsety = torch.clamp(
(torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy)
.floor()
.int(),
0,
offsety_max,
)
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
else:
px = min(size, paddingx)
py = min(size, paddingy)
offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int()
offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int()
cutout = input[
:,
:,
paddingy + offsety : paddingy + offsety + size,
paddingx + offsetx : paddingx + offsetx + size,
]
cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size))
offsets.append(
torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device)
)
sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device))
cutouts = self.augs(torch.cat(cutouts))
offsets = torch.cat(offsets)
sizes = torch.cat(sizes)
if self.noise_fac:
facs = cutouts.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
cutouts.add_(facs * torch.randn_like(cutouts))
cutouts, offsets, sizes = cutouts_samplers.pytti_classic(
input=input,
side_x=side_x,
side_y=side_y,
cut_size=cut_size,
padding=self.padding,
cutn=self.cutn,
cut_pow=self.cut_pow,
border_mode=self.border_mode,
augs=self.augs,
noise_fac=self.noise_fac,
device=DEVICE,
)
return cutouts, offsets, sizes

def forward(
Expand Down
Empty file.
18 changes: 18 additions & 0 deletions src/pytti/Perceptor/cutouts/augs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import kornia.augmentation as K
from torch import nn


def pytti_classic():
return nn.Sequential(
K.RandomHorizontalFlip(p=0.3),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
K.RandomPerspective(
0.2,
p=0.4,
),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
K.RandomErasing(
scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7
),
nn.Identity(),
)
117 changes: 117 additions & 0 deletions src/pytti/Perceptor/cutouts/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Methods for obtaining cutouts, agnostic to augmentations.

Cutout choices have a significant impact on the performance of the perceptors and the
overall look of the image.

The objects defined here probably are only being used in pytti.Perceptor.cutouts.Embedder.HDMultiClipEmbedder, but they
should be sufficiently general for use in notebooks without pyttitools otherwise in use.
"""

import torch
from typing import Tuple
from torch.nn import functional as F

PADDING_MODES = {
"mirror": "reflect",
"smear": "replicate",
"wrap": "circular",
"black": "constant",
}

# (
# cut_size = 64
# cut_pow = 0.5
# noise_fac = 0.0
# cutn = 8
# border_mode = "clamp"
# augs = None
# return Cutout(
# cut_size=cut_size,
# cut_pow=cut_pow,
# noise_fac=noise_fac,
# cutn=cutn,
# border_mode=border_mode,
# augs=augs,
# )


def pytti_classic(
# self,
input: torch.Tensor,
side_x,
side_y,
cut_size,
padding,
cutn,
cut_pow,
border_mode,
augs,
noise_fac,
device,
) -> Tuple[list, list, list]:
"""
This is the cutout method that was already in use in the original pytti.
"""
min_size = min(side_x, side_y, cut_size)
max_size = min(side_x, side_y)
paddingx = min(round(side_x * padding), side_x)
paddingy = min(round(side_y * padding), side_y)
cutouts = []
offsets = []
sizes = []
for _ in range(cutn):
# mean is 0.8
# varience is 0.3
size = int(
max_size
* (
torch.zeros(
1,
)
.normal_(mean=0.8, std=0.3)
.clip(cut_size / max_size, 1.0)
** cut_pow
)
)
offsetx_max = side_x - size + 1
offsety_max = side_y - size + 1
if border_mode == "clamp":
offsetx = torch.clamp(
(torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx)
.floor()
.int(),
0,
offsetx_max,
)
offsety = torch.clamp(
(torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy)
.floor()
.int(),
0,
offsety_max,
)
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
else:
px = min(size, paddingx)
py = min(size, paddingy)
offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int()
offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int()
cutout = input[
:,
:,
paddingy + offsety : paddingy + offsety + size,
paddingx + offsetx : paddingx + offsetx + size,
]
cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size))
offsets.append(
torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device)
)
sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device))
cutouts = augs(torch.cat(cutouts))
offsets = torch.cat(offsets)
sizes = torch.cat(sizes)
if noise_fac:
facs = cutouts.new_empty([cutn, 1, 1, 1]).uniform_(0, noise_fac)
cutouts.add_(facs * torch.randn_like(cutouts))
return cutouts, offsets, sizes
24 changes: 14 additions & 10 deletions src/pytti/image_models/differentiable_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from PIL import Image
from pytti.tensor_tools import named_rearrange

# for typing
import torch
from pytti.LossAug.BaseLossClass import Loss

SUPPORTED_MODES = ["L", "RGB", "I", "F"]


Expand All @@ -25,13 +29,13 @@ def __init__(self, width: int, height: int, pixel_format: str = "RGB"):
self.lr = 0.02
self.latent_strength = 0

def decode_training_tensor(self):
def decode_training_tensor(self) -> torch.Tensor:
"""
returns a decoded tensor of this image for training
"""
return self.decode_tensor()

def get_image_tensor(self):
def get_image_tensor(self) -> torch.Tensor:
"""
optional method: returns an [n x w_i x h_i] tensor representing the local image data
those data will be used for animation if afforded
Expand All @@ -41,26 +45,26 @@ def get_image_tensor(self):
def clone(self):
raise NotImplementedError

def get_latent_tensor(self, detach=False):
def get_latent_tensor(self, detach=False) -> torch.Tensor:
if detach:
return self.get_image_tensor().detach()
else:
return self.get_image_tensor()

def set_image_tensor(self, tensor):
def set_image_tensor(self, tensor: torch.Tensor):
"""
optional method: accepts an [n x w_i x h_i] tensor representing the local image data
those data will be by the animation system
"""
raise NotImplementedError

def decode_tensor(self):
def decode_tensor(self) -> torch.Tensor:
"""
returns a decoded tensor of this image
"""
raise NotImplementedError

def encode_image(self, pil_image):
def encode_image(self, pil_image: Image):
"""
overwrites this image with the input image
pil_image: (Image) input image
Expand All @@ -79,7 +83,7 @@ def update(self):
"""
pass

def make_latent(self, pil_image):
def make_latent(self, pil_image: Image) -> torch.Tensor:
try:
dummy = self.clone()
except NotImplementedError:
Expand All @@ -88,15 +92,15 @@ def make_latent(self, pil_image):
return dummy.get_latent_tensor(detach=True)

@classmethod
def get_preferred_loss(cls):
def get_preferred_loss(cls) -> Loss:
from pytti.LossAug.HSVLossClass import HSVLoss

return HSVLoss

def image_loss(self):
return []

def decode_image(self):
def decode_image(self) -> Image:
"""
render a PIL Image version of this image
"""
Expand All @@ -112,7 +116,7 @@ def decode_image(self):
)
return Image.fromarray(array)

def forward(self):
def forward(self) -> torch.Tensor:
"""
returns a decoded tensor of this image
"""
Expand Down
Loading