diff --git a/blehconfig.example.yaml b/blehconfig.example.yaml index 3594ad6..510dfc2 100644 --- a/blehconfig.example.yaml +++ b/blehconfig.example.yaml @@ -38,6 +38,12 @@ betterTaesdPreviews: # Setting it to "vae" will use whatever dtype ComfyUI is set to use for VAE. preview_dtype: null + # Uses non-blocking transfers for previews when the device supports it. + # Not recommended as it is extremely likely to corrupt previews, especially if the previewer + # is relatively slow or the latent is large (video models, Chroma Radiance). However, + # it might decrease the performance impact of previewing. + preview_non_blocking: false + # Allows skipping upscale layers in the TAESD model, may increase performance when previewing large images or batches. # May be set to -1 (conservative) or -2 (aggressive) to automatically calculate how many to skip. See README.md for details. skip_upscale_layers: 0 diff --git a/py/better_previews/base.py b/py/better_previews/base.py index 04b3897..f08df3c 100644 --- a/py/better_previews/base.py +++ b/py/better_previews/base.py @@ -4,43 +4,86 @@ from comfy import latent_formats +from .tae_vid import TAEVid, TAEVidBase, TAEVidLTX2 + if TYPE_CHECKING: from pathlib import Path class VideoModelInfo(NamedTuple): + name: str latent_format: latent_formats.LatentFormat - fps: int = 24 + fps: int | float = 24 temporal_compression: int = 8 + temporal_layers: int = 0 patch_size: int = 1 + nested_tensor_index: int = 0 tae_model: str | Path | None = None + tae_class: TAEVidBase | None = TAEVid VIDEO_FORMATS = { - "mochi": VideoModelInfo( - latent_formats.Mochi, - temporal_compression=6, - tae_model="taem1.pth", - ), - "hunyuanvideo": VideoModelInfo( - latent_formats.HunyuanVideo, - temporal_compression=4, - tae_model="taehv.pth", - ), - "cosmos1cv8x8x8": VideoModelInfo(latent_formats.Cosmos1CV8x8x8), - "wan21": VideoModelInfo( - latent_formats.Wan21, - fps=16, - temporal_compression=4, - tae_model="taew2_1.pth", - ), - "wan22": VideoModelInfo( - latent_formats.Wan22, - fps=24, - temporal_compression=4, - patch_size=2, - tae_model="taew2_2.pth", - ), + vmi.name: vmi + for vmi in ( + VideoModelInfo( + "mochi", + latent_formats.Mochi, + temporal_compression=6, + tae_model="taem1.pth", + ), + VideoModelInfo( + "hunyuanvideo", + latent_formats.HunyuanVideo, + temporal_compression=4, + tae_model="taehv.pth", + ), + VideoModelInfo( + "hunyuanvideo15", + latent_formats.HunyuanVideo15, + temporal_compression=4, + patch_size=2, + tae_model="taehv1_5.pth", + ), + VideoModelInfo( + "cosmos1cv8x8x8", + latent_formats.Cosmos1CV8x8x8, + ), + VideoModelInfo( + "wan21", + latent_formats.Wan21, + fps=16, + temporal_compression=4, + temporal_layers=2, + tae_model="taew2_1.pth", + ), + VideoModelInfo( + "wan22", + latent_formats.Wan22, + fps=24, + temporal_compression=4, + temporal_layers=2, + patch_size=2, + tae_model="taew2_2.pth", + ), + VideoModelInfo( + "ltxv", + latent_formats.LTXV, + fps=24, + patch_size=4, + temporal_layers=3, + tae_model="taeltx_2.pth", + tae_class=TAEVidLTX2, + ), + VideoModelInfo( + "ltxav", + latent_formats.LTXV, + fps=24, + patch_size=4, + temporal_layers=3, + tae_model="taeltx_2.pth", + tae_class=TAEVidLTX2, + ), + ) } diff --git a/py/better_previews/previewer.py b/py/better_previews/previewer.py index 45c2369..4f90575 100644 --- a/py/better_previews/previewer.py +++ b/py/better_previews/previewer.py @@ -1,17 +1,22 @@ from __future__ import annotations import math +from io import BytesIO from time import time from typing import TYPE_CHECKING +import comfy.utils as comfy_utils import folder_paths import latent_preview import torch +from aiohttp import web +from comfy import latent_formats from comfy.cli_args import LatentPreviewMethod from comfy.cli_args import args as comfy_args from comfy.model_management import device_supports_non_blocking, vae_dtype from comfy.taesd.taesd import TAESD from PIL import Image +from server import PromptServer from tqdm import tqdm from ..settings import SETTINGS # noqa: TID252 @@ -19,14 +24,22 @@ from .tae_vid import TAEVid if TYPE_CHECKING: + from collections.abc import Callable + import numpy as np from comfy import latent_formats + +class BlehPreviewerState: + last_latent_shapes: tuple | None = None + fps_override: float | None = None + + +PREVIEWER_STATE = BlehPreviewerState() + _ORIG_PREVIEWER = latent_preview.TAESDPreviewerImpl _ORIG_GET_PREVIEWER = latent_preview.get_previewer -LAST_LATENT_FORMAT = None - # Referenced from https://github.com/learnables/learn2learn/blob/752200384c3ca8caeb8487b5dd1afd6568e8ec01/learn2learn/utils/__init__.py#L51 def clone_module(module, *, memo: dict | None = None) -> torch.nn.Module: @@ -81,21 +94,73 @@ def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)): ) +class LastPreview: + image: bytes | None + stamp: float | None + content_type: str | None + + dum_page = """ + + + bleh preview + + + + + + + + + """ + + def __init__(self): + self.image = None + self.stamp = None + self.content_type = None + + def update( + self, *, image_bytes: bytes, content_type: str, stamp: float | None = None + ): + self.image = image_bytes + self.stamp = time() if stamp is None else stamp + self.content_type = content_type + + async def __call__(self, request: web.Request): + if request.path.endswith(".html"): + return web.Response(body=self.dum_page, content_type="text/html") + if self.image is None or self.content_type is None: + raise web.HTTPNotFound(reason="OHNO") + return web.Response(body=self.image, content_type=self.content_type) + + +LAST_PREVIEW = LastPreview() +PromptServer.instance.routes.get("/bleh/last_preview")(LAST_PREVIEW) +PromptServer.instance.routes.get("/bleh/last_preview.html")(LAST_PREVIEW) + + class ImageWrapper: - def __init__(self, frames: tuple, frame_duration: int): - self._frames = frames + def __init__(self, frames: tuple | Image, frame_duration: int = 250): + self._frames = (frames,) if not isinstance(frames, (tuple, list)) else frames self._frame_duration = frame_duration def save(self, fp, format: str | None, **kwargs: dict): # noqa: A002 - if len(self._frames) == 1: + if len(self._frames) > 1: + kwargs |= { + "loop": 0, + "save_all": True, + "append_images": self._frames[1:], + "duration": self._frame_duration, + } + format = "webp" + if not SETTINGS.btp_publish_last_preview: return self._frames[0].save(fp, format, **kwargs) - kwargs |= { - "loop": 0, - "save_all": True, - "append_images": self._frames[1:], - "duration": self._frame_duration, - } - return self._frames[0].save(fp, "webp", **kwargs) + buf = BytesIO() + result = self._frames[0].save(buf, format, **kwargs) + # FIXME + image_bytes = buf.getvalue() + LAST_PREVIEW.update(image_bytes=image_bytes, content_type=f"image/{format}") + fp.write(image_bytes) + return result def resize(self, *args: list, **kwargs: dict) -> ImageWrapper: return ImageWrapper( @@ -156,18 +221,26 @@ def __init__( *, dtype: torch.dtype, device: torch.device, + height_factor: int = 4, + width_factor: int = 1, normalize_dims: tuple = (-1,), ): super().__init__() self.dtype = dtype self.device = device self.normalize_dims = normalize_dims + self.height_factor = height_factor + self.width_factor = width_factor @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: batch, temporal = x.shape[0], x.shape[-1] x = normalize_to_scale(x, 0.0, 1.0, dim=self.normalize_dims) * 255.0 x = x.reshape(batch, -1, temporal) + if self.height_factor > 1: + x = x.repeat_interleave(dim=1, repeats=self.height_factor) + if self.width_factor > 1: + x = x.repeat_interleave(dim=1, repeats=self.width_factor) return x[..., None].expand(*x.shape, 3) @@ -179,7 +252,10 @@ def __init__( latent_format: latent_formats.LatentFormat, vid_info: VideoModelInfo | None = None, ): - self.latent_format = latent_format + self.orig_latent_format = latent_format + self.latent_format = ( + latent_format if vid_info is None else vid_info.latent_format + ) self.latent_format_name = ( "unknown" if latent_format is None @@ -406,7 +482,8 @@ def prepare_previewer( return x0.to( device=pdevice, dtype=pdtype, - non_blocking=device_supports_non_blocking(x0.device), + non_blocking=SETTINGS.btp_preview_non_blocking + and device_supports_non_blocking(x0.device), ) def _decode_latent_taevid(self, x0: torch.Tensor) -> tuple[torch.Tensor, int, int]: @@ -473,12 +550,25 @@ def calc_cols_rows( rows = math.ceil(batch_size / cols) return cols, rows - @classmethod - def decoded_to_animation(cls, samples: np.ndarray) -> ImageWrapper: + def decoded_to_animation( + self, + samples: np.ndarray, + video_frames: int, + ) -> ImageWrapper: batch = samples.shape[0] + fps_override = PREVIEWER_STATE.fps_override + if self.vid_info is None or not video_frames: + frame_duration = 250 if not fps_override else 1000 / fps_override + else: + time_factor = self.vid_info.temporal_compression / max( + 1, + self.previewer_model.t_upscale, + ) + ms_frame = 1000.0 / (fps_override or self.vid_info.fps) + frame_duration = ms_frame * time_factor return ImageWrapper( tuple(Image.fromarray(samples[idx]) for idx in range(batch)), - frame_duration=250, + frame_duration=max(1, int(frame_duration)), ) def decoded_to_image( @@ -487,22 +577,23 @@ def decoded_to_image( cols: int, rows: int, *, - is_video=False, + video_frames: int = 0, ) -> Image | ImageWrapper: batch, (height, width) = samples.shape[0], samples.shape[-3:-1] samples = samples.to( device="cpu", dtype=torch.uint8, - non_blocking=device_supports_non_blocking(samples.device), + non_blocking=SETTINGS.btp_preview_non_blocking + and device_supports_non_blocking(samples.device), ).numpy() if batch == 1: - self.cached = Image.fromarray(samples[0]) + self.cached = ImageWrapper((Image.fromarray(samples[0]),)) return self.cached if SETTINGS.btp_animate_preview == "both" or ( - is_video, + video_frames != 0, SETTINGS.btp_animate_preview, ) in {(True, "video"), (False, "batch")}: - return self.decoded_to_animation(samples) + return self.decoded_to_animation(samples, video_frames=video_frames) cols, rows = self.calc_cols_rows(batch, width, height) img_size = (width * cols, height * rows) if self.cached is not None and self.cached.size == img_size: @@ -514,7 +605,7 @@ def decoded_to_image( Image.fromarray(samples[idx]), box=((idx % cols) * width, ((idx // cols) % rows) * height), ) - return result + return ImageWrapper((result,)) @torch.no_grad() def init_fallback_previewer(self, device: torch.device, dtype: torch.dtype) -> bool: @@ -526,7 +617,7 @@ def init_fallback_previewer(self, device: torch.device, dtype: torch.dtype) -> b and self.fallback_previewer_model.device == device ): return True - if self.latent_format_name == "aceaudio": + if self.latent_format_name in {"aceaudio", "aceaudio15"}: self.fallback_previewer_model = ACEStepsPreviewerModel( device=device, dtype=dtype, @@ -563,11 +654,52 @@ def fallback_previewer(self, x0: torch.Tensor, *, quiet=False) -> Image: except torch.OutOfMemoryError: return self.blank + def ensure_x0_shape(self, x0: torch.Tensor) -> tuple[torch.Tensor, bool]: # noqa: PLR0911 + expected_channels = self.latent_format.latent_channels + expected_ndim = 2 + self.latent_format.latent_dimensions + if x0.shape[0] == 0: + return x0, False + if self.latent_format_name == "aceaudio15" and x0.ndim == expected_ndim + 1: + expected_ndim += 1 + if ( + x0.ndim > 1 + and x0.ndim == expected_ndim + and x0.shape[1] == expected_channels + ): + return x0, True + last_shapes = PREVIEWER_STATE.last_latent_shapes + if not last_shapes or not hasattr(comfy_utils, "unpack_latents"): + return x0, False + last_numel = sum(math.prod(tshape) for tshape in last_shapes) + if last_numel != x0.numel(): + return x0, False + nest_idx = self.vid_info.nested_tensor_index if self.vid_info else 0 + target_shape = None if len(last_shapes) <= nest_idx else last_shapes[nest_idx] + if ( + # Have to have a nest shape + target_shape is None + # with at least a channel dimension, + or len(target_shape) < 2 + # with the expected number of dims, + or len(target_shape) != expected_ndim + # And the correct number of channels. + or target_shape[1] != expected_channels + ): + return x0, False + unpacked_latents = comfy_utils.unpack_latents(x0, last_shapes) + target_latent = ( + None if len(unpacked_latents) <= nest_idx else unpacked_latents[nest_idx] + ) + if target_latent is None or target_latent.shape != target_shape: + return x0, False + return target_latent.reshape(*target_shape), True + def decode_latent_to_preview(self, x0: torch.Tensor) -> Image: if self.check_use_cached(): return self.cached - if x0.shape[0] == 0: - return self.blank # Shouldn't actually be possible. + x0, can_preview = self.ensure_x0_shape(x0) + if not can_preview: + return self.blank if (self.oom_count and not self.oom_retry) or self.previewer_model is None: return self.fallback_previewer(x0, quiet=True) is_video = x0.ndim == 5 @@ -579,7 +711,10 @@ def decode_latent_to_preview(self, x0: torch.Tensor) -> Image: if is_video else self._decode_latent_taesd(x0) ) - result = self.decoded_to_image(*dargs, is_video=is_video) + result = self.decoded_to_image( + *dargs, + video_frames=x0.shape[2] if is_video else 0, + ) except torch.OutOfMemoryError: used_fallback = True result = self.fallback_previewer(x0) @@ -601,7 +736,11 @@ def orig_get_previewer(): preview_method = comfy_args.preview_method - if preview_method == LatentPreviewMethod.NoPreviews: + if preview_method not in { + LatentPreviewMethod.TAESD, + LatentPreviewMethod.Auto, + LatentPreviewMethod.Latent2RGB, + }: return orig_get_previewer() format_name = latent_format.__class__.__name__.lower() @@ -611,29 +750,34 @@ def orig_get_previewer(): or (SETTINGS.btp_whitelist and format_name not in SETTINGS.btp_whitelist) ): return orig_get_previewer() + if format_name in {"aceaudio", "aceaudio15"}: + return BetterPreviewer(latent_format=latent_format) + vid_info = VIDEO_FORMATS.get(format_name) + eff_latent_format = ( + vid_info.latent_format if vid_info is not None else latent_format + ) tae_model = None if preview_method in {LatentPreviewMethod.TAESD, LatentPreviewMethod.Auto}: - vid_info = VIDEO_FORMATS.get(format_name) - if vid_info is not None and vid_info.tae_model is not None: + if ( + vid_info is not None + and vid_info.tae_model is not None + and vid_info.tae_class is not None + ): tae_model_path = folder_paths.get_full_path( "vae_approx", vid_info.tae_model, ) - tupscale_limit = SETTINGS.btp_video_temporal_upscale_level - decoder_time_upscale = tuple( - i < tupscale_limit for i in range(TAEVid.temporal_upscale_blocks) - ) tae_model = ( - TAEVid( + vid_info.tae_class( checkpoint_path=tae_model_path, vmi=vid_info, device=torch.device("cpu"), - decoder_time_upscale=decoder_time_upscale, + decoder_time_upscale_level=SETTINGS.btp_video_temporal_upscale_level, ) if tae_model_path is not None else None ) - if tae_model is None and latent_format.taesd_decoder_name is not None: + elif vid_info is None and latent_format.taesd_decoder_name is not None: taesd_path = folder_paths.get_full_path( "vae_approx", f"{latent_format.taesd_decoder_name}.pth", @@ -653,7 +797,8 @@ def orig_get_previewer(): latent_format=latent_format, vid_info=vid_info, ) - if format_name == "aceaudio" or latent_format.latent_rgb_factors is not None: + # Using Latent2RGB either via setting or because no preview model. + if eff_latent_format.latent_rgb_factors is not None: return BetterPreviewer(latent_format=latent_format) return orig_get_previewer() diff --git a/py/better_previews/tae_vid.py b/py/better_previews/tae_vid.py index 325f2f0..f676f11 100644 --- a/py/better_previews/tae_vid.py +++ b/py/better_previews/tae_vid.py @@ -53,6 +53,10 @@ def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor: return self.act(self.conv(torch.cat((x, past), 1)) + self.skip(x)) +def make_memblocks(n: int, *, count: int = 3) -> tuple[MemBlock, ...]: + return tuple(MemBlock(n, n) for _ in range(count)) + + class TPool(nn.Module): def __init__(self, n_f, stride): super().__init__() @@ -183,8 +187,8 @@ def apply(self, x: torch.Tensor, *, show_progress=False) -> torch.Tensor: return torch.stack(out, 1) -class TAEVid(nn.Module): - temporal_upscale_blocks = 2 +class TAEVidBase(nn.Module): + temporal_upscale_blocks = 3 spatial_upscale_blocks = 3 _nf = (256, 128, 64, 64) @@ -195,61 +199,33 @@ def __init__( vmi: VideoModelInfo, image_channels: int = 3, device="cpu", - decoder_time_upscale=(True, True), - decoder_space_upscale=(True, True, True), + encoder_time_downscale_level: int = 3, + decoder_time_upscale_level: int = 3, + decoder_space_upscale_level: int = 3, ): - n_f = self._nf super().__init__() self.vmi = vmi - self.latent_channels = vmi.latent_format.latent_channels self.image_channels = image_channels + self.latent_channels = vmi.latent_format.latent_channels self.patch_size = vmi.patch_size - self.encoder = nn.Sequential( - conv(image_channels * self.patch_size**2, 64), - nn.ReLU(inplace=True), - TPool(64, 2), - conv(64, 64, stride=2, bias=False), - MemBlock(64, 64), - MemBlock(64, 64), - MemBlock(64, 64), - TPool(64, 2), - conv(64, 64, stride=2, bias=False), - MemBlock(64, 64), - MemBlock(64, 64), - MemBlock(64, 64), - TPool(64, 1), - conv(64, 64, stride=2, bias=False), - MemBlock(64, 64), - MemBlock(64, 64), - MemBlock(64, 64), - conv(64, vmi.latent_format.latent_channels), + encoder_time_downscale = self._get_encoder_flags( + time_level=encoder_time_downscale_level, ) - self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 - self.decoder = nn.Sequential( - Clamp(), - conv(vmi.latent_format.latent_channels, n_f[0]), - nn.ReLU(inplace=True), - MemBlock(n_f[0], n_f[0]), - MemBlock(n_f[0], n_f[0]), - MemBlock(n_f[0], n_f[0]), - nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), - TGrow(n_f[0], 1), - conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1]), - MemBlock(n_f[1], n_f[1]), - MemBlock(n_f[1], n_f[1]), - nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), - TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), - conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2]), - MemBlock(n_f[2], n_f[2]), - MemBlock(n_f[2], n_f[2]), - nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), - TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), - conv(n_f[2], n_f[3], bias=False), - nn.ReLU(inplace=True), - conv(n_f[3], image_channels * self.patch_size**2), + decoder_time_upscale, decoder_space_upscale = self._get_decoder_flags( + time_level=decoder_time_upscale_level, + space_level=decoder_space_upscale_level, + ) + encoder_strides = tuple(1 + int(flag) for flag in encoder_time_downscale) + decoder_strides = tuple(1 + int(flag) for flag in decoder_time_upscale) + decoder_scale_factors = tuple(1 + int(flag) for flag in decoder_space_upscale) + self.encoder = self._build_encoder(strides=encoder_strides) + self.decoder = self._build_decoder( + strides=decoder_strides, + scale_factors=decoder_scale_factors, ) + self.t_upscale = 2 ** sum(decoder_time_upscale) + self.t_downscale = 2 ** sum(encoder_time_downscale) + self.frames_to_trim = self.t_upscale - 1 if checkpoint_path is None: return self.load_state_dict( @@ -258,6 +234,66 @@ def __init__( ), ) + def _get_decoder_flags( + self, + *, + time_level: int = 3, + space_level: int = 3, + ) -> tuple[tuple[bool, ...], tuple[bool, ...]]: + decoder_time_upscale = tuple(i < time_level for i in range(3)) + decoder_space_upscale = tuple(i < space_level for i in range(3)) + return decoder_time_upscale, decoder_space_upscale + + def _get_encoder_flags( + self, + *, + time_level: int = 3, + ) -> tuple[bool, ...]: + return tuple(i < time_level for i in range(3)) + + def _build_decoder( + self, + *, + strides: tuple[int, ...], + scale_factors: tuple[int, ...], + ) -> nn.Module: + n_f = self._nf + return nn.Sequential( + Clamp(), + conv(self.latent_channels, n_f[0]), + nn.ReLU(inplace=True), + *make_memblocks(n_f[0]), + nn.Upsample(scale_factor=scale_factors[0]), + TGrow(n_f[0], strides[0]), + conv(n_f[0], n_f[1], bias=False), + *make_memblocks(n_f[1]), + nn.Upsample(scale_factor=scale_factors[1]), + TGrow(n_f[1], strides[1]), + conv(n_f[1], n_f[2], bias=False), + *make_memblocks(n_f[2]), + nn.Upsample(scale_factor=scale_factors[2]), + TGrow(n_f[2], strides[2]), + conv(n_f[2], n_f[3], bias=False), + nn.ReLU(inplace=True), + conv(n_f[3], self.image_channels * self.patch_size**2), + ) + + def _build_encoder(self, *, strides: tuple[int, ...]) -> nn.Module: + return nn.Sequential( + conv(self.image_channels * self.patch_size**2, 64), + nn.ReLU(inplace=True), + TPool(64, strides[0]), + conv(64, 64, stride=2, bias=False), + *make_memblocks(64), + TPool(64, strides[1]), + conv(64, 64, stride=2, bias=False), + *make_memblocks(64), + TPool(64, strides[2]), + conv(64, 64, stride=2, bias=False), + *make_memblocks(64), + conv(64, self.latent_channels), + ) + def patch_tgrow_layers(self, sd: dict) -> dict: new_sd = self.state_dict() for i, layer in enumerate(self.decoder): @@ -304,8 +340,15 @@ def apply( show_progress=False, ) -> torch.Tensor: model = self.decoder if decode else self.encoder - if not decode and self.vmi.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) + if not decode: + if self.vmi.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + # Pad handling copied from https://github.com/madebyollin + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of self.t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) if parallel: result = self.apply_parallel(x, model, show_progress=show_progress) else: @@ -324,3 +367,45 @@ def encode(self, *args: list, **kwargs: dict) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: return self.c(x) + + +class TAEVid(TAEVidBase): + def _get_decoder_flags( + self, + *, + time_level: int = 3, + space_level: int = 3, + ) -> tuple[tuple[bool, ...], tuple[bool, ...]]: + tu, su = super()._get_decoder_flags( + time_level=time_level, + space_level=space_level, + ) + return (False, *tu[:2]), su + + def _get_encoder_flags( + self, + *, + time_level: int = 3, + ) -> tuple[bool, ...]: + return (*super()._get_encoder_flags(time_level=time_level)[:2], False) + + +class TAEVidLTX2(TAEVidBase): + def _get_decoder_flags( + self, + *, + time_level: int = 3, + space_level: int = 3, + ) -> tuple[tuple[bool, ...], tuple[bool, ...]]: + _tu, su = super()._get_decoder_flags( + time_level=time_level, + space_level=space_level, + ) + return (True, True, True), su + + def _get_encoder_flags( + self, + *, + time_level: int = 3, # noqa: ARG002 + ) -> tuple[bool, ...]: + return (True, True, True) diff --git a/py/latent_utils.py b/py/latent_utils.py index a75364d..55c460d 100644 --- a/py/latent_utils.py +++ b/py/latent_utils.py @@ -5,13 +5,20 @@ import math import os from functools import partial -from typing import ClassVar +from tokenize import triple_quoted +from typing import TYPE_CHECKING, Any import kornia.filters as kf import numpy as np import torch import torch.nn.functional as nnf from torch import FloatTensor, LongTensor, fft +from tqdm import tqdm + +from . import wavelet_functions as wavef + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence OVERRIDE_NO_SCALE = "COMFYUI_BLEH_OVERRIDE_NO_SCALE" in os.environ USE_ORIG_NORMALIZE = "COMFYUI_BLEH_ORIG_NORMALIZE" in os.environ @@ -798,6 +805,347 @@ def blend_blend( ) +def ortho_blend( + a: torch.Tensor, + b: torch.Tensor, + t: torch.Tensor, + *, + blend_mode: str | Callable | None = None, + proj_scale: float = -1.0, + ortho_scale: float = 1.0, + start_dim: int = 1, + end_dim: int = -1, + rescale_limit: float = 0.0, + # a, b, blend or None + rescale_result_mode: str | None = None, + # When rescale_target mode is blend, will use blend_mode if None. + rescale_result_blend_mode: str | Callable | None = None, + # LERP if None. + dyn_result_blend_mode: str | Callable | None = None, + dyn_ortho_mode: bool = False, + dyn_min_scale: float = 0.0, + dyn_max_scale: float = 1.0, + # Can only be used when the flattened tensor has 4 dimensions left. + smooth_factor_kernel_size: int | tuple[int, int] = 0, + ortho_verbose: bool = False, + eps: float = 1e-06, +) -> torch.Tensor: + orig_shape = a.shape + ndim = a.ndim + if start_dim < 0: + start_dim = max(0, min(ndim + start_dim, ndim - 1)) + if end_dim < 0: + end_dim = max(0, min(ndim + end_dim, ndim - 1)) + if start_dim > end_dim: + start_dim, end_dim = end_dim, start_dim + sync_t = t.ndim == ndim + if sync_t: + t = t.flatten(start_dim=start_dim, end_dim=end_dim) + a = a.flatten(start_dim=start_dim, end_dim=end_dim) + b = b.flatten(start_dim=start_dim, end_dim=end_dim) + if end_dim != ndim - 1: + a = a.movedim(start_dim, -1) + b = b.movedim(start_dim, -1) + if sync_t: + t = t.movedim(start_dim, -1) + if start_dim == 0: + a = a.unsqueeze(0) + b = b.unsqueeze(0) + if sync_t: + t = t.unsqueeze(0) + b_normed = b.norm(dim=-1, keepdim=True) if rescale_limit else None + dot_ba = (b * a).sum(dim=-1, keepdim=True) + dot_aa = (a**2).sum(dim=-1, keepdim=True) + proj = (dot_ba / (dot_aa + eps)) * a + proj *= proj_scale + b_ortho = proj.add_(b if ortho_scale == 1.0 else b * ortho_scale) + if b_normed is not None: + rescale_limit = abs(rescale_limit) + if rescale_limit == 1: + rescale_limit += eps + b_ortho_normed = b_ortho.norm(dim=-1, keepdim=True) + b_ortho_normed += eps + b_normed /= b_ortho_normed + b_normed = b_normed.clamp_(-rescale_limit, rescale_limit) + b_ortho *= b_normed + if blend_mode is None: + + def blend_function(a, b, t): + return (b * t).add_(a) + else: + blend_function = ( + BLENDING_MODES[blend_mode] if isinstance(blend_mode, str) else blend_mode + ) + ortho_result = blend_function(a, b_ortho, t) + if rescale_result_mode == "a": + rescale_result_target = a + elif rescale_result_mode == "b": + rescale_result_target = b + elif rescale_result_mode == "blend": + rr_blend_function = ( + blend_function + if rescale_result_blend_mode is None + else ( + BLENDING_MODES[rescale_result_blend_mode] + if isinstance(rescale_result_blend_mode, str) + else rescale_result_blend_mode + ) + ) + rescale_result_target = rr_blend_function(a, b, t) + else: + rescale_result_target = None + if rescale_result_target is not None: + result_norm = ortho_result.norm(dim=-1, keepdim=True).add_(eps) + target_norm = rescale_result_target.norm(dim=-1, keepdim=True) + target_norm /= result_norm + ortho_result *= target_norm + if b_normed is not None and dyn_ortho_mode: + vanilla_result = ( + rr_blend_function(a, b, t) + if rescale_result_mode != "blend" + else rescale_result_target + ) + dyn_blend_function = ( + torch.lerp + if dyn_result_blend_mode is None + else ( + BLENDING_MODES[dyn_result_blend_mode] + if isinstance(dyn_result_blend_mode, str) + else dyn_result_blend_mode + ) + ) + ortho_factor = ( + (1.0 - ((b_normed - 1.0) / (rescale_limit - 1.0)).clamp_(0.0, 1.0)) + .add_(dyn_min_scale) + .mul_(dyn_max_scale - dyn_min_scale) + ) + if smooth_factor_kernel_size != 0: + if ortho_factor.ndim < 4: + raise ValueError( + f"Can't use smooth_factor_kernel_size when ortho_factor has less than 4 dimensions. It has shape: {ortho_factor.shape}", + ) + ortho_factor = ( + torch.nn.functional.avg_pool2d( + ortho_factor.movedim(-1, -3), + kernel_size=smooth_factor_kernel_size, + stride=1, + padding=1, + ) + .movedim(-3, -1) + .clamp_(dyn_min_scale, dyn_max_scale) + ) + ortho_result = dyn_blend_function(vanilla_result, ortho_result, ortho_factor) + if ortho_verbose: + tqdm.write( + f"ORTHO BLEND: b_norm min/max={b_normed.aminmax()}, avg: {ortho_factor.mean().item():.5f}, min: {ortho_factor.min().item():.5f}, max: {ortho_factor.max().item():.5f}", + ) + if end_dim != ndim - 1: + ortho_result = ortho_result.movedim(-1, start_dim) + return ortho_result.reshape(orig_shape) + + +def symmetric_ortho_blend( + a: torch.Tensor, + b: torch.Tensor, + t: torch.Tensor, + *, + symmetric_strength: float = 1.0, + symmetric_deduce_mode: bool = False, + **kwargs: dict, +) -> torch.Tensor: + blended = ortho_blend(a, b, t, **kwargs) + if symmetric_strength == 0.0: + return blended + b_ortho = blended.sub_(a) + if symmetric_deduce_mode: + b_proj = b - b_ortho + # Projection would theoretically be the same for both, in the simple case at least? + # Actually, probably not. Oh well, this is here as an option now. + a_ortho = a - b_proj + else: + a_ortho = ortho_blend(b, a, a.new_tensor(1.0), **kwargs) - b + a_proj = a - a_ortho + return a_proj.mul_(1.0 - symmetric_strength).add_(a_ortho).add_(b_ortho) + + +class WaveletBlend: + wavelet: wavef.Wavelet | None = None + use_float64: bool = False + + def __init__( + self, + *, + device: str | torch.device | None = None, + use_float64: bool = False, + **kwargs: dict, + ): + self.device = device + self.wavelet_kwargs = kwargs + self.use_float64 = use_float64 + + def get_wavelet(self, *, device: str | torch.device | None = None) -> wavef.Wavelet: + if self.wavelet is None: + self.wavelet = wavef.Wavelet( + device=device if device is not None else self.device, + **self.wavelet_kwargs, + ).to(dtype=torch.float64 if self.use_float64 else torch.float32) + self.device = device + return self.wavelet + if device is not None and self.wavelet.device != device: + self.wavelet = self.wavelet.to(device=device) + self.device = device + return self.wavelet + + @staticmethod + def maybe_offset( + yl: torch.Tensor, + yh: Sequence[torch.Tensor], + offset_yl: float | torch.Tensor | None, + offset_yh: float | Sequence[float | Sequence[float]] | None, + *, + in_place: bool = False, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if offset_yl in {None, 1.0} and offset_yh in {None, 1.0}: + return (yl, tuple(yh)) + return wavef.wavelet_scaling( + yl, + yh, + yl_scale=offset_yl if offset_yl is not None else 1.0, + yh_scales=offset_yh, + in_place=in_place, + ) + + def wavelet_blend( + self, + a: torch.Tensor, + b: torch.Tensor, + t: float | torch.Tensor, + *, + blend_mode_yl: str | Callable = torch.lerp, + blend_mode_yh: str | Callable | None = None, + a_offset_yl: float | torch.Tensor | None = None, + a_offset_yh: float | Sequence[float | Sequence[float]] | None = None, + b_offset_yl: float | torch.Tensor | None = None, + b_offset_yh: float | Sequence[float | Sequence[float]] | None = None, + out_offset_yl: float | torch.Tensor | None = None, + out_offset_yh: float | Sequence[float | Sequence[float]] | None = None, + blend_yl_offset: float = 1.0, + blend_yh_offset: float | torch.Tensor = 1.0, + two_step_inverse: bool = False, + in_place_offset: bool = True, + ) -> torch.Tensor: + if isinstance(blend_mode_yl, str): + blend_mode_yl = BLENDING_MODES[blend_mode_yl] + if blend_mode_yh is None: + blend_mode_yh = blend_mode_yl + elif isinstance(blend_mode_yh, str): + blend_mode_yh = BLENDING_MODES[blend_mode_yh] + wavelet = self.get_wavelet(device=a.device) + dtype = a.dtype + if a.ndim != b.ndim: + raise ValueError( + f"Tensor a ndim ({a.ndim}) must match tensor b ndim ({b.ndim})" + ) + orig_shape = a.shape + # FIXME: This reshaping logic is almost certainly not reliable. + if a.ndim > 4: + a = a.reshape(a.shape[0], -1, *a.shape[-2:]) + if b.ndim > 4: + b = a.reshape(b.shape[0], -1, *b.shape[-2:]) + a = a.to(dtype=torch.float64 if self.use_float64 else torch.float32) + b = b.to(a) + t = a.new_tensor(t) if not isinstance(t, torch.Tensor) else t.to(a) + if t.ndim > 4: + t = a.reshape(t.shape[0], -1, *t.shape[-2:]) + aw_l, aw_h = self.maybe_offset( + *wavelet.forward(a), + a_offset_yl, + a_offset_yh, + in_place=in_place_offset, + ) + bw_l, bw_h = self.maybe_offset( + *wavelet.forward(b), + b_offset_yl, + b_offset_yh, + in_place=in_place_offset, + ) + blend_yl_offset = t if blend_yl_offset == 1 else t * blend_yl_offset + blend_yh_offset = t if blend_yh_offset == 1 else t * blend_yh_offset + outw_l, outw_h = self.maybe_offset( + *wavef.wavelet_blend( + (aw_l, aw_h), + (bw_l, bw_h), + yl_factor=blend_yl_offset, + yh_factor=blend_yh_offset, + blend_function=blend_mode_yl, + yh_blend_function=blend_mode_yh, + ), + offset_yl=out_offset_yl, + offset_yh=out_offset_yh, + in_place=in_place_offset, + ) + result = wavelet.inverse(outw_l, outw_h, two_step_inverse=two_step_inverse) + result = result[tuple(slice(None, dsize) for dsize in a.shape)] + return result.to(dtype=dtype).reshape(orig_shape) + + +WAVELET_BLEND_CACHE: dict[frozenset[tuple[str, Any]], WaveletBlend] = {} + + +def wavelet_blend( + a: torch.Tensor, + b: torch.Tensor, + t: float | torch.Tensor, + *, + blend_mode_yl: str | Callable = torch.lerp, + blend_mode_yh: str | Callable | None = None, + **kwargs: dict[str, Any], +) -> torch.Tensor: + if isinstance(blend_mode_yl, str): + blend_mode_yl = BLENDING_MODES[blend_mode_yl] + if blend_mode_yh is None: + blend_mode_yh = blend_mode_yl + _ = kwargs.pop("device", None) + wavelet_kwargs = { + k: kwargs.pop(k) + for k in ( + "wave", + "level", + "mode", + "use_1d_dwt", + "use_dtcwt", + "biort", + "qshift", + "inv_wave", + "inv_mode", + "inv_biort", + "inv_qshift", + "two_step_inverse", + "use_float64", + ) + if k in kwargs + } + cache_key = frozenset( + ( + wavelet_kwargs + | {"blend_mode_yl": blend_mode_yl, "blend_mode_yh": blend_mode_yh} + ).items(), + ) + print(f"\nWAVELET BLEND: cache key: {cache_key}") + wb = WAVELET_BLEND_CACHE.get(cache_key) + if wb is None: + wb = WaveletBlend(device=a.device, **wavelet_kwargs) + WAVELET_BLEND_CACHE[cache_key] = wb + return wb.wavelet_blend( + a, + b, + t, + blend_mode_yl=blend_mode_yl, + blend_mode_yh=blend_mode_yh, + **kwargs, + ) + + class BlendMode: __slots__ = ( "allow_scale", @@ -805,24 +1153,36 @@ class BlendMode: "f_kwargs", "f_raw", "force_rescale", + "fork_rng", + "invert_scale", "norm", "norm_dims", "rescale_dims", + "rescale_max", + "rescale_min", "rev", + "scale_multiplier", + "visible", ) class _Empty: pass - def __init__( # noqa: PLR0917 + def __init__( self, f, norm=None, - norm_dims=(-3, -2, -1), - rev=False, - allow_scale=True, - rescale_dims=(-3, -2, -1), - force_rescale=False, + norm_dims: tuple = (-3, -2, -1), + rev: bool = False, + allow_scale: bool = True, + rescale_dims: tuple = (-3, -2, -1), + rescale_min: float = 0.0, + rescale_max: float = 1.0, + force_rescale: bool = False, + fork_rng: bool = False, + invert_scale: float | None = None, + scale_multiplier: float = 1.0, + visible: bool = True, **kwargs: dict, ): self.f_raw = f @@ -837,37 +1197,35 @@ def __init__( # noqa: PLR0917 self.rev = rev self.allow_scale = allow_scale self.rescale_dims = rescale_dims + self.rescale_min = rescale_min + self.rescale_max = rescale_max self.force_rescale = force_rescale + self.fork_rng = fork_rng + self.invert_scale = invert_scale + self.scale_multiplier = scale_multiplier + self.visible = visible - def edited( - self, - *, - f=_Empty, - norm=_Empty, - norm_dims=_Empty, - rev=_Empty, - allow_scale=_Empty, - rescale_dims=_Empty, - force_rescale=_Empty, - preserve_kwargs=True, - **kwargs: dict, - ) -> object: + def edited(self, *, f=_Empty, preserve_kwargs=True, **kwargs: dict) -> BlendMode: empty = self._Empty kwargs = (self.f_kwargs | kwargs) if preserve_kwargs else kwargs - return self.__class__( - f if f is not empty else self.f_raw, - norm=norm if norm is not empty else self.norm, - norm_dims=norm_dims if norm_dims is not empty else self.norm_dims, - rev=rev if rev is not empty else self.rev, - allow_scale=allow_scale if allow_scale is not empty else self.allow_scale, - rescale_dims=rescale_dims - if rescale_dims is not empty - else self.rescale_dims, - force_rescale=force_rescale - if force_rescale is not empty - else self.force_rescale, - **kwargs, - ) + kwargs |= { + k: v if (v := kwargs.get(k, empty)) is not empty else getattr(self, k) + for k in ( + "norm", + "norm_dims", + "rev", + "allow_scale", + "rescale_dims", + "rescale_min", + "rescale_max", + "force_rescale", + "fork_rng", + "invert_scale", + "scale_multiplier", + "visible", + ) + } + return self.__class__(f if f is not empty else self.f_raw, **kwargs) def rescale(self, t, *, rescale_dims=_Empty): if t.ndim > 2: @@ -879,24 +1237,47 @@ def rescale(self, t, *, rescale_dims=_Empty): rescale_dims = -1 tmin = torch.amin(t, keepdim=True, dim=rescale_dims) tmax = torch.amax(t, keepdim=True, dim=rescale_dims) - return (t - tmin).div_(tmax - tmin).clamp_(0, 1), tmin, tmax + return ( + (t - tmin).div_(tmax - tmin).clamp_(self.rescale_min, self.rescale_max), + tmin, + tmax, + ) - def __call__(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor: + def __call__( + self, + a: torch.Tensor, + b: torch.Tensor, + t: torch.Tensor | float, + *, + norm_dims=_Empty, + ) -> torch.Tensor: if not self.force_rescale: return self.__call__internal(a, b, t, norm_dims=norm_dims) a, amin, amax = self.rescale(a) b, bmin, bmax = self.rescale(b) - result = self.__call__internal(a, b, t, norm_dims=norm_dims) + with torch.random.fork_rng(devices=(a.device, b.device), enabled=self.fork_rng): + result = self.__call__internal(a, b, t, norm_dims=norm_dims) del a, b rmin, rmax = torch.lerp(amin, bmin, 0.5), torch.lerp(amax, bmax, 0.5) del amin, amax, bmin, bmax return result.mul_(rmax.sub_(rmin)).add_(rmin) - def __call__internal(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor: + def __call__internal( + self, + a: torch.Tensor, + b: torch.Tensor, + t: torch.Tensor | float, + *, + norm_dims=_Empty, + ) -> torch.Tensor: if not isinstance(t, torch.Tensor) and isinstance(a, torch.Tensor): t = a.new_full((1,), t) if self.rev: a, b = b, a + if self.invert_scale is not None: + t = self.invert_scale - t + if self.scale_multiplier != 1.0: + t = t * self.scale_multiplier if self.norm is None: return self.f(a, b, t) return self.norm( @@ -907,11 +1288,36 @@ def __call__internal(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor: class BlendingModes: + BLEH = True + def __init__(self, builtins=None): self.builtins = {} if builtins is None else builtins self.cache = {} - def get(self, k: str, default=None): + def get_dict_key(self, k: dict): + ds = frozenset(k.items()) + cached = self.cache.get(ds) + if cached is not None: + return cached + name = k.get("name") + if name is None: + raise ValueError( + "When passing a blend mode key as dict, a string 'name' key must exist." + ) + name = name.strip() + base_bm = self.builtins.get(name) + if base_bm is None: + errstr = f"Unknown mode {name} for extended blend specification" + raise ValueError(errstr) + bm_kwargs = k.copy() + del bm_kwargs["name"] + bm = base_bm.edited(**bm_kwargs) + self.cache[k] = bm + return bm + + def get(self, k: str | dict, default=None): + if isinstance(k, dict): + return self.get_dict_key(k) result = self.builtins.get(k) if result is not None: return result @@ -995,16 +1401,16 @@ def try_extended(self, k: str, default=None) -> object: return bm def items(self): - return self.builtins.items() + return ((k, v) for k, v in self.builtins.items() if v.visible) def values(self): - return self.builtins.values() + return (v for v in self.builtins.values() if v.visible) def __contains__(self, k: str) -> bool: return self.get(k) is not None def __iter__(self): - return self.builtins.__iter__() + return (k for k, _v in self.items()) keys = __iter__ @@ -1054,7 +1460,8 @@ def copy(self): "a_only": BlendMode(lambda a, _b, t: a * t, allow_scale=False), "b_only": BlendMode(lambda _a, b, t: b * t, allow_scale=False), # Interpolates between tensors a and b using normalized linear interpolation. - "bislerp": BlendMode( + # This definitely isn't biSLERP. + "bislerp_wrong": BlendMode( lambda a, b, t: ((1 - t) * a).add_(t * b), normalize, ), @@ -1118,6 +1525,8 @@ def copy(self): "inject_copysign_b": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(b)), "inject_avoidsign_a": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(a.neg())), "inject_avoidsign_b": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(b.neg())), + "cfg": BlendMode(lambda a, b, t: (a - b).mul_(t).add_(b)), + "cfg_base_a": BlendMode(lambda a, b, t: (a - b).mul_(t).add_(a)), # Interpolates between tensors a and b using linear interpolation. # "lerp": BlendMode(lambda a, b, t: ((1.0 - t) * a).add_(t * b)), "lerp": BlendMode(torch.lerp), @@ -1138,6 +1547,9 @@ def copy(self): "lerp_avoidsign_b": BlendMode( lambda a, b, t: ((1.0 - t) * a).add_(t * b).copysign_(b.neg()), ), + "weighted_average": BlendMode( + lambda a, b, t: (b * t).add_(a) / (1.0 + abs(t)), + ), # Simulates a brightening effect by adding tensor b to tensor a, scaled by t. "lineardodge": BlendMode(lambda a, b, t: (b * t).add_(a)), "copysign": BlendMode(lambda a, b, _t: torch.copysign(a, b)), @@ -1251,6 +1663,10 @@ def copy(self): normalize, allow_scale=False, ), + "multiply_by_b": BlendMode( + lambda a, b, _t: a * b, + allow_scale=False, + ), "overlay": BlendMode( lambda a, b, t: (2 * a * b + a**2 - 2 * a * b * a) * t if torch.all(b < 0.5) @@ -1296,6 +1712,65 @@ def copy(self): allow_scale=False, force_rescale=True, ), + "wavelet_b_hi_100_lo_0": BlendMode( + f=wavelet_blend, + blend_yl_offset=0.0, + blend_yh_offset=1.0, + wave="db4", + level=8, + ), + "wavelet_b_hi_0_lo_100": BlendMode( + f=wavelet_blend, + blend_yl_offset=1.0, + blend_yh_offset=0.0, + wave="db4", + level=8, + ), + "ortho": BlendMode(ortho_blend), + "ortho_rescaled": BlendMode(ortho_blend, rescale_limit=2.0), + "ortho_rescaled_lerpish": BlendMode( + ortho_blend, + rescale_limit=2.0, + rescale_result_blend_mode="lerp", + rescale_result_mode="blend", + ), + "ortho_lerp": BlendMode(ortho_blend, blend_mode="lerp"), + "ortho_dyn_lerp": BlendMode( + ortho_blend, + blend_mode="lerp", + rescale_result_mode="blend", + rescale_limit=4.0, + dyn_ortho_mode=True, + ), + "ortho_dyn_lerp_inverted": BlendMode( + ortho_blend, + blend_mode="lerp", + rescale_result_mode="blend", + rescale_limit=2.0, + dyn_ortho_mode=True, + rev=True, + invert_scale=1.0, + ), + "ortho_lerp_rescaled": BlendMode( + ortho_blend, + blend_mode="lerp", + rescale_result_mode="blend", + rescale_limit=2.0, + ), + "ortho_cfg": BlendMode( + lambda a, b, t, **kwargs: ortho_blend(b, a - b, t, **kwargs), + ), + "ortho_cfg_base_a": BlendMode( + lambda a, b, t, **kwargs: ortho_blend(a, a - b, t, **kwargs), + ), + "symmetric_ortho": BlendMode(symmetric_ortho_blend), + "symmetric_ortho_rescaled": BlendMode(symmetric_ortho_blend, rescale_limit=2.0), + "symmetric_ortho_cfg": BlendMode( + lambda a, b, t, **kwargs: symmetric_ortho_blend(b, a - b, t, **kwargs), + ), + "symmetric_ortho_cfg_base_a": BlendMode( + lambda a, b, t, **kwargs: symmetric_ortho_blend(a, a - b, t, **kwargs), + ), } BLENDING_MODES |= { @@ -1319,10 +1794,10 @@ def copy(self): "bislerp": slerp_orig, "altbislerp": altslerp, "revaltbislerp": lambda a, b, t: altslerp(b, a, t), - "bibislerp": BLENDING_MODES["bislerp"].edited(norm_dims=0), + "bibislerp": BLENDING_MODES["bislerp_wrong"].edited(norm_dims=0), "revhslerp": lambda a, b, t: hslerp_alt(b, a, t), "revbislerp": lambda a, b, t: slerp_orig(b, a, t), - "revbibislerp": BLENDING_MODES["revbislerp"].edited(norm_dims=0), + "revbibislerp": BLENDING_MODES["revbislerp_wrong"].edited(norm_dims=0), } FILTER_PRESETS = { @@ -1611,7 +2086,9 @@ def biderp(samples, width, height, mode="bislerp", mode_h=None): # noqa: PLR091 mode_h = mode derp_w = (BIDERP_MODES if ":" not in mode else BLENDING_MODES).get(mode, slerp_orig) - derp_h = (BIDERP_MODES if ":" not in mode_h else BLENDING_MODES).get(mode_h, slerp_orig) + derp_h = (BIDERP_MODES if ":" not in mode_h else BLENDING_MODES).get( + mode_h, slerp_orig + ) def generate_bilinear_data(length_old, length_new, device): coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape( diff --git a/py/nodes/__init__.py b/py/nodes/__init__.py index f8358ec..8f6ca9f 100644 --- a/py/nodes/__init__.py +++ b/py/nodes/__init__.py @@ -36,10 +36,14 @@ "BlehPlug": misc.BlehPlug, "BlehRefinerAfter": refinerAfter.BlehRefinerAfter, "BlehSageAttentionSampler": sageAttention.BlehSageAttentionSampler, + "BlehAdvancedAttentionSampler": sageAttention.BlehAdvancedAttentionSampler, "BlehSetSamplerPreset": samplers.BlehSetSamplerPreset, "BlehSetSigmas": misc.BlehSetSigmas, "BlehTAEVideoDecode": taevid.TAEVideoDecode, "BlehTAEVideoEncode": taevid.TAEVideoEncode, + "BlehModelProcessLatentIn": misc.BlehModelProcessLatentIn, + "BlehModelProcessLatentOut": misc.BlehModelProcessLatentOut, + "BlehFixGuiderPreviewing": misc.BlehFixGuiderPreviewing, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/py/nodes/blockCFG.py b/py/nodes/blockCFG.py index 164769e..5918e2a 100644 --- a/py/nodes/blockCFG.py +++ b/py/nodes/blockCFG.py @@ -1,4 +1,60 @@ -from functools import partial +import math +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial, reduce + +import torch +from tqdm import tqdm + + +class BlockType(Enum): + INPUT = auto() + OUTPUT = auto() + MIDDLE = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN = auto() + + +class BlendType(Enum): + DIFF = auto() + RESULT = auto() + + +class CondType(Enum): + COND = auto() + UNCOND = auto() + BOTH = auto() + + +@dataclass +class BlockCFGItem: + start_sigma: float + end_sigma: float + block_type: BlockType + target_type: BlendType + cond_type: CondType + block_num: int + scale: float + skip_mode: bool + + +class BlockCFG: + start_sigma: float + end_sigma: float + block_types: frozenset[BlockType] + verbose: bool = True + + def __init__(self, items: tuple[BlockCFGItem, ...]): + self.start_sigma, self.end_sigma = reduce( + lambda old, new: (min(old[0], new[0]), max(old[1], new[1])), + ((i.start_sigma, i.end_sigma) for i in items), + (math.inf, math.inf * -1), + ) + self.block_types = frozenset(i.block_type for i in items) + + # def check_applies(self, class BlockCFGBleh: @@ -123,6 +179,7 @@ def patch( reverse = apply_to != "cond" def check_applies(block_list, transformer_options): + tqdm.write(f"* BLOCKCFG: tf={transformer_options}") cond_or_uncond = transformer_options["cond_or_uncond"] if ( not (0 in cond_or_uncond and 1 in cond_or_uncond) @@ -140,6 +197,13 @@ def check_applies(block_list, transformer_options): return -1 in block_list return block_def in {-1, transformer_options.get("transformer_index")} + def apply_cfg_fun_(tensor: torch.Tensor, primary_offset: int) -> torch.Tensor: + full_batch = tensor.shape[0] + if full_batch % 2: + raise RuntimeError("Batch size must be multiple of 2") + batch = full_batch // 2 + diff = tensor[:batch, ...] - tensor[batch:, ...] + def apply_cfg_fun(tensor, primary_offset): secondary_offset = 0 if primary_offset == 1 else 1 if reverse: @@ -156,12 +220,15 @@ def apply_cfg_fun(tensor, primary_offset): ).mul_(scale) return result + mid_patch = None + def non_output_block_patch(h, transformer_options, *, block_list): + nonlocal mid_patch + # print("\nSET????", mid_patch) + if mid_patch is not None: + mid_patch._bleh_set_topts(transformer_options) cond_or_uncond = transformer_options["cond_or_uncond"] - if not check_applies( - block_list, - transformer_options, - ): + if not block_list or not check_applies(block_list, transformer_options): return h return apply_cfg_fun(h, cond_or_uncond.index(0)) @@ -180,7 +247,55 @@ def output_block_patch(h, hsp, transformer_options, *, block_list): ) m = model.clone() - if input_blocks: + + if middle_blocks: + # print("******** MIDDLE") + try: + mb = model.get_model_object("diffusion_model.middle_block.0") + except AttributeError: + mb = None + orig_forward = getattr(mb, "forward", None) + if mb is None or orig_forward is None: + raise ValueError("Could not get middle block or forward") + + class MBForward: + def __init__(self, orig_forward): + real_orig_forward = orig_forward + while temp := getattr( + real_orig_forward, "_bleh_orig_forward", None + ): + real_orig_forward = temp + orig_forward = real_orig_forward + self._bleh_orig_forward = orig_forward + self._bleh_topts = None + + def _bleh_set_topts(self, transformer_options: dict) -> None: + # if self._bleh_topts: + # return + cond_or_uncond = transformer_options["cond_or_uncond"] + self._bleh_topts = { + "cond_or_uncond": cond_or_uncond.clone() + if isinstance(cond_or_uncond, torch.Tensor) + else cond_or_uncond, + "sigmas": transformer_options["sigmas"].clone(), + "block": ("middle", 0), + } + + def __call__(self, *args: list, **kwargs: dict) -> torch.Tensor: + result = self._bleh_orig_forward(*args, **kwargs) + try: + return non_output_block_patch( + result, + self._bleh_topts, + block_list=middle_blocks, + ) + finally: + self._bleh_topts = None + + mid_patch = MBForward(orig_forward) + m.add_object_patch("diffusion_model.middle_block.0.forward", mid_patch) + + if input_blocks or middle_blocks: ( m.set_model_input_block_patch_after_skip if skip_mode @@ -188,11 +303,7 @@ def output_block_patch(h, hsp, transformer_options, *, block_list): )( partial(non_output_block_patch, block_list=input_blocks), ) - if middle_blocks: - m.set_model_patch( - partial(non_output_block_patch, block_list=middle_blocks), - "middle_block_patch", - ) + if output_blocks: m.set_model_output_block_patch( partial(output_block_patch, block_list=output_blocks), diff --git a/py/nodes/misc.py b/py/nodes/misc.py index 61c55f4..4dfac74 100644 --- a/py/nodes/misc.py +++ b/py/nodes/misc.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib +import math import operator import random from decimal import Decimal @@ -11,7 +12,7 @@ from comfy import model_management from comfy.model_management import throw_exception_if_processing_interrupted -from ..better_previews.previewer import ensure_previewer +from ..better_previews.previewer import PREVIEWER_STATE, ensure_previewer from ..latent_utils import normalize_to_scale @@ -482,3 +483,159 @@ def output_block_patch(h, hsp, _transformer_options): m.set_model_output_block_patch(output_block_patch) return (m,) + + +class BlehModelProcessLatentIn: + DESCRIPTION = "Advanced node that can be used to scale a raw latent for model input. Generally only needed if you're doing something that bypasses the normal latent input mechanisms." + RETURN_TYPES = ("LATENT",) + FUNCTION = "go" + CATEGORY = "latent/advanced" + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "model": ("MODEL",), + "latent": ("LATENT",), + }, + } + + @classmethod + def go(cls, *, model, latent: dict) -> tuple[dict]: + latent_format = model.model.latent_format + samples = ( + latent["samples"].detach().to(device="cpu", dtype=torch.float32, copy=True) + ) + return (latent | {"samples": latent_format.process_in(samples)},) + + +class BlehModelProcessLatentOut: + DESCRIPTION = "Advanced node that can be used to scale a latent to the correct range for output. Generally only needed if you're doing something that bypasses the normal latent output mechanisms." + RETURN_TYPES = ("LATENT",) + FUNCTION = "go" + CATEGORY = "latent/advanced" + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "model": ("MODEL",), + "latent": ("LATENT",), + }, + } + + @classmethod + def go(cls, *, model, latent: dict) -> tuple[dict]: + latent_format = model.model.latent_format + samples = ( + latent["samples"].detach().to(device="cpu", dtype=torch.float32, copy=True) + ) + return (latent | {"samples": latent_format.process_out(samples)},) + + +class PreviewFixGuider: + def __init__(self, guider, *, fps_override: int | float | None = None): + self.__guider = guider + self.__fps_override = fps_override + + def __getattr__(self, k): + return getattr(self.__guider, k) + + def sample(self, noise, latent_image, *args, **kwargs): + latent_shapes = ( + (tuple(latent_image.shape),) + if not latent_image.is_nested + else tuple(tuple(t.shape) for t in latent_image.unbind()) + ) + PREVIEWER_STATE.last_latent_shapes = latent_shapes + fps_override = self.__fps_override + if fps_override: + PREVIEWER_STATE.fps_override = fps_override + try: + return self.__guider.sample(noise, latent_image, *args, **kwargs) + finally: + PREVIEWER_STATE.last_latent_shapes = None + PREVIEWER_STATE.fps_override = None + + # def sample(self, noise, latent_image, *args, **kwargs): + # nest_index = self.__nest_index + # orig_callback = kwargs.get("callback") + # sample = partial(self.__guider.sample, noise, latent_image, *args) + # if not (latent_image.is_nested and orig_callback is not None): + # # Either not nested or no callback, so no need to fix previewing. + # return sample(**kwargs) + # latent_part_sizes = tuple( + # t.shape if isinstance(t, torch.Tensor) and not t.is_nested else None + # for t in latent_image.unbind() + # ) + # if not ( + # len(latent_part_sizes) >= nest_index + # and all(ps is not None for ps in latent_part_sizes) + # ): + # # Multiple levels of nesting not yet implemented. + # return sample(**kwargs) + # offset = 0 + # # ComfyUI preserves the batch dimension and smashes everything else together. + # for ps in latent_part_sizes[:nest_index]: + # offset += math.prod(ps[1:]) + # orig_shape = latent_part_sizes[nest_index] + # orig_nelems = math.prod(orig_shape[1:]) + + # def cb_wrapper(i, denoised, x, *args, **kwargs) -> None: + # print( + # f"\n\nCB SHAPES: {denoised.shape}, {x.shape}, orig {orig_shape}, orig elems {orig_nelems}", + # ) + # denoised, x = ( + # t[:, :, offset : offset + orig_nelems].reshape( + # t.shape[0], + # *orig_shape[1:], + # ) + # for t in (denoised, x) + # ) + # print( + # f"\n\nFIXED CB SHAPES: {denoised.shape}, {x.shape}", + # ) + # return orig_callback(i, denoised, x, *args, **kwargs) + + # kwargs["callback"] = cb_wrapper + # return sample(**kwargs) + + +class BlehFixGuiderPreviewing: + DESCRIPTION = "Wraps a guider to give the Bleh previewing system a hint about the latent shapes. Only necessary for models like LTX-2 which use nested tensors." + FUNCTION = "go" + OUTPUT_NODE = False + CATEGORY = "hacks" + + RETURN_TYPES = ("GUIDER",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "guider": ("GUIDER",), + "fps_override": ( + "FLOAT", + { + "default": 0.0, + "min": 0.0, + "max": 9999.0, + "tooltip": "Can be used to override the FPS when previewing with video models. Disabled if set to 0.", + }, + ), + }, + } + + @classmethod + def go( + cls, + *, + guider, + fps_override: float | None = None, + ) -> tuple: + return ( + PreviewFixGuider( + guider, + fps_override=fps_override or None, + ), + ) diff --git a/py/nodes/sageAttention.py b/py/nodes/sageAttention.py index af1aa46..3067b49 100644 --- a/py/nodes/sageAttention.py +++ b/py/nodes/sageAttention.py @@ -3,12 +3,18 @@ import contextlib import importlib -from typing import TYPE_CHECKING +import math +from enum import Enum, auto +from functools import partial, update_wrapper +from typing import TYPE_CHECKING, Any, NamedTuple import comfy.ldm.modules.attention as comfyattn import torch import yaml from comfy.samplers import KSAMPLER +from tqdm import tqdm + +from ..latent_utils import BLENDING_MODES try: import sageattention @@ -51,6 +57,22 @@ HAVE_ATTN_OVERRIDE = hasattr(comfyattn, "register_attention_function") +# class AttnsConfig(NamedTuple): +# name: str +# version: str +# supported_head_sizes: collections.abc.Collection + +# class AttentionRule(NamedTuple): + + +# class AttentionRules(NamedTuple): +# orig_attn: Callable +# start_sigma: float = math.inf +# end_sigma: float = 0.0 +# verbose: bool = False +# rules: tuple[AttentionRule, ...] = () + + def attention_bleh( # noqa: PLR0914 q: torch.Tensor, k: torch.Tensor, @@ -182,10 +204,14 @@ def make_attn_wrapper( raise ValueError( "SpargeAttention is not available, make sure you have the spas_sage_attn Python package installed", ) - if sageattn_function == "sparge": + if sageattn_function in {"sparge", "sparge2"}: sageattn_function = spas_sage_attn.spas_sage2_attn_meansim_cuda + elif sageattn_function == "sparge2_topk": + sageattn_function = spas_sage_attn.spas_sage2_attn_meansim_topk_cuda elif sageattn_function == "sparge1": sageattn_function = spas_sage_attn.spas_sage_attn_meansim_cuda + elif sageattn_function == "sparge1_topk": + sageattn_function = spas_sage_attn.spas_sage_attn_meansim_topk_cuda else: sageattn_function = getattr(sageattention, sageattn_function) @@ -331,6 +357,19 @@ def go( return (model,) +class TimeMode(Enum): + PERCENT = auto() + SIGMA = auto() + + +class SageAttnOptions(NamedTuple): + sampler: object + attn_kwargs: dict[str, Any] + start_time: float = math.inf + end_time: float = 0.0 + time_mode: TimeMode = TimeMode.PERCENT + + class BlehModelWrapper: def __init__(self, model: object, model_call: Callable): self.__bleh_model = model @@ -344,20 +383,25 @@ def __getattr__(self, k: str): def sageattn_sampler( + config: SageAttnOptions, model: object, x: torch.Tensor, sigmas: torch.Tensor, - *, - sageattn_sampler_options: tuple, + # *, + # sageattn_sampler_options: tuple, **kwargs: dict, ) -> torch.Tensor: - sampler, start_percent, end_percent, sageattn_kwargs = sageattn_sampler_options - ms = model.inner_model.inner_model.model_sampling - start_sigma, end_sigma = ( - round(ms.percent_to_sigma(start_percent), 4), - round(ms.percent_to_sigma(end_percent), 4), - ) - del ms + # sampler, start_percent, end_percent, sageattn_kwargs = sageattn_sampler_options + if config.time_mode == TimeMode.PERCENT: + ms = model.inner_model.inner_model.model_sampling + start_sigma, end_sigma = ( + round(ms.percent_to_sigma(config.start_time), 4), + round(ms.percent_to_sigma(config.end_time), 4), + ) + del ms + else: + start_sigma = config.start_time + end_sigma = config.end_time def model_call( model: object, @@ -369,7 +413,7 @@ def model_call( enabled = end_sigma <= sigma_float <= start_sigma with sageattn_context( enabled=enabled, - **sageattn_kwargs, + **config.attn_kwargs, ) as attn_override: if enabled and HAVE_ATTN_OVERRIDE: model_options = kwargs.pop("model_options", {}).copy() @@ -381,12 +425,12 @@ def model_call( kwargs["model_options"] = model_options return model(x, sigma, **kwargs) - return sampler.sampler_function( + return config.sampler.sampler_function( BlehModelWrapper(model, model_call), x, sigmas, **kwargs, - **sampler.extra_options, + **config.sampler.extra_options, ) @@ -450,14 +494,383 @@ def go( ) return ( KSAMPLER( - sageattn_sampler, - extra_options={ - "sageattn_sampler_options": ( - sampler, - start_percent, - end_percent, - get_yaml_parameters(yaml_parameters), + update_wrapper( + partial( + sageattn_sampler, + SageAttnOptions( + start_time=start_percent, + end_time=end_percent, + time_mode=TimeMode.PERCENT, + sampler=sampler, + attn_kwargs=get_yaml_parameters(yaml_parameters), + ), + ), + sampler.sampler_function, + ), + ), + ) + + +class AdvancedAttnRule(NamedTuple): + attn_function: Callable | None + attn_kwargs: dict[str, Any] + blend_function: Callable | None = None + start_sigma: float = math.inf + end_sigma: float = 0.0 + check_nan: bool = False + q_multiplier: float = 1.0 + k_multiplier: float = 1.0 + v_multiplier: float = 1.0 + output_multiplier: float = 1.0 + device: torch.device | str | None = None + dtype: torch.dtype | str | None = None + blend: float = 1.0 + op_q: str | None = None + op_k: str | None = None + op_v: str | None = None + op_current_result_preblend: str | None = None + op_result_preblend: str | None = None + op_result_postblend: str | None = None + op_result_postblend_diff: str | None = None + + @classmethod + def build( + cls, + *, + attn_function: str | Callable | None = None, + blend_mode: str | None = None, + device=None, + dtype=None, + **kwargs, + ) -> NamedTuple: + blend_function = BLENDING_MODES[blend_mode] if blend_mode is not None else None + my_kwargs = {k: kwargs.pop(k) for k in cls._fields if k in kwargs} + if attn_function == "default": + attn_function = None + elif isinstance(attn_function, str): + kwargs["sageattn_function"] = attn_function + attn_function = make_attn_wrapper(orig_attn=None, **kwargs) + if isinstance(dtype, str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float64": torch.float64, + }.get(dtype) + return cls( + device=device, + dtype=dtype, + blend_function=blend_function, + attn_function=attn_function, + attn_kwargs=kwargs, + **my_kwargs, + ) + + +class AdvancedAttnConfig(NamedTuple): + sampler: object + verbose: bool = False + rules: tuple[AdvancedAttnRule, ...] = () + start_time: float = math.inf + end_time: float = 0.0 + call_indexes: frozenset[float] = frozenset() + time_mode: TimeMode = TimeMode.PERCENT + min_cond_batch: int = 0 + batch_slice: tuple | str | None = None + max_idx: int = -1 + delegate_override: bool = True + op_result: str | None = None + latent_ops: dict[str, Callable] = {} + + @classmethod + def build( + cls, + *, + rules=(), + time_mode: str | TimeMode | None = None, + call_indexes=(), + **kwargs, + ) -> NamedTuple: + fs = frozenset(cls._fields) + rules = tuple(AdvancedAttnRule.build(**r) for r in rules) + if isinstance(time_mode, str): + time_mode = getattr(TimeMode, time_mode.strip().upper()) + call_indexes = frozenset( + i + if math.isnan(i) or i == math.inf or not isinstance(i, float) + else int(i) + 0.5 + for i in call_indexes + ) + kwargs = {k: v for k, v in kwargs.items() if k in fs} + return cls( + rules=rules, + time_mode=time_mode, + call_indexes=call_indexes, + **kwargs, + ) + + def call_op( + self, op_key: str | None, t: torch.Tensor, *, sigma: float + ) -> torch.Tensor: + op = None if op_key is None else self.latent_ops.get(op_key) + if op is None: + return t + return ( + op(t) + if not hasattr(op, "EXTENDED_LATENT_OPERATION") + else op(t, sigma=sigma) + ) + + def attn_wrapper( + self, + sigma: float, + currattncall: CurrAttnCall, + old_override: Callable | None, + comfy_orig_attn: Callable, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + fallback_attn = ( + partial(old_override, comfy_orig_attn) + if old_override is not None and self.delegate_override + else comfy_orig_attn + ) + result = None + rules = self.rules + call_idx = currattncall.idx + rev_idx = ( + -abs(currattncall.max_idx - currattncall.idx) + if currattncall.max_idx >= 0 + else math.nan + ) + ci = self.call_indexes + fallthrough_match = math.inf in ci + exclude = (call_idx + 0.5) in ci or ( + not math.isnan(rev_idx) and (rev_idx + 0.5) in ci + ) + matched = not exclude and (fallthrough_match or call_idx in ci or rev_idx in ci) + if self.verbose: + tqdm.write( + f"[BLEH] AdvancedAttn wrapper({currattncall.idx:<3}): sigma={sigma:.4f}, max_idx={currattncall.max_idx:<3}, rev_idx={rev_idx:<3}, matched={matched}, exclude={exclude}, q.shape={q.shape}", + ) + currattncall.idx += 1 + rules = self.rules if matched else () + for rule in rules: + if not rule.end_sigma <= sigma <= rule.start_sigma: + continue + currq, currk, currv = q, k, v + if rule.q_multiplier != 1: + currq = currq * rule.q_multiplier + if rule.k_multiplier != 1: + currk = currk * rule.k_multiplier + if rule.v_multiplier != 1: + currv = currv * rule.v_multiplier + if rule.dtype is not None or rule.device is not None: + currq = currq.to(device=rule.device, dtype=rule.dtype) + currk = currk.to(device=rule.device, dtype=rule.dtype) + currv = currv.to(device=rule.device, dtype=rule.dtype) + currq = self.call_op(rule.op_q, currq, sigma=sigma) + currk = self.call_op(rule.op_k, currk, sigma=sigma) + currv = self.call_op(rule.op_v, currv, sigma=sigma) + attn_function = ( + partial(rule.attn_function, fallback_attn) + if rule.attn_function + else fallback_attn + ) + curr_result = attn_function(currq, currk, currv, *args, **kwargs) + del currq, currk, currv + if rule.check_nan and curr_result.isnan().any(): + del curr_result + continue + if rule.output_multiplier != 1: + curr_result *= rule.output_multiplier + curr_result = self.call_op( + rule.op_current_result_preblend, + curr_result, + sigma=sigma, + ) + if curr_result.dtype != q.dtype or curr_result.device != q.device: + curr_result = curr_result.to(q) + if result is not None: + result = self.call_op(rule.op_result_preblend, result, sigma=sigma) + if result is None or rule.blend_function is None: + result = curr_result + del curr_result + continue + prev_result = result + result = rule.blend_function(result, curr_result, rule.blend) + if rule.op_result_postblend_diff is not None: + result = prev_result + self.call_op( + rule.op_result_postblend_diff, + result - prev_result, + sigma=sigma, + ) + del curr_result, prev_result + result = self.call_op(rule.op_result_postblend, result, sigma=sigma) + if result is None: + result = fallback_attn(q, k, v, *args, **kwargs) + return self.call_op(self.op_result, result, sigma=sigma) + + +class CurrAttnCall: + def __init__(self, idx: int = 0, max_idx: int = -1): + self.idx = idx + self.max_idx = max_idx + + +def advancedattn_sampler( + config: AdvancedAttnConfig, + model: object, + x: torch.Tensor, + sigmas: torch.Tensor, + **kwargs: dict, +) -> torch.Tensor: + if config.time_mode == TimeMode.PERCENT: + ms = model.inner_model.inner_model.model_sampling + start_sigma, end_sigma = ( + round(ms.percent_to_sigma(config.start_time), 4), + round(ms.percent_to_sigma(config.end_time), 4), + ) + del ms + else: + start_sigma = config.start_time + end_sigma = config.end_time + + max_idx = -1 + + def model_call( + model: object, + x: torch.Tensor, + sigma: torch.Tensor, + **kwargs: dict[str, Any], + ) -> torch.Tensor: + nonlocal max_idx + sigma_float = float(sigma.max().detach().cpu()) + enabled = end_sigma <= sigma_float <= start_sigma + if not enabled: + return model(x, sigma, **kwargs) + calltracker = CurrAttnCall(idx=0, max_idx=max_idx) + if config.verbose: + tqdm.write(f"[BLEH] AdvancedAttn: Config: {config}") + model_options = kwargs.pop("model_options", {}).copy() + transformer_options = model_options.pop("transformer_options", {}).copy() + old_override = transformer_options.pop("optimized_attention_override", None) + attn_override = partial( + config.attn_wrapper, + sigma_float, + calltracker, + old_override, + ) + transformer_options["optimized_attention_override"] = attn_override + model_options["transformer_options"] = transformer_options + kwargs["model_options"] = model_options + result = model(x, sigma, **kwargs) + max_idx = max(max_idx, calltracker.idx) + return result + + return config.sampler.sampler_function( + BlehModelWrapper(model, model_call), + x, + sigmas, + **kwargs, + **config.sampler.extra_options, + ) + + +class BlehAdvancedAttentionSampler: + DESCRIPTION = "TBD" + CATEGORY = "sampling/custom_sampling/samplers" + RETURN_TYPES = ("SAMPLER",) + FUNCTION = "go" + + DEFAULT_YAML_PARAMS = """verbose: false +start_time: 0.0 +end_time: 1.0 +# One of: percent, sigma +time_mode: percent +# .inf means match everything. Whole float values exclude an index. I.E 2.0 +# Call index as in the Nth call to attention this model evaluation. +# Negative indexes count from the end but can only match after a pass through the model. +call_indexes: [.inf] +# Can be set to null (everything), cond, uncond or a list. +batch_slice: null +# Requires cond batch information to be passed and at least this many items. +min_cond_batch: 0 +rules: + # Passed as sageattn_function unless set to default. + # Keys not in this list are passed through like with the SageAttention node: + # attn_function, blend_mode, blend + - attn_function: default + # You can set whatever other keys you want here. + - attn_function: sageattn + # Blends target the last attention result and are ignored + # if it's missing. + # The default blend means: + # sageattn + (defaultattn - sageattn) * 2 + blend_mode: cfg + blend: 2.0 +""" + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "sampler": ("SAMPLER",), + "yaml_parameters": ( + "STRING", + { + "default": cls.DEFAULT_YAML_PARAMS, + "tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.", + "dynamicPrompts": False, + "multiline": True, + "defaultInput": True, + }, + ), + }, + "optional": { + "op_0": ("LATENT_OPERATION",), + "op_1": ("LATENT_OPERATION",), + "op_2": ("LATENT_OPERATION",), + "op_3": ("LATENT_OPERATION",), + "op_4": ("LATENT_OPERATION",), + "op_5": ("LATENT_OPERATION",), + "op_6": ("LATENT_OPERATION",), + "op_7": ("LATENT_OPERATION",), + "op_8": ("LATENT_OPERATION",), + "op_9": ("LATENT_OPERATION",), + }, + } + + @classmethod + def go( + cls, + sampler: object, + yaml_parameters: str, + **kwargs: dict, + ) -> tuple: + if sageattention is None: + raise RuntimeError( + "sageattention not installed to Python environment: SageAttention feature unavailable", + ) + if not HAVE_ATTN_OVERRIDE: + raise RuntimeError( + "This node only supports recent ComfyUI versions that support attention overrides.", + ) + params = get_yaml_parameters(yaml_parameters) + params["latent_ops"] = { + k: v for k, v in kwargs.items() if k.startswith("op_") and v is not None + } + return ( + KSAMPLER( + update_wrapper( + partial( + advancedattn_sampler, + AdvancedAttnConfig.build(sampler=sampler, **params), ), - }, + sampler.sampler_function, + ), ), ) diff --git a/py/nodes/taevid.py b/py/nodes/taevid.py index 0892342..ad3ef8a 100644 --- a/py/nodes/taevid.py +++ b/py/nodes/taevid.py @@ -18,7 +18,19 @@ class TAEVideoNodeBase: def INPUT_TYPES(cls) -> dict: return { "required": { - "latent_type": (("wan21", "wan22", "hunyuanvideo", "mochi"),), + "latent_type": ( + ( + "wan21", + "wan22", + "hunyuanvideo", + "hunyuanvideo15", + "mochi", + "ltxv", + ), + { + "tooltip": "Use ltxv for LTX-2 AV.", + }, + ), "parallel_mode": ( "BOOLEAN", { @@ -45,6 +57,8 @@ def get_taevid_model( model_src = "taew2_2.pth from https://github.com/madebyollin/taehv" elif latent_type == "hunyuanvideo": model_src = "taehv.pth from https://github.com/madebyollin/taehv" + elif latent_type == "ltxv": + model_src = "taeltx_2.pth from https://github.com/madebyollin/taehv" else: model_src = "taem1.pth from https://github.com/madebyollin/taem1" err_string = f"Missing TAE video model. Download {model_src} and place it in the models/vae_approx directory" @@ -66,7 +80,7 @@ def go(cls, *, latent, latent_type: str, parallel_mode: bool) -> tuple: class TAEVideoDecode(TAEVideoNodeBase): RETURN_TYPES = ("IMAGE",) CATEGORY = "latent" - DESCRIPTION = "Fast decoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD." + DESCRIPTION = "Fast decoding of Wan, Hunyuan, Mochi and LTX video latents with the video equivalent of TAESD." @classmethod def INPUT_TYPES(cls) -> dict: @@ -100,7 +114,7 @@ def go(cls, *, latent: dict, latent_type: str, parallel_mode: bool) -> tuple: class TAEVideoEncode(TAEVideoNodeBase): RETURN_TYPES = ("LATENT",) CATEGORY = "latent" - DESCRIPTION = "Fast encoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD." + DESCRIPTION = "Fast encoding of Wan, Hunyuan, Mochi and LTX video latents with the video equivalent of TAESD." @classmethod def INPUT_TYPES(cls) -> dict: diff --git a/py/settings.py b/py/settings.py index 327e93d..2ded21f 100644 --- a/py/settings.py +++ b/py/settings.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from pathlib import Path +from typing import NamedTuple class Settings: @@ -7,6 +10,8 @@ def __init__(self): def update(self, obj): btp = obj.get("betterTaesdPreviews", None) + if btp is None: + btp = obj.get("previews", None) self.btp_enabled = btp is not None and btp.get("enabled", True) is True if not self.btp_enabled: return @@ -23,6 +28,7 @@ def update(self, obj): self.btp_preview_device = btp.get("preview_device") # default, keep, float32, float16, bfloat16 self.btp_preview_dtype = btp.get("preview_dtype") + self.btp_preview_non_blocking = bool(btp.get("preview_non_blocking", False)) self.btp_maxed_batch_step_mode = btp.get("maxed_batch_step_mode", False) self.btp_compile_previewer = btp.get("compile_previewer", False) self.btp_oom_fallback = btp.get("oom_fallback", "latent2rgb") @@ -37,6 +43,7 @@ def update(self, obj): ) self.btp_animate_preview = btp.get("animate_preview", "none") self.btp_verbose = btp.get("verbose", False) + self.btp_publish_last_preview = btp.get("publish_last_preview", False) @staticmethod def get_cfg_path(filename) -> Path: diff --git a/py/wavelet_functions.py b/py/wavelet_functions.py new file mode 100644 index 0000000..b993b41 --- /dev/null +++ b/py/wavelet_functions.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import torch + +if TYPE_CHECKING: + from collections.abc import Sequence + +try: + import pytorch_wavelets as ptwav + import pywt + + HAVE_WAVELETS = True +except ImportError: + ptwav = None + pywt = None + HAVE_WAVELETS = False + + +def fallback[V, D](val: V | D, default: D = None) -> V | D: + return val if val is not None else default + + +class Wavelet: + DEFAULT_MODE = "symmetric" + DEFAULT_LEVEL = 3 + DEFAULT_WAVE = "db4" + DEFAULT_USE_1D_DWT = False + DEFAULT_USE_DTCWT = False + DEFAULT_QSHIFT = "qshift_a" + DEFAULT_BIORT = "near_sym_a" + + def __init__( + self, + *, + wave: str = DEFAULT_WAVE, + level: int = DEFAULT_LEVEL, + mode: str = DEFAULT_MODE, + use_1d_dwt: bool = DEFAULT_USE_1D_DWT, + use_dtcwt: bool = DEFAULT_USE_DTCWT, + biort: str = DEFAULT_BIORT, + qshift: str = DEFAULT_QSHIFT, + inv_wave: str | None = None, + inv_mode: str | None = None, + inv_biort: str | None = None, + inv_qshift=None, + device: str | torch.device | None = None, + ): + if not HAVE_WAVELETS: + raise RuntimeError( + "Wavelet use requires the pytorch_wavelets package to be installed in your Python environment", + ) + inv_wave = fallback(inv_wave, wave) + inv_mode = fallback(inv_mode, mode) + inv_biort = fallback(inv_biort, biort) + inv_qshift = fallback(inv_qshift, qshift) + if use_dtcwt: + fwdfun, invfun = ptwav.DTCWTForward, ptwav.DTCWTInverse + elif use_1d_dwt: + fwdfun, invfun = ptwav.DWT1DForward, ptwav.DWT1DInverse + else: + fwdfun, invfun = ptwav.DWTForward, ptwav.DWTInverse + if use_dtcwt: + self._wavelet_forward = fwdfun( + J=level, + mode=mode, + biort=biort, + qshift=qshift, + ) + self._wavelet_inverse = invfun( + mode=inv_mode, + biort=inv_biort, + qshift=inv_qshift, + ) + else: + self._wavelet_forward = fwdfun(J=level, wave=wave, mode=mode) + self._wavelet_inverse = invfun(wave=inv_wave, mode=inv_mode) + self.device = device + if device is not None: + self._wavelet_forward = self._wavelet_forward.to(device=device) + self._wavelet_inverse = self._wavelet_inverse.to(device=device) + + def forward( + self, + t: torch.Tensor, + *, + forward_function: Callable | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + return fallback(forward_function, self._wavelet_forward)(t) + + def inverse( + self, + yl: torch.Tensor, + yh: tuple[torch.Tensor, ...], + *, + inverse_function: Callable | None = None, + two_step_inverse: bool = False, + ) -> torch.Tensor: + inverse_function = fallback(inverse_function, self._wavelet_inverse) + if not two_step_inverse: + return inverse_function((yl, yh)) + result = inverse_function((torch.zeros_like(yl), yh)) + result += inverse_function( + ( + yl, + tuple(torch.zeros_like(yh_band) for yh_band in yh), + ) + ) + return result + + def to(self, *args: list, copy: bool = False, **kwargs: dict) -> Wavelet: + o = Wavelet.__new__(Wavelet) if copy else self + o._wavelet_forward = self._wavelet_forward.to(*args, **kwargs) # noqa: SLF001 + o._wavelet_inverse = self._wavelet_inverse.to(*args, **kwargs) # noqa: SLF001 + o.device = kwargs.get("device") + return o + + @staticmethod + def wavelist() -> tuple: + return tuple(pywt.wavelist()) if HAVE_WAVELETS else () + + @staticmethod + def biortlist() -> tuple: + return ( + ("near_sym_a", "near_sym_b", "antonini", "legall") if HAVE_WAVELETS else () + ) + + @staticmethod + def qshiftlist() -> tuple: + return ( + ("qshift_a", "qshift_b", "qshift_c", "qshift_d", "qshift_06") + if HAVE_WAVELETS + else () + ) + + @staticmethod + def modelist() -> tuple: + return ( + ( + "symmetric", + "zero", + "reflect", + "replicate", + "periodization", + "periodic", + "constant", + ) + if HAVE_WAVELETS + else () + ) + + +def expand_yh_scales( + yh: Sequence, + *, + yh_scales: float | Sequence = 1.0, +) -> float | tuple: + yhlen = len(yh) + yh_shape = yh[0].shape + # Doesn't make sense to target orientations for 1D DWD (3D here). + olen = yh_shape[2] if len(yh_shape) > 3 else 1 + # print(f"\nSIZES: yhlen={yhlen}, olen={olen}, yh_shape={yh[0].shape}") + if isinstance(yh_scales, (float, int)): + return ((float(yh_scales),) * olen,) * yhlen + otemplate = (1.0,) * olen + yh_scales = tuple( + (float(band),) * olen + if isinstance(band, (float, int)) + else ( + ( + *(float(i) for i in band[:olen]), + *otemplate[: olen - len(band[:olen])], + ) + if isinstance(band, (tuple, list)) + else band + ) + for band in yh_scales + ) + if "fill" in yh_scales: + fillidx = yh_scales.index("fill") + if "fill" in yh_scales[fillidx + 1 :]: + raise ValueError("Only one fill allowed.") + if fillidx == 0 or len(yh_scales) < 2: + raise ValueError( + "Invalid fill value, cannot be in the first position or the only item.", + ) + yhslen = len(yh_scales) + if yhslen - 1 < yhlen: + # Need to pad. + fill = (yh_scales[fillidx - 1],) * (yhlen - (len(yh_scales) - 1)) + yh_scales = (*yh_scales[:fillidx], *fill, *yh_scales[fillidx + 1 :]) + else: + # Just remove the "fill". + yh_scales = (*yh_scales[:fillidx], *yh_scales[fillidx + 1 :]) + return yh_scales[:yhlen] + + +def wavelet_scaling( + yl: torch.Tensor, + yh: Sequence[torch.Tensor], + yl_scale: float | torch.Tensor, + yh_scales: float | Sequence[float | Sequence[float]] | None, + *, + in_place: bool = False, +) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if not in_place: + yl = yl.clone() + yh = tuple(yhband.clone() for yhband in yh) + if yl_scale != 1.0: + yl *= yl_scale + yh_scales = expand_yh_scales( + yh, + yh_scales=yh_scales if yh_scales is not None else 1.0, + ) + for hscale, ht in zip(yh_scales, yh): + if isinstance(hscale, (int, float)): + ht *= hscale # noqa: PLW2901 + continue + for lidx in range(min(ht.shape[2], len(hscale))): + ht[:, :, lidx] *= hscale[lidx] + return (yl, yh) + + +def wavelet_blend( + a: tuple, + b: tuple, + *, + yl_factor: torch.Tensor | float, + blend_function: Callable, + yh_factor: torch.Tensor | float | None = None, + yh_blend_function: Callable | None = None, +) -> tuple: + if not isinstance(yl_factor, torch.Tensor): + yl_factor = a[0].new_full((1,), yl_factor) + if yh_factor is None: + yh_factor = yl_factor + elif not isinstance(yh_factor, torch.Tensor): + yh_factor = a[0].new_full((1,), yh_factor) + yh_blend_function = fallback(yh_blend_function, blend_function) + return ( + blend_function(a[0], b[0], yl_factor), + tuple(yh_blend_function(ta, tb, yh_factor) for ta, tb in zip(a[1], b[1])), + )