diff --git a/src/pytti/LossAug/DepthLossClass.py b/src/pytti/LossAug/DepthLossClass.py index 0428806..501a31d 100644 --- a/src/pytti/LossAug/DepthLossClass.py +++ b/src/pytti/LossAug/DepthLossClass.py @@ -29,7 +29,6 @@ def init_AdaBins(device=None): class DepthLoss(MSELoss): @torch.no_grad() def set_comp(self, pil_image): - # pil_image = pil_image.resize(self.image_shape, Image.LANCZOS) self.comp.set_(DepthLoss.make_comp(pil_image)) if self.use_mask and self.mask.shape[-2:] != self.comp.shape[-2:]: self.mask.set_(TF.resize(self.mask, self.comp.shape[-2:])) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index 8f267b8..a2612b3 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -25,37 +25,13 @@ def __init__( TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape ) + # Comp and mask should live on the image representation, not the loss class. @torch.no_grad() def set_comp(self, pil_image, device=DEVICE): self.pil_image = pil_image self.has_latent = False self.direct_loss.set_comp(pil_image.resize(self.image_shape, Image.LANCZOS)) - @classmethod - @vram_usage_mode("Latent Image Loss") - @torch.no_grad() - def TargetImage( - cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE - ): - text, weight, stop = parse( - prompt_string, r"(? Loss: - """ - Given a weight name, weight, name, image, and target image, returns a loss object + loss_augs.extend(optical_flows) - :param weight_name: The name of the loss function - :param weight: The weight of the loss - :param name: The name of the loss function - :param img: The image to be optimized - :param pil_target: The target image - :return: The loss function. - """ - Loss = self.loss_factory - out = Loss.TargetImage( - f"{self.weight_category} {self.name}:{self.weight}", - self.img.image_shape, - self.pil_target, - ) - out.set_enabled(self.pil_target is not None) - return out + return img, loss_augs, optical_flows def _standardize_null(weight): @@ -198,189 +177,3 @@ def _standardize_null(weight): if float(weight) == 0: weight = "" return weight - - -class LossConfigurator: - """ - Groups together procedures for initializing losses - """ - - def __init__( - self, - init_image_pil: Image.Image, - restore: bool, - img: PixelImage, - embedder, - prompts, - # params, - ######## - direct_image_prompts, - semantic_stabilization_weight, - init_image, - semantic_init_weight, - animation_mode, - flow_stabilization_weight, - flow_long_term_samples, - smoothing_weight, - ########### - direct_init_weight, - direct_stabilization_weight, - depth_stabilization_weight, - edge_stabilization_weight, - ): - self.init_image_pil = init_image_pil - self.img = img - self.embedder = embedder - self.prompts = prompts - - self.init_augs = [] - self.loss_augs = [] - self.optical_flows = [] - self.last_frame_semantic = None - self.semantic_init_prompt = None - - # self.params = params - self.restore = restore - - ### params - self.direct_image_prompts = direct_image_prompts - self.semantic_stabilization_weight = _standardize_null( - semantic_stabilization_weight - ) - self.init_image = init_image - self.semantic_init_weight = _standardize_null(semantic_init_weight) - self.animation_mode = animation_mode - self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight) - self.flow_long_term_samples = flow_long_term_samples - self.smoothing_weight = _standardize_null(smoothing_weight) - - ###### - self.direct_init_weight = _standardize_null(direct_init_weight) - self.direct_stabilization_weight = _standardize_null( - direct_stabilization_weight - ) - self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight) - self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight) - - def process_direct_image_prompts(self): - # prompt parsing shouldn't go here. - self.loss_augs.extend( - type(self.img) - .get_preferred_loss() - .TargetImage(p.strip(), self.img.image_shape, is_path=True) - for p in self.direct_image_prompts.split("|") - if p.strip() - ) - - def process_semantic_stabilization(self): - last_frame_pil = self.init_image_pil - if not last_frame_pil: - last_frame_pil = self.img.decode_image() - self.last_frame_semantic = parse_prompt( - self.embedder, - f"stabilization:{self.semantic_stabilization_weight}", - last_frame_pil, - ) - self.last_frame_semantic.set_enabled(self.init_image_pil is not None) - for scene in self.prompts: - scene.append(self.last_frame_semantic) - - def configure_losses(self): - if self.init_image_pil is not None: - self.configure_init_image() - self.process_direct_image_prompts() - if self.semantic_stabilization_weight: - self.process_semantic_stabilization() - self.configure_stabilization_augs() - self.configure_optical_flows() - self.configure_aesthetic_losses() - - return ( - self.loss_augs, - self.init_augs, - self.stabilization_augs, - self.optical_flows, - self.semantic_init_prompt, - self.last_frame_semantic, - self.img, - ) - - def configure_init_image(self): - - if not self.restore: - # move these logging statements into .encode_image() - logger.info("Encoding image...") - self.img.encode_image(self.init_image_pil) - logger.info("Encoded Image:") - # pretty sure this assumes we're in a notebook - display.display(self.img.decode_image()) - - ## wrap this for the flexibility that the loop is pretending to provide... - # set up init image prompt - if self.direct_init_weight: - init_aug = LossBuilder( - "direct_init_weight", - self.direct_init_weight, - f"init image ({self.init_image})", - self.img, - self.init_image_pil, - ).build_loss() - self.loss_augs.append(init_aug) - self.init_augs.append(init_aug) - - ######## - if self.semantic_init_weight: - self.semantic_init_prompt = parse_prompt( - self.embedder, - f"init image [{self.init_image}]:{self.semantic_init_weight}", - self.init_image_pil, - ) - self.prompts[0].append(self.semantic_init_prompt) - - # stabilization - def configure_stabilization_augs(self): - d_augs = { - "direct_stabilization_weight": self.direct_stabilization_weight, - "depth_stabilization_weight": self.depth_stabilization_weight, - "edge_stabilization_weight": self.edge_stabilization_weight, - } - stabilization_augs = [ - LossBuilder( - k, v, "stabilization", self.img, self.init_image_pil - ).build_loss() - for k, v in d_augs.items() - if v - ] - self.stabilization_augs = stabilization_augs - self.loss_augs.extend(stabilization_augs) - - def configure_optical_flows(self): - optical_flows = None - - if self.animation_mode == "Video Source": - if self.flow_stabilization_weight == "": - self.flow_stabilization_weight = "0" - optical_flows = [ - OpticalFlowLoss.TargetImage( - f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}", - self.img.image_shape, - ) - for i in range(self.flow_long_term_samples + 1) - ] - - elif self.animation_mode == "3D" and self.flow_stabilization_weight: - optical_flows = [ - TargetFlowLoss.TargetImage( - f"optical flow stabilization:{self.flow_stabilization_weight}", - self.img.image_shape, - ) - ] - - if optical_flows is not None: - for optical_flow in optical_flows: - optical_flow.set_enabled(False) - self.loss_augs.extend(optical_flows) - - def configure_aesthetic_losses(self): - if self.smoothing_weight != 0: - self.loss_augs.append(TVLoss(weight=self.smoothing_weight)) diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index ff1e5dc..8a1f33e 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -4,9 +4,9 @@ from torch.nn import functional as F from pytti.LossAug.BaseLossClass import Loss -# from pytti.Notebook import Rotoscoper from pytti.rotoscoper import Rotoscoper -from pytti import fetch, parse, vram_usage_mode +from pytti import fetch, vram_usage_mode +from pytti.eval_tools import parse_subprompt import torch @@ -30,37 +30,6 @@ def __init__( self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device)) self.use_mask = False - @classmethod - @vram_usage_mode("Loss Augs") - @torch.no_grad() - def TargetImage( - cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None - ): - # Why is this prompt parsing stuff here? Deprecate in favor of centralized - # parsing functions (if feasible) - text, weight, stop = parse( - prompt_string, r"(?