diff --git a/blehconfig.example.yaml b/blehconfig.example.yaml
index 3594ad6..510dfc2 100644
--- a/blehconfig.example.yaml
+++ b/blehconfig.example.yaml
@@ -38,6 +38,12 @@ betterTaesdPreviews:
# Setting it to "vae" will use whatever dtype ComfyUI is set to use for VAE.
preview_dtype: null
+ # Uses non-blocking transfers for previews when the device supports it.
+ # Not recommended as it is extremely likely to corrupt previews, especially if the previewer
+ # is relatively slow or the latent is large (video models, Chroma Radiance). However,
+ # it might decrease the performance impact of previewing.
+ preview_non_blocking: false
+
# Allows skipping upscale layers in the TAESD model, may increase performance when previewing large images or batches.
# May be set to -1 (conservative) or -2 (aggressive) to automatically calculate how many to skip. See README.md for details.
skip_upscale_layers: 0
diff --git a/py/better_previews/base.py b/py/better_previews/base.py
index 04b3897..f08df3c 100644
--- a/py/better_previews/base.py
+++ b/py/better_previews/base.py
@@ -4,43 +4,86 @@
from comfy import latent_formats
+from .tae_vid import TAEVid, TAEVidBase, TAEVidLTX2
+
if TYPE_CHECKING:
from pathlib import Path
class VideoModelInfo(NamedTuple):
+ name: str
latent_format: latent_formats.LatentFormat
- fps: int = 24
+ fps: int | float = 24
temporal_compression: int = 8
+ temporal_layers: int = 0
patch_size: int = 1
+ nested_tensor_index: int = 0
tae_model: str | Path | None = None
+ tae_class: TAEVidBase | None = TAEVid
VIDEO_FORMATS = {
- "mochi": VideoModelInfo(
- latent_formats.Mochi,
- temporal_compression=6,
- tae_model="taem1.pth",
- ),
- "hunyuanvideo": VideoModelInfo(
- latent_formats.HunyuanVideo,
- temporal_compression=4,
- tae_model="taehv.pth",
- ),
- "cosmos1cv8x8x8": VideoModelInfo(latent_formats.Cosmos1CV8x8x8),
- "wan21": VideoModelInfo(
- latent_formats.Wan21,
- fps=16,
- temporal_compression=4,
- tae_model="taew2_1.pth",
- ),
- "wan22": VideoModelInfo(
- latent_formats.Wan22,
- fps=24,
- temporal_compression=4,
- patch_size=2,
- tae_model="taew2_2.pth",
- ),
+ vmi.name: vmi
+ for vmi in (
+ VideoModelInfo(
+ "mochi",
+ latent_formats.Mochi,
+ temporal_compression=6,
+ tae_model="taem1.pth",
+ ),
+ VideoModelInfo(
+ "hunyuanvideo",
+ latent_formats.HunyuanVideo,
+ temporal_compression=4,
+ tae_model="taehv.pth",
+ ),
+ VideoModelInfo(
+ "hunyuanvideo15",
+ latent_formats.HunyuanVideo15,
+ temporal_compression=4,
+ patch_size=2,
+ tae_model="taehv1_5.pth",
+ ),
+ VideoModelInfo(
+ "cosmos1cv8x8x8",
+ latent_formats.Cosmos1CV8x8x8,
+ ),
+ VideoModelInfo(
+ "wan21",
+ latent_formats.Wan21,
+ fps=16,
+ temporal_compression=4,
+ temporal_layers=2,
+ tae_model="taew2_1.pth",
+ ),
+ VideoModelInfo(
+ "wan22",
+ latent_formats.Wan22,
+ fps=24,
+ temporal_compression=4,
+ temporal_layers=2,
+ patch_size=2,
+ tae_model="taew2_2.pth",
+ ),
+ VideoModelInfo(
+ "ltxv",
+ latent_formats.LTXV,
+ fps=24,
+ patch_size=4,
+ temporal_layers=3,
+ tae_model="taeltx_2.pth",
+ tae_class=TAEVidLTX2,
+ ),
+ VideoModelInfo(
+ "ltxav",
+ latent_formats.LTXV,
+ fps=24,
+ patch_size=4,
+ temporal_layers=3,
+ tae_model="taeltx_2.pth",
+ tae_class=TAEVidLTX2,
+ ),
+ )
}
diff --git a/py/better_previews/previewer.py b/py/better_previews/previewer.py
index 45c2369..4f90575 100644
--- a/py/better_previews/previewer.py
+++ b/py/better_previews/previewer.py
@@ -1,17 +1,22 @@
from __future__ import annotations
import math
+from io import BytesIO
from time import time
from typing import TYPE_CHECKING
+import comfy.utils as comfy_utils
import folder_paths
import latent_preview
import torch
+from aiohttp import web
+from comfy import latent_formats
from comfy.cli_args import LatentPreviewMethod
from comfy.cli_args import args as comfy_args
from comfy.model_management import device_supports_non_blocking, vae_dtype
from comfy.taesd.taesd import TAESD
from PIL import Image
+from server import PromptServer
from tqdm import tqdm
from ..settings import SETTINGS # noqa: TID252
@@ -19,14 +24,22 @@
from .tae_vid import TAEVid
if TYPE_CHECKING:
+ from collections.abc import Callable
+
import numpy as np
from comfy import latent_formats
+
+class BlehPreviewerState:
+ last_latent_shapes: tuple | None = None
+ fps_override: float | None = None
+
+
+PREVIEWER_STATE = BlehPreviewerState()
+
_ORIG_PREVIEWER = latent_preview.TAESDPreviewerImpl
_ORIG_GET_PREVIEWER = latent_preview.get_previewer
-LAST_LATENT_FORMAT = None
-
# Referenced from https://github.com/learnables/learn2learn/blob/752200384c3ca8caeb8487b5dd1afd6568e8ec01/learn2learn/utils/__init__.py#L51
def clone_module(module, *, memo: dict | None = None) -> torch.nn.Module:
@@ -81,21 +94,73 @@ def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)):
)
+class LastPreview:
+ image: bytes | None
+ stamp: float | None
+ content_type: str | None
+
+ dum_page = """
+
+
+ bleh preview
+
+
+
+
+
+
+
+
+ """
+
+ def __init__(self):
+ self.image = None
+ self.stamp = None
+ self.content_type = None
+
+ def update(
+ self, *, image_bytes: bytes, content_type: str, stamp: float | None = None
+ ):
+ self.image = image_bytes
+ self.stamp = time() if stamp is None else stamp
+ self.content_type = content_type
+
+ async def __call__(self, request: web.Request):
+ if request.path.endswith(".html"):
+ return web.Response(body=self.dum_page, content_type="text/html")
+ if self.image is None or self.content_type is None:
+ raise web.HTTPNotFound(reason="OHNO")
+ return web.Response(body=self.image, content_type=self.content_type)
+
+
+LAST_PREVIEW = LastPreview()
+PromptServer.instance.routes.get("/bleh/last_preview")(LAST_PREVIEW)
+PromptServer.instance.routes.get("/bleh/last_preview.html")(LAST_PREVIEW)
+
+
class ImageWrapper:
- def __init__(self, frames: tuple, frame_duration: int):
- self._frames = frames
+ def __init__(self, frames: tuple | Image, frame_duration: int = 250):
+ self._frames = (frames,) if not isinstance(frames, (tuple, list)) else frames
self._frame_duration = frame_duration
def save(self, fp, format: str | None, **kwargs: dict): # noqa: A002
- if len(self._frames) == 1:
+ if len(self._frames) > 1:
+ kwargs |= {
+ "loop": 0,
+ "save_all": True,
+ "append_images": self._frames[1:],
+ "duration": self._frame_duration,
+ }
+ format = "webp"
+ if not SETTINGS.btp_publish_last_preview:
return self._frames[0].save(fp, format, **kwargs)
- kwargs |= {
- "loop": 0,
- "save_all": True,
- "append_images": self._frames[1:],
- "duration": self._frame_duration,
- }
- return self._frames[0].save(fp, "webp", **kwargs)
+ buf = BytesIO()
+ result = self._frames[0].save(buf, format, **kwargs)
+ # FIXME
+ image_bytes = buf.getvalue()
+ LAST_PREVIEW.update(image_bytes=image_bytes, content_type=f"image/{format}")
+ fp.write(image_bytes)
+ return result
def resize(self, *args: list, **kwargs: dict) -> ImageWrapper:
return ImageWrapper(
@@ -156,18 +221,26 @@ def __init__(
*,
dtype: torch.dtype,
device: torch.device,
+ height_factor: int = 4,
+ width_factor: int = 1,
normalize_dims: tuple = (-1,),
):
super().__init__()
self.dtype = dtype
self.device = device
self.normalize_dims = normalize_dims
+ self.height_factor = height_factor
+ self.width_factor = width_factor
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, temporal = x.shape[0], x.shape[-1]
x = normalize_to_scale(x, 0.0, 1.0, dim=self.normalize_dims) * 255.0
x = x.reshape(batch, -1, temporal)
+ if self.height_factor > 1:
+ x = x.repeat_interleave(dim=1, repeats=self.height_factor)
+ if self.width_factor > 1:
+ x = x.repeat_interleave(dim=1, repeats=self.width_factor)
return x[..., None].expand(*x.shape, 3)
@@ -179,7 +252,10 @@ def __init__(
latent_format: latent_formats.LatentFormat,
vid_info: VideoModelInfo | None = None,
):
- self.latent_format = latent_format
+ self.orig_latent_format = latent_format
+ self.latent_format = (
+ latent_format if vid_info is None else vid_info.latent_format
+ )
self.latent_format_name = (
"unknown"
if latent_format is None
@@ -406,7 +482,8 @@ def prepare_previewer(
return x0.to(
device=pdevice,
dtype=pdtype,
- non_blocking=device_supports_non_blocking(x0.device),
+ non_blocking=SETTINGS.btp_preview_non_blocking
+ and device_supports_non_blocking(x0.device),
)
def _decode_latent_taevid(self, x0: torch.Tensor) -> tuple[torch.Tensor, int, int]:
@@ -473,12 +550,25 @@ def calc_cols_rows(
rows = math.ceil(batch_size / cols)
return cols, rows
- @classmethod
- def decoded_to_animation(cls, samples: np.ndarray) -> ImageWrapper:
+ def decoded_to_animation(
+ self,
+ samples: np.ndarray,
+ video_frames: int,
+ ) -> ImageWrapper:
batch = samples.shape[0]
+ fps_override = PREVIEWER_STATE.fps_override
+ if self.vid_info is None or not video_frames:
+ frame_duration = 250 if not fps_override else 1000 / fps_override
+ else:
+ time_factor = self.vid_info.temporal_compression / max(
+ 1,
+ self.previewer_model.t_upscale,
+ )
+ ms_frame = 1000.0 / (fps_override or self.vid_info.fps)
+ frame_duration = ms_frame * time_factor
return ImageWrapper(
tuple(Image.fromarray(samples[idx]) for idx in range(batch)),
- frame_duration=250,
+ frame_duration=max(1, int(frame_duration)),
)
def decoded_to_image(
@@ -487,22 +577,23 @@ def decoded_to_image(
cols: int,
rows: int,
*,
- is_video=False,
+ video_frames: int = 0,
) -> Image | ImageWrapper:
batch, (height, width) = samples.shape[0], samples.shape[-3:-1]
samples = samples.to(
device="cpu",
dtype=torch.uint8,
- non_blocking=device_supports_non_blocking(samples.device),
+ non_blocking=SETTINGS.btp_preview_non_blocking
+ and device_supports_non_blocking(samples.device),
).numpy()
if batch == 1:
- self.cached = Image.fromarray(samples[0])
+ self.cached = ImageWrapper((Image.fromarray(samples[0]),))
return self.cached
if SETTINGS.btp_animate_preview == "both" or (
- is_video,
+ video_frames != 0,
SETTINGS.btp_animate_preview,
) in {(True, "video"), (False, "batch")}:
- return self.decoded_to_animation(samples)
+ return self.decoded_to_animation(samples, video_frames=video_frames)
cols, rows = self.calc_cols_rows(batch, width, height)
img_size = (width * cols, height * rows)
if self.cached is not None and self.cached.size == img_size:
@@ -514,7 +605,7 @@ def decoded_to_image(
Image.fromarray(samples[idx]),
box=((idx % cols) * width, ((idx // cols) % rows) * height),
)
- return result
+ return ImageWrapper((result,))
@torch.no_grad()
def init_fallback_previewer(self, device: torch.device, dtype: torch.dtype) -> bool:
@@ -526,7 +617,7 @@ def init_fallback_previewer(self, device: torch.device, dtype: torch.dtype) -> b
and self.fallback_previewer_model.device == device
):
return True
- if self.latent_format_name == "aceaudio":
+ if self.latent_format_name in {"aceaudio", "aceaudio15"}:
self.fallback_previewer_model = ACEStepsPreviewerModel(
device=device,
dtype=dtype,
@@ -563,11 +654,52 @@ def fallback_previewer(self, x0: torch.Tensor, *, quiet=False) -> Image:
except torch.OutOfMemoryError:
return self.blank
+ def ensure_x0_shape(self, x0: torch.Tensor) -> tuple[torch.Tensor, bool]: # noqa: PLR0911
+ expected_channels = self.latent_format.latent_channels
+ expected_ndim = 2 + self.latent_format.latent_dimensions
+ if x0.shape[0] == 0:
+ return x0, False
+ if self.latent_format_name == "aceaudio15" and x0.ndim == expected_ndim + 1:
+ expected_ndim += 1
+ if (
+ x0.ndim > 1
+ and x0.ndim == expected_ndim
+ and x0.shape[1] == expected_channels
+ ):
+ return x0, True
+ last_shapes = PREVIEWER_STATE.last_latent_shapes
+ if not last_shapes or not hasattr(comfy_utils, "unpack_latents"):
+ return x0, False
+ last_numel = sum(math.prod(tshape) for tshape in last_shapes)
+ if last_numel != x0.numel():
+ return x0, False
+ nest_idx = self.vid_info.nested_tensor_index if self.vid_info else 0
+ target_shape = None if len(last_shapes) <= nest_idx else last_shapes[nest_idx]
+ if (
+ # Have to have a nest shape
+ target_shape is None
+ # with at least a channel dimension,
+ or len(target_shape) < 2
+ # with the expected number of dims,
+ or len(target_shape) != expected_ndim
+ # And the correct number of channels.
+ or target_shape[1] != expected_channels
+ ):
+ return x0, False
+ unpacked_latents = comfy_utils.unpack_latents(x0, last_shapes)
+ target_latent = (
+ None if len(unpacked_latents) <= nest_idx else unpacked_latents[nest_idx]
+ )
+ if target_latent is None or target_latent.shape != target_shape:
+ return x0, False
+ return target_latent.reshape(*target_shape), True
+
def decode_latent_to_preview(self, x0: torch.Tensor) -> Image:
if self.check_use_cached():
return self.cached
- if x0.shape[0] == 0:
- return self.blank # Shouldn't actually be possible.
+ x0, can_preview = self.ensure_x0_shape(x0)
+ if not can_preview:
+ return self.blank
if (self.oom_count and not self.oom_retry) or self.previewer_model is None:
return self.fallback_previewer(x0, quiet=True)
is_video = x0.ndim == 5
@@ -579,7 +711,10 @@ def decode_latent_to_preview(self, x0: torch.Tensor) -> Image:
if is_video
else self._decode_latent_taesd(x0)
)
- result = self.decoded_to_image(*dargs, is_video=is_video)
+ result = self.decoded_to_image(
+ *dargs,
+ video_frames=x0.shape[2] if is_video else 0,
+ )
except torch.OutOfMemoryError:
used_fallback = True
result = self.fallback_previewer(x0)
@@ -601,7 +736,11 @@ def orig_get_previewer():
preview_method = comfy_args.preview_method
- if preview_method == LatentPreviewMethod.NoPreviews:
+ if preview_method not in {
+ LatentPreviewMethod.TAESD,
+ LatentPreviewMethod.Auto,
+ LatentPreviewMethod.Latent2RGB,
+ }:
return orig_get_previewer()
format_name = latent_format.__class__.__name__.lower()
@@ -611,29 +750,34 @@ def orig_get_previewer():
or (SETTINGS.btp_whitelist and format_name not in SETTINGS.btp_whitelist)
):
return orig_get_previewer()
+ if format_name in {"aceaudio", "aceaudio15"}:
+ return BetterPreviewer(latent_format=latent_format)
+ vid_info = VIDEO_FORMATS.get(format_name)
+ eff_latent_format = (
+ vid_info.latent_format if vid_info is not None else latent_format
+ )
tae_model = None
if preview_method in {LatentPreviewMethod.TAESD, LatentPreviewMethod.Auto}:
- vid_info = VIDEO_FORMATS.get(format_name)
- if vid_info is not None and vid_info.tae_model is not None:
+ if (
+ vid_info is not None
+ and vid_info.tae_model is not None
+ and vid_info.tae_class is not None
+ ):
tae_model_path = folder_paths.get_full_path(
"vae_approx",
vid_info.tae_model,
)
- tupscale_limit = SETTINGS.btp_video_temporal_upscale_level
- decoder_time_upscale = tuple(
- i < tupscale_limit for i in range(TAEVid.temporal_upscale_blocks)
- )
tae_model = (
- TAEVid(
+ vid_info.tae_class(
checkpoint_path=tae_model_path,
vmi=vid_info,
device=torch.device("cpu"),
- decoder_time_upscale=decoder_time_upscale,
+ decoder_time_upscale_level=SETTINGS.btp_video_temporal_upscale_level,
)
if tae_model_path is not None
else None
)
- if tae_model is None and latent_format.taesd_decoder_name is not None:
+ elif vid_info is None and latent_format.taesd_decoder_name is not None:
taesd_path = folder_paths.get_full_path(
"vae_approx",
f"{latent_format.taesd_decoder_name}.pth",
@@ -653,7 +797,8 @@ def orig_get_previewer():
latent_format=latent_format,
vid_info=vid_info,
)
- if format_name == "aceaudio" or latent_format.latent_rgb_factors is not None:
+ # Using Latent2RGB either via setting or because no preview model.
+ if eff_latent_format.latent_rgb_factors is not None:
return BetterPreviewer(latent_format=latent_format)
return orig_get_previewer()
diff --git a/py/better_previews/tae_vid.py b/py/better_previews/tae_vid.py
index 325f2f0..f676f11 100644
--- a/py/better_previews/tae_vid.py
+++ b/py/better_previews/tae_vid.py
@@ -53,6 +53,10 @@ def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor:
return self.act(self.conv(torch.cat((x, past), 1)) + self.skip(x))
+def make_memblocks(n: int, *, count: int = 3) -> tuple[MemBlock, ...]:
+ return tuple(MemBlock(n, n) for _ in range(count))
+
+
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
@@ -183,8 +187,8 @@ def apply(self, x: torch.Tensor, *, show_progress=False) -> torch.Tensor:
return torch.stack(out, 1)
-class TAEVid(nn.Module):
- temporal_upscale_blocks = 2
+class TAEVidBase(nn.Module):
+ temporal_upscale_blocks = 3
spatial_upscale_blocks = 3
_nf = (256, 128, 64, 64)
@@ -195,61 +199,33 @@ def __init__(
vmi: VideoModelInfo,
image_channels: int = 3,
device="cpu",
- decoder_time_upscale=(True, True),
- decoder_space_upscale=(True, True, True),
+ encoder_time_downscale_level: int = 3,
+ decoder_time_upscale_level: int = 3,
+ decoder_space_upscale_level: int = 3,
):
- n_f = self._nf
super().__init__()
self.vmi = vmi
- self.latent_channels = vmi.latent_format.latent_channels
self.image_channels = image_channels
+ self.latent_channels = vmi.latent_format.latent_channels
self.patch_size = vmi.patch_size
- self.encoder = nn.Sequential(
- conv(image_channels * self.patch_size**2, 64),
- nn.ReLU(inplace=True),
- TPool(64, 2),
- conv(64, 64, stride=2, bias=False),
- MemBlock(64, 64),
- MemBlock(64, 64),
- MemBlock(64, 64),
- TPool(64, 2),
- conv(64, 64, stride=2, bias=False),
- MemBlock(64, 64),
- MemBlock(64, 64),
- MemBlock(64, 64),
- TPool(64, 1),
- conv(64, 64, stride=2, bias=False),
- MemBlock(64, 64),
- MemBlock(64, 64),
- MemBlock(64, 64),
- conv(64, vmi.latent_format.latent_channels),
+ encoder_time_downscale = self._get_encoder_flags(
+ time_level=encoder_time_downscale_level,
)
- self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
- self.decoder = nn.Sequential(
- Clamp(),
- conv(vmi.latent_format.latent_channels, n_f[0]),
- nn.ReLU(inplace=True),
- MemBlock(n_f[0], n_f[0]),
- MemBlock(n_f[0], n_f[0]),
- MemBlock(n_f[0], n_f[0]),
- nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
- TGrow(n_f[0], 1),
- conv(n_f[0], n_f[1], bias=False),
- MemBlock(n_f[1], n_f[1]),
- MemBlock(n_f[1], n_f[1]),
- MemBlock(n_f[1], n_f[1]),
- nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
- TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
- conv(n_f[1], n_f[2], bias=False),
- MemBlock(n_f[2], n_f[2]),
- MemBlock(n_f[2], n_f[2]),
- MemBlock(n_f[2], n_f[2]),
- nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
- TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
- conv(n_f[2], n_f[3], bias=False),
- nn.ReLU(inplace=True),
- conv(n_f[3], image_channels * self.patch_size**2),
+ decoder_time_upscale, decoder_space_upscale = self._get_decoder_flags(
+ time_level=decoder_time_upscale_level,
+ space_level=decoder_space_upscale_level,
+ )
+ encoder_strides = tuple(1 + int(flag) for flag in encoder_time_downscale)
+ decoder_strides = tuple(1 + int(flag) for flag in decoder_time_upscale)
+ decoder_scale_factors = tuple(1 + int(flag) for flag in decoder_space_upscale)
+ self.encoder = self._build_encoder(strides=encoder_strides)
+ self.decoder = self._build_decoder(
+ strides=decoder_strides,
+ scale_factors=decoder_scale_factors,
)
+ self.t_upscale = 2 ** sum(decoder_time_upscale)
+ self.t_downscale = 2 ** sum(encoder_time_downscale)
+ self.frames_to_trim = self.t_upscale - 1
if checkpoint_path is None:
return
self.load_state_dict(
@@ -258,6 +234,66 @@ def __init__(
),
)
+ def _get_decoder_flags(
+ self,
+ *,
+ time_level: int = 3,
+ space_level: int = 3,
+ ) -> tuple[tuple[bool, ...], tuple[bool, ...]]:
+ decoder_time_upscale = tuple(i < time_level for i in range(3))
+ decoder_space_upscale = tuple(i < space_level for i in range(3))
+ return decoder_time_upscale, decoder_space_upscale
+
+ def _get_encoder_flags(
+ self,
+ *,
+ time_level: int = 3,
+ ) -> tuple[bool, ...]:
+ return tuple(i < time_level for i in range(3))
+
+ def _build_decoder(
+ self,
+ *,
+ strides: tuple[int, ...],
+ scale_factors: tuple[int, ...],
+ ) -> nn.Module:
+ n_f = self._nf
+ return nn.Sequential(
+ Clamp(),
+ conv(self.latent_channels, n_f[0]),
+ nn.ReLU(inplace=True),
+ *make_memblocks(n_f[0]),
+ nn.Upsample(scale_factor=scale_factors[0]),
+ TGrow(n_f[0], strides[0]),
+ conv(n_f[0], n_f[1], bias=False),
+ *make_memblocks(n_f[1]),
+ nn.Upsample(scale_factor=scale_factors[1]),
+ TGrow(n_f[1], strides[1]),
+ conv(n_f[1], n_f[2], bias=False),
+ *make_memblocks(n_f[2]),
+ nn.Upsample(scale_factor=scale_factors[2]),
+ TGrow(n_f[2], strides[2]),
+ conv(n_f[2], n_f[3], bias=False),
+ nn.ReLU(inplace=True),
+ conv(n_f[3], self.image_channels * self.patch_size**2),
+ )
+
+ def _build_encoder(self, *, strides: tuple[int, ...]) -> nn.Module:
+ return nn.Sequential(
+ conv(self.image_channels * self.patch_size**2, 64),
+ nn.ReLU(inplace=True),
+ TPool(64, strides[0]),
+ conv(64, 64, stride=2, bias=False),
+ *make_memblocks(64),
+ TPool(64, strides[1]),
+ conv(64, 64, stride=2, bias=False),
+ *make_memblocks(64),
+ TPool(64, strides[2]),
+ conv(64, 64, stride=2, bias=False),
+ *make_memblocks(64),
+ conv(64, self.latent_channels),
+ )
+
def patch_tgrow_layers(self, sd: dict) -> dict:
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
@@ -304,8 +340,15 @@ def apply(
show_progress=False,
) -> torch.Tensor:
model = self.decoder if decode else self.encoder
- if not decode and self.vmi.patch_size > 1:
- x = F.pixel_unshuffle(x, self.patch_size)
+ if not decode:
+ if self.vmi.patch_size > 1:
+ x = F.pixel_unshuffle(x, self.patch_size)
+ # Pad handling copied from https://github.com/madebyollin
+ if x.shape[1] % self.t_downscale != 0:
+ # pad at end to multiple of self.t_downscale
+ n_pad = self.t_downscale - x.shape[1] % self.t_downscale
+ padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
+ x = torch.cat([x, padding], 1)
if parallel:
result = self.apply_parallel(x, model, show_progress=show_progress)
else:
@@ -324,3 +367,45 @@ def encode(self, *args: list, **kwargs: dict) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.c(x)
+
+
+class TAEVid(TAEVidBase):
+ def _get_decoder_flags(
+ self,
+ *,
+ time_level: int = 3,
+ space_level: int = 3,
+ ) -> tuple[tuple[bool, ...], tuple[bool, ...]]:
+ tu, su = super()._get_decoder_flags(
+ time_level=time_level,
+ space_level=space_level,
+ )
+ return (False, *tu[:2]), su
+
+ def _get_encoder_flags(
+ self,
+ *,
+ time_level: int = 3,
+ ) -> tuple[bool, ...]:
+ return (*super()._get_encoder_flags(time_level=time_level)[:2], False)
+
+
+class TAEVidLTX2(TAEVidBase):
+ def _get_decoder_flags(
+ self,
+ *,
+ time_level: int = 3,
+ space_level: int = 3,
+ ) -> tuple[tuple[bool, ...], tuple[bool, ...]]:
+ _tu, su = super()._get_decoder_flags(
+ time_level=time_level,
+ space_level=space_level,
+ )
+ return (True, True, True), su
+
+ def _get_encoder_flags(
+ self,
+ *,
+ time_level: int = 3, # noqa: ARG002
+ ) -> tuple[bool, ...]:
+ return (True, True, True)
diff --git a/py/latent_utils.py b/py/latent_utils.py
index a75364d..55c460d 100644
--- a/py/latent_utils.py
+++ b/py/latent_utils.py
@@ -5,13 +5,20 @@
import math
import os
from functools import partial
-from typing import ClassVar
+from tokenize import triple_quoted
+from typing import TYPE_CHECKING, Any
import kornia.filters as kf
import numpy as np
import torch
import torch.nn.functional as nnf
from torch import FloatTensor, LongTensor, fft
+from tqdm import tqdm
+
+from . import wavelet_functions as wavef
+
+if TYPE_CHECKING:
+ from collections.abc import Callable, Sequence
OVERRIDE_NO_SCALE = "COMFYUI_BLEH_OVERRIDE_NO_SCALE" in os.environ
USE_ORIG_NORMALIZE = "COMFYUI_BLEH_ORIG_NORMALIZE" in os.environ
@@ -798,6 +805,347 @@ def blend_blend(
)
+def ortho_blend(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: torch.Tensor,
+ *,
+ blend_mode: str | Callable | None = None,
+ proj_scale: float = -1.0,
+ ortho_scale: float = 1.0,
+ start_dim: int = 1,
+ end_dim: int = -1,
+ rescale_limit: float = 0.0,
+ # a, b, blend or None
+ rescale_result_mode: str | None = None,
+ # When rescale_target mode is blend, will use blend_mode if None.
+ rescale_result_blend_mode: str | Callable | None = None,
+ # LERP if None.
+ dyn_result_blend_mode: str | Callable | None = None,
+ dyn_ortho_mode: bool = False,
+ dyn_min_scale: float = 0.0,
+ dyn_max_scale: float = 1.0,
+ # Can only be used when the flattened tensor has 4 dimensions left.
+ smooth_factor_kernel_size: int | tuple[int, int] = 0,
+ ortho_verbose: bool = False,
+ eps: float = 1e-06,
+) -> torch.Tensor:
+ orig_shape = a.shape
+ ndim = a.ndim
+ if start_dim < 0:
+ start_dim = max(0, min(ndim + start_dim, ndim - 1))
+ if end_dim < 0:
+ end_dim = max(0, min(ndim + end_dim, ndim - 1))
+ if start_dim > end_dim:
+ start_dim, end_dim = end_dim, start_dim
+ sync_t = t.ndim == ndim
+ if sync_t:
+ t = t.flatten(start_dim=start_dim, end_dim=end_dim)
+ a = a.flatten(start_dim=start_dim, end_dim=end_dim)
+ b = b.flatten(start_dim=start_dim, end_dim=end_dim)
+ if end_dim != ndim - 1:
+ a = a.movedim(start_dim, -1)
+ b = b.movedim(start_dim, -1)
+ if sync_t:
+ t = t.movedim(start_dim, -1)
+ if start_dim == 0:
+ a = a.unsqueeze(0)
+ b = b.unsqueeze(0)
+ if sync_t:
+ t = t.unsqueeze(0)
+ b_normed = b.norm(dim=-1, keepdim=True) if rescale_limit else None
+ dot_ba = (b * a).sum(dim=-1, keepdim=True)
+ dot_aa = (a**2).sum(dim=-1, keepdim=True)
+ proj = (dot_ba / (dot_aa + eps)) * a
+ proj *= proj_scale
+ b_ortho = proj.add_(b if ortho_scale == 1.0 else b * ortho_scale)
+ if b_normed is not None:
+ rescale_limit = abs(rescale_limit)
+ if rescale_limit == 1:
+ rescale_limit += eps
+ b_ortho_normed = b_ortho.norm(dim=-1, keepdim=True)
+ b_ortho_normed += eps
+ b_normed /= b_ortho_normed
+ b_normed = b_normed.clamp_(-rescale_limit, rescale_limit)
+ b_ortho *= b_normed
+ if blend_mode is None:
+
+ def blend_function(a, b, t):
+ return (b * t).add_(a)
+ else:
+ blend_function = (
+ BLENDING_MODES[blend_mode] if isinstance(blend_mode, str) else blend_mode
+ )
+ ortho_result = blend_function(a, b_ortho, t)
+ if rescale_result_mode == "a":
+ rescale_result_target = a
+ elif rescale_result_mode == "b":
+ rescale_result_target = b
+ elif rescale_result_mode == "blend":
+ rr_blend_function = (
+ blend_function
+ if rescale_result_blend_mode is None
+ else (
+ BLENDING_MODES[rescale_result_blend_mode]
+ if isinstance(rescale_result_blend_mode, str)
+ else rescale_result_blend_mode
+ )
+ )
+ rescale_result_target = rr_blend_function(a, b, t)
+ else:
+ rescale_result_target = None
+ if rescale_result_target is not None:
+ result_norm = ortho_result.norm(dim=-1, keepdim=True).add_(eps)
+ target_norm = rescale_result_target.norm(dim=-1, keepdim=True)
+ target_norm /= result_norm
+ ortho_result *= target_norm
+ if b_normed is not None and dyn_ortho_mode:
+ vanilla_result = (
+ rr_blend_function(a, b, t)
+ if rescale_result_mode != "blend"
+ else rescale_result_target
+ )
+ dyn_blend_function = (
+ torch.lerp
+ if dyn_result_blend_mode is None
+ else (
+ BLENDING_MODES[dyn_result_blend_mode]
+ if isinstance(dyn_result_blend_mode, str)
+ else dyn_result_blend_mode
+ )
+ )
+ ortho_factor = (
+ (1.0 - ((b_normed - 1.0) / (rescale_limit - 1.0)).clamp_(0.0, 1.0))
+ .add_(dyn_min_scale)
+ .mul_(dyn_max_scale - dyn_min_scale)
+ )
+ if smooth_factor_kernel_size != 0:
+ if ortho_factor.ndim < 4:
+ raise ValueError(
+ f"Can't use smooth_factor_kernel_size when ortho_factor has less than 4 dimensions. It has shape: {ortho_factor.shape}",
+ )
+ ortho_factor = (
+ torch.nn.functional.avg_pool2d(
+ ortho_factor.movedim(-1, -3),
+ kernel_size=smooth_factor_kernel_size,
+ stride=1,
+ padding=1,
+ )
+ .movedim(-3, -1)
+ .clamp_(dyn_min_scale, dyn_max_scale)
+ )
+ ortho_result = dyn_blend_function(vanilla_result, ortho_result, ortho_factor)
+ if ortho_verbose:
+ tqdm.write(
+ f"ORTHO BLEND: b_norm min/max={b_normed.aminmax()}, avg: {ortho_factor.mean().item():.5f}, min: {ortho_factor.min().item():.5f}, max: {ortho_factor.max().item():.5f}",
+ )
+ if end_dim != ndim - 1:
+ ortho_result = ortho_result.movedim(-1, start_dim)
+ return ortho_result.reshape(orig_shape)
+
+
+def symmetric_ortho_blend(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: torch.Tensor,
+ *,
+ symmetric_strength: float = 1.0,
+ symmetric_deduce_mode: bool = False,
+ **kwargs: dict,
+) -> torch.Tensor:
+ blended = ortho_blend(a, b, t, **kwargs)
+ if symmetric_strength == 0.0:
+ return blended
+ b_ortho = blended.sub_(a)
+ if symmetric_deduce_mode:
+ b_proj = b - b_ortho
+ # Projection would theoretically be the same for both, in the simple case at least?
+ # Actually, probably not. Oh well, this is here as an option now.
+ a_ortho = a - b_proj
+ else:
+ a_ortho = ortho_blend(b, a, a.new_tensor(1.0), **kwargs) - b
+ a_proj = a - a_ortho
+ return a_proj.mul_(1.0 - symmetric_strength).add_(a_ortho).add_(b_ortho)
+
+
+class WaveletBlend:
+ wavelet: wavef.Wavelet | None = None
+ use_float64: bool = False
+
+ def __init__(
+ self,
+ *,
+ device: str | torch.device | None = None,
+ use_float64: bool = False,
+ **kwargs: dict,
+ ):
+ self.device = device
+ self.wavelet_kwargs = kwargs
+ self.use_float64 = use_float64
+
+ def get_wavelet(self, *, device: str | torch.device | None = None) -> wavef.Wavelet:
+ if self.wavelet is None:
+ self.wavelet = wavef.Wavelet(
+ device=device if device is not None else self.device,
+ **self.wavelet_kwargs,
+ ).to(dtype=torch.float64 if self.use_float64 else torch.float32)
+ self.device = device
+ return self.wavelet
+ if device is not None and self.wavelet.device != device:
+ self.wavelet = self.wavelet.to(device=device)
+ self.device = device
+ return self.wavelet
+
+ @staticmethod
+ def maybe_offset(
+ yl: torch.Tensor,
+ yh: Sequence[torch.Tensor],
+ offset_yl: float | torch.Tensor | None,
+ offset_yh: float | Sequence[float | Sequence[float]] | None,
+ *,
+ in_place: bool = False,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]:
+ if offset_yl in {None, 1.0} and offset_yh in {None, 1.0}:
+ return (yl, tuple(yh))
+ return wavef.wavelet_scaling(
+ yl,
+ yh,
+ yl_scale=offset_yl if offset_yl is not None else 1.0,
+ yh_scales=offset_yh,
+ in_place=in_place,
+ )
+
+ def wavelet_blend(
+ self,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: float | torch.Tensor,
+ *,
+ blend_mode_yl: str | Callable = torch.lerp,
+ blend_mode_yh: str | Callable | None = None,
+ a_offset_yl: float | torch.Tensor | None = None,
+ a_offset_yh: float | Sequence[float | Sequence[float]] | None = None,
+ b_offset_yl: float | torch.Tensor | None = None,
+ b_offset_yh: float | Sequence[float | Sequence[float]] | None = None,
+ out_offset_yl: float | torch.Tensor | None = None,
+ out_offset_yh: float | Sequence[float | Sequence[float]] | None = None,
+ blend_yl_offset: float = 1.0,
+ blend_yh_offset: float | torch.Tensor = 1.0,
+ two_step_inverse: bool = False,
+ in_place_offset: bool = True,
+ ) -> torch.Tensor:
+ if isinstance(blend_mode_yl, str):
+ blend_mode_yl = BLENDING_MODES[blend_mode_yl]
+ if blend_mode_yh is None:
+ blend_mode_yh = blend_mode_yl
+ elif isinstance(blend_mode_yh, str):
+ blend_mode_yh = BLENDING_MODES[blend_mode_yh]
+ wavelet = self.get_wavelet(device=a.device)
+ dtype = a.dtype
+ if a.ndim != b.ndim:
+ raise ValueError(
+ f"Tensor a ndim ({a.ndim}) must match tensor b ndim ({b.ndim})"
+ )
+ orig_shape = a.shape
+ # FIXME: This reshaping logic is almost certainly not reliable.
+ if a.ndim > 4:
+ a = a.reshape(a.shape[0], -1, *a.shape[-2:])
+ if b.ndim > 4:
+ b = a.reshape(b.shape[0], -1, *b.shape[-2:])
+ a = a.to(dtype=torch.float64 if self.use_float64 else torch.float32)
+ b = b.to(a)
+ t = a.new_tensor(t) if not isinstance(t, torch.Tensor) else t.to(a)
+ if t.ndim > 4:
+ t = a.reshape(t.shape[0], -1, *t.shape[-2:])
+ aw_l, aw_h = self.maybe_offset(
+ *wavelet.forward(a),
+ a_offset_yl,
+ a_offset_yh,
+ in_place=in_place_offset,
+ )
+ bw_l, bw_h = self.maybe_offset(
+ *wavelet.forward(b),
+ b_offset_yl,
+ b_offset_yh,
+ in_place=in_place_offset,
+ )
+ blend_yl_offset = t if blend_yl_offset == 1 else t * blend_yl_offset
+ blend_yh_offset = t if blend_yh_offset == 1 else t * blend_yh_offset
+ outw_l, outw_h = self.maybe_offset(
+ *wavef.wavelet_blend(
+ (aw_l, aw_h),
+ (bw_l, bw_h),
+ yl_factor=blend_yl_offset,
+ yh_factor=blend_yh_offset,
+ blend_function=blend_mode_yl,
+ yh_blend_function=blend_mode_yh,
+ ),
+ offset_yl=out_offset_yl,
+ offset_yh=out_offset_yh,
+ in_place=in_place_offset,
+ )
+ result = wavelet.inverse(outw_l, outw_h, two_step_inverse=two_step_inverse)
+ result = result[tuple(slice(None, dsize) for dsize in a.shape)]
+ return result.to(dtype=dtype).reshape(orig_shape)
+
+
+WAVELET_BLEND_CACHE: dict[frozenset[tuple[str, Any]], WaveletBlend] = {}
+
+
+def wavelet_blend(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: float | torch.Tensor,
+ *,
+ blend_mode_yl: str | Callable = torch.lerp,
+ blend_mode_yh: str | Callable | None = None,
+ **kwargs: dict[str, Any],
+) -> torch.Tensor:
+ if isinstance(blend_mode_yl, str):
+ blend_mode_yl = BLENDING_MODES[blend_mode_yl]
+ if blend_mode_yh is None:
+ blend_mode_yh = blend_mode_yl
+ _ = kwargs.pop("device", None)
+ wavelet_kwargs = {
+ k: kwargs.pop(k)
+ for k in (
+ "wave",
+ "level",
+ "mode",
+ "use_1d_dwt",
+ "use_dtcwt",
+ "biort",
+ "qshift",
+ "inv_wave",
+ "inv_mode",
+ "inv_biort",
+ "inv_qshift",
+ "two_step_inverse",
+ "use_float64",
+ )
+ if k in kwargs
+ }
+ cache_key = frozenset(
+ (
+ wavelet_kwargs
+ | {"blend_mode_yl": blend_mode_yl, "blend_mode_yh": blend_mode_yh}
+ ).items(),
+ )
+ print(f"\nWAVELET BLEND: cache key: {cache_key}")
+ wb = WAVELET_BLEND_CACHE.get(cache_key)
+ if wb is None:
+ wb = WaveletBlend(device=a.device, **wavelet_kwargs)
+ WAVELET_BLEND_CACHE[cache_key] = wb
+ return wb.wavelet_blend(
+ a,
+ b,
+ t,
+ blend_mode_yl=blend_mode_yl,
+ blend_mode_yh=blend_mode_yh,
+ **kwargs,
+ )
+
+
class BlendMode:
__slots__ = (
"allow_scale",
@@ -805,24 +1153,36 @@ class BlendMode:
"f_kwargs",
"f_raw",
"force_rescale",
+ "fork_rng",
+ "invert_scale",
"norm",
"norm_dims",
"rescale_dims",
+ "rescale_max",
+ "rescale_min",
"rev",
+ "scale_multiplier",
+ "visible",
)
class _Empty:
pass
- def __init__( # noqa: PLR0917
+ def __init__(
self,
f,
norm=None,
- norm_dims=(-3, -2, -1),
- rev=False,
- allow_scale=True,
- rescale_dims=(-3, -2, -1),
- force_rescale=False,
+ norm_dims: tuple = (-3, -2, -1),
+ rev: bool = False,
+ allow_scale: bool = True,
+ rescale_dims: tuple = (-3, -2, -1),
+ rescale_min: float = 0.0,
+ rescale_max: float = 1.0,
+ force_rescale: bool = False,
+ fork_rng: bool = False,
+ invert_scale: float | None = None,
+ scale_multiplier: float = 1.0,
+ visible: bool = True,
**kwargs: dict,
):
self.f_raw = f
@@ -837,37 +1197,35 @@ def __init__( # noqa: PLR0917
self.rev = rev
self.allow_scale = allow_scale
self.rescale_dims = rescale_dims
+ self.rescale_min = rescale_min
+ self.rescale_max = rescale_max
self.force_rescale = force_rescale
+ self.fork_rng = fork_rng
+ self.invert_scale = invert_scale
+ self.scale_multiplier = scale_multiplier
+ self.visible = visible
- def edited(
- self,
- *,
- f=_Empty,
- norm=_Empty,
- norm_dims=_Empty,
- rev=_Empty,
- allow_scale=_Empty,
- rescale_dims=_Empty,
- force_rescale=_Empty,
- preserve_kwargs=True,
- **kwargs: dict,
- ) -> object:
+ def edited(self, *, f=_Empty, preserve_kwargs=True, **kwargs: dict) -> BlendMode:
empty = self._Empty
kwargs = (self.f_kwargs | kwargs) if preserve_kwargs else kwargs
- return self.__class__(
- f if f is not empty else self.f_raw,
- norm=norm if norm is not empty else self.norm,
- norm_dims=norm_dims if norm_dims is not empty else self.norm_dims,
- rev=rev if rev is not empty else self.rev,
- allow_scale=allow_scale if allow_scale is not empty else self.allow_scale,
- rescale_dims=rescale_dims
- if rescale_dims is not empty
- else self.rescale_dims,
- force_rescale=force_rescale
- if force_rescale is not empty
- else self.force_rescale,
- **kwargs,
- )
+ kwargs |= {
+ k: v if (v := kwargs.get(k, empty)) is not empty else getattr(self, k)
+ for k in (
+ "norm",
+ "norm_dims",
+ "rev",
+ "allow_scale",
+ "rescale_dims",
+ "rescale_min",
+ "rescale_max",
+ "force_rescale",
+ "fork_rng",
+ "invert_scale",
+ "scale_multiplier",
+ "visible",
+ )
+ }
+ return self.__class__(f if f is not empty else self.f_raw, **kwargs)
def rescale(self, t, *, rescale_dims=_Empty):
if t.ndim > 2:
@@ -879,24 +1237,47 @@ def rescale(self, t, *, rescale_dims=_Empty):
rescale_dims = -1
tmin = torch.amin(t, keepdim=True, dim=rescale_dims)
tmax = torch.amax(t, keepdim=True, dim=rescale_dims)
- return (t - tmin).div_(tmax - tmin).clamp_(0, 1), tmin, tmax
+ return (
+ (t - tmin).div_(tmax - tmin).clamp_(self.rescale_min, self.rescale_max),
+ tmin,
+ tmax,
+ )
- def __call__(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor:
+ def __call__(
+ self,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: torch.Tensor | float,
+ *,
+ norm_dims=_Empty,
+ ) -> torch.Tensor:
if not self.force_rescale:
return self.__call__internal(a, b, t, norm_dims=norm_dims)
a, amin, amax = self.rescale(a)
b, bmin, bmax = self.rescale(b)
- result = self.__call__internal(a, b, t, norm_dims=norm_dims)
+ with torch.random.fork_rng(devices=(a.device, b.device), enabled=self.fork_rng):
+ result = self.__call__internal(a, b, t, norm_dims=norm_dims)
del a, b
rmin, rmax = torch.lerp(amin, bmin, 0.5), torch.lerp(amax, bmax, 0.5)
del amin, amax, bmin, bmax
return result.mul_(rmax.sub_(rmin)).add_(rmin)
- def __call__internal(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor:
+ def __call__internal(
+ self,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ t: torch.Tensor | float,
+ *,
+ norm_dims=_Empty,
+ ) -> torch.Tensor:
if not isinstance(t, torch.Tensor) and isinstance(a, torch.Tensor):
t = a.new_full((1,), t)
if self.rev:
a, b = b, a
+ if self.invert_scale is not None:
+ t = self.invert_scale - t
+ if self.scale_multiplier != 1.0:
+ t = t * self.scale_multiplier
if self.norm is None:
return self.f(a, b, t)
return self.norm(
@@ -907,11 +1288,36 @@ def __call__internal(self, a, b, t, *, norm_dims=_Empty) -> torch.Tensor:
class BlendingModes:
+ BLEH = True
+
def __init__(self, builtins=None):
self.builtins = {} if builtins is None else builtins
self.cache = {}
- def get(self, k: str, default=None):
+ def get_dict_key(self, k: dict):
+ ds = frozenset(k.items())
+ cached = self.cache.get(ds)
+ if cached is not None:
+ return cached
+ name = k.get("name")
+ if name is None:
+ raise ValueError(
+ "When passing a blend mode key as dict, a string 'name' key must exist."
+ )
+ name = name.strip()
+ base_bm = self.builtins.get(name)
+ if base_bm is None:
+ errstr = f"Unknown mode {name} for extended blend specification"
+ raise ValueError(errstr)
+ bm_kwargs = k.copy()
+ del bm_kwargs["name"]
+ bm = base_bm.edited(**bm_kwargs)
+ self.cache[k] = bm
+ return bm
+
+ def get(self, k: str | dict, default=None):
+ if isinstance(k, dict):
+ return self.get_dict_key(k)
result = self.builtins.get(k)
if result is not None:
return result
@@ -995,16 +1401,16 @@ def try_extended(self, k: str, default=None) -> object:
return bm
def items(self):
- return self.builtins.items()
+ return ((k, v) for k, v in self.builtins.items() if v.visible)
def values(self):
- return self.builtins.values()
+ return (v for v in self.builtins.values() if v.visible)
def __contains__(self, k: str) -> bool:
return self.get(k) is not None
def __iter__(self):
- return self.builtins.__iter__()
+ return (k for k, _v in self.items())
keys = __iter__
@@ -1054,7 +1460,8 @@ def copy(self):
"a_only": BlendMode(lambda a, _b, t: a * t, allow_scale=False),
"b_only": BlendMode(lambda _a, b, t: b * t, allow_scale=False),
# Interpolates between tensors a and b using normalized linear interpolation.
- "bislerp": BlendMode(
+ # This definitely isn't biSLERP.
+ "bislerp_wrong": BlendMode(
lambda a, b, t: ((1 - t) * a).add_(t * b),
normalize,
),
@@ -1118,6 +1525,8 @@ def copy(self):
"inject_copysign_b": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(b)),
"inject_avoidsign_a": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(a.neg())),
"inject_avoidsign_b": BlendMode(lambda a, b, t: (b * t).add_(a).copysign_(b.neg())),
+ "cfg": BlendMode(lambda a, b, t: (a - b).mul_(t).add_(b)),
+ "cfg_base_a": BlendMode(lambda a, b, t: (a - b).mul_(t).add_(a)),
# Interpolates between tensors a and b using linear interpolation.
# "lerp": BlendMode(lambda a, b, t: ((1.0 - t) * a).add_(t * b)),
"lerp": BlendMode(torch.lerp),
@@ -1138,6 +1547,9 @@ def copy(self):
"lerp_avoidsign_b": BlendMode(
lambda a, b, t: ((1.0 - t) * a).add_(t * b).copysign_(b.neg()),
),
+ "weighted_average": BlendMode(
+ lambda a, b, t: (b * t).add_(a) / (1.0 + abs(t)),
+ ),
# Simulates a brightening effect by adding tensor b to tensor a, scaled by t.
"lineardodge": BlendMode(lambda a, b, t: (b * t).add_(a)),
"copysign": BlendMode(lambda a, b, _t: torch.copysign(a, b)),
@@ -1251,6 +1663,10 @@ def copy(self):
normalize,
allow_scale=False,
),
+ "multiply_by_b": BlendMode(
+ lambda a, b, _t: a * b,
+ allow_scale=False,
+ ),
"overlay": BlendMode(
lambda a, b, t: (2 * a * b + a**2 - 2 * a * b * a) * t
if torch.all(b < 0.5)
@@ -1296,6 +1712,65 @@ def copy(self):
allow_scale=False,
force_rescale=True,
),
+ "wavelet_b_hi_100_lo_0": BlendMode(
+ f=wavelet_blend,
+ blend_yl_offset=0.0,
+ blend_yh_offset=1.0,
+ wave="db4",
+ level=8,
+ ),
+ "wavelet_b_hi_0_lo_100": BlendMode(
+ f=wavelet_blend,
+ blend_yl_offset=1.0,
+ blend_yh_offset=0.0,
+ wave="db4",
+ level=8,
+ ),
+ "ortho": BlendMode(ortho_blend),
+ "ortho_rescaled": BlendMode(ortho_blend, rescale_limit=2.0),
+ "ortho_rescaled_lerpish": BlendMode(
+ ortho_blend,
+ rescale_limit=2.0,
+ rescale_result_blend_mode="lerp",
+ rescale_result_mode="blend",
+ ),
+ "ortho_lerp": BlendMode(ortho_blend, blend_mode="lerp"),
+ "ortho_dyn_lerp": BlendMode(
+ ortho_blend,
+ blend_mode="lerp",
+ rescale_result_mode="blend",
+ rescale_limit=4.0,
+ dyn_ortho_mode=True,
+ ),
+ "ortho_dyn_lerp_inverted": BlendMode(
+ ortho_blend,
+ blend_mode="lerp",
+ rescale_result_mode="blend",
+ rescale_limit=2.0,
+ dyn_ortho_mode=True,
+ rev=True,
+ invert_scale=1.0,
+ ),
+ "ortho_lerp_rescaled": BlendMode(
+ ortho_blend,
+ blend_mode="lerp",
+ rescale_result_mode="blend",
+ rescale_limit=2.0,
+ ),
+ "ortho_cfg": BlendMode(
+ lambda a, b, t, **kwargs: ortho_blend(b, a - b, t, **kwargs),
+ ),
+ "ortho_cfg_base_a": BlendMode(
+ lambda a, b, t, **kwargs: ortho_blend(a, a - b, t, **kwargs),
+ ),
+ "symmetric_ortho": BlendMode(symmetric_ortho_blend),
+ "symmetric_ortho_rescaled": BlendMode(symmetric_ortho_blend, rescale_limit=2.0),
+ "symmetric_ortho_cfg": BlendMode(
+ lambda a, b, t, **kwargs: symmetric_ortho_blend(b, a - b, t, **kwargs),
+ ),
+ "symmetric_ortho_cfg_base_a": BlendMode(
+ lambda a, b, t, **kwargs: symmetric_ortho_blend(a, a - b, t, **kwargs),
+ ),
}
BLENDING_MODES |= {
@@ -1319,10 +1794,10 @@ def copy(self):
"bislerp": slerp_orig,
"altbislerp": altslerp,
"revaltbislerp": lambda a, b, t: altslerp(b, a, t),
- "bibislerp": BLENDING_MODES["bislerp"].edited(norm_dims=0),
+ "bibislerp": BLENDING_MODES["bislerp_wrong"].edited(norm_dims=0),
"revhslerp": lambda a, b, t: hslerp_alt(b, a, t),
"revbislerp": lambda a, b, t: slerp_orig(b, a, t),
- "revbibislerp": BLENDING_MODES["revbislerp"].edited(norm_dims=0),
+ "revbibislerp": BLENDING_MODES["revbislerp_wrong"].edited(norm_dims=0),
}
FILTER_PRESETS = {
@@ -1611,7 +2086,9 @@ def biderp(samples, width, height, mode="bislerp", mode_h=None): # noqa: PLR091
mode_h = mode
derp_w = (BIDERP_MODES if ":" not in mode else BLENDING_MODES).get(mode, slerp_orig)
- derp_h = (BIDERP_MODES if ":" not in mode_h else BLENDING_MODES).get(mode_h, slerp_orig)
+ derp_h = (BIDERP_MODES if ":" not in mode_h else BLENDING_MODES).get(
+ mode_h, slerp_orig
+ )
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(
diff --git a/py/nodes/__init__.py b/py/nodes/__init__.py
index f8358ec..8f6ca9f 100644
--- a/py/nodes/__init__.py
+++ b/py/nodes/__init__.py
@@ -36,10 +36,14 @@
"BlehPlug": misc.BlehPlug,
"BlehRefinerAfter": refinerAfter.BlehRefinerAfter,
"BlehSageAttentionSampler": sageAttention.BlehSageAttentionSampler,
+ "BlehAdvancedAttentionSampler": sageAttention.BlehAdvancedAttentionSampler,
"BlehSetSamplerPreset": samplers.BlehSetSamplerPreset,
"BlehSetSigmas": misc.BlehSetSigmas,
"BlehTAEVideoDecode": taevid.TAEVideoDecode,
"BlehTAEVideoEncode": taevid.TAEVideoEncode,
+ "BlehModelProcessLatentIn": misc.BlehModelProcessLatentIn,
+ "BlehModelProcessLatentOut": misc.BlehModelProcessLatentOut,
+ "BlehFixGuiderPreviewing": misc.BlehFixGuiderPreviewing,
}
NODE_DISPLAY_NAME_MAPPINGS = {
diff --git a/py/nodes/blockCFG.py b/py/nodes/blockCFG.py
index 164769e..5918e2a 100644
--- a/py/nodes/blockCFG.py
+++ b/py/nodes/blockCFG.py
@@ -1,4 +1,60 @@
-from functools import partial
+import math
+from dataclasses import dataclass
+from enum import Enum, auto
+from functools import partial, reduce
+
+import torch
+from tqdm import tqdm
+
+
+class BlockType(Enum):
+ INPUT = auto()
+ OUTPUT = auto()
+ MIDDLE = auto()
+ ATTN_Q = auto()
+ ATTN_K = auto()
+ ATTN_V = auto()
+ ATTN = auto()
+
+
+class BlendType(Enum):
+ DIFF = auto()
+ RESULT = auto()
+
+
+class CondType(Enum):
+ COND = auto()
+ UNCOND = auto()
+ BOTH = auto()
+
+
+@dataclass
+class BlockCFGItem:
+ start_sigma: float
+ end_sigma: float
+ block_type: BlockType
+ target_type: BlendType
+ cond_type: CondType
+ block_num: int
+ scale: float
+ skip_mode: bool
+
+
+class BlockCFG:
+ start_sigma: float
+ end_sigma: float
+ block_types: frozenset[BlockType]
+ verbose: bool = True
+
+ def __init__(self, items: tuple[BlockCFGItem, ...]):
+ self.start_sigma, self.end_sigma = reduce(
+ lambda old, new: (min(old[0], new[0]), max(old[1], new[1])),
+ ((i.start_sigma, i.end_sigma) for i in items),
+ (math.inf, math.inf * -1),
+ )
+ self.block_types = frozenset(i.block_type for i in items)
+
+ # def check_applies(self,
class BlockCFGBleh:
@@ -123,6 +179,7 @@ def patch(
reverse = apply_to != "cond"
def check_applies(block_list, transformer_options):
+ tqdm.write(f"* BLOCKCFG: tf={transformer_options}")
cond_or_uncond = transformer_options["cond_or_uncond"]
if (
not (0 in cond_or_uncond and 1 in cond_or_uncond)
@@ -140,6 +197,13 @@ def check_applies(block_list, transformer_options):
return -1 in block_list
return block_def in {-1, transformer_options.get("transformer_index")}
+ def apply_cfg_fun_(tensor: torch.Tensor, primary_offset: int) -> torch.Tensor:
+ full_batch = tensor.shape[0]
+ if full_batch % 2:
+ raise RuntimeError("Batch size must be multiple of 2")
+ batch = full_batch // 2
+ diff = tensor[:batch, ...] - tensor[batch:, ...]
+
def apply_cfg_fun(tensor, primary_offset):
secondary_offset = 0 if primary_offset == 1 else 1
if reverse:
@@ -156,12 +220,15 @@ def apply_cfg_fun(tensor, primary_offset):
).mul_(scale)
return result
+ mid_patch = None
+
def non_output_block_patch(h, transformer_options, *, block_list):
+ nonlocal mid_patch
+ # print("\nSET????", mid_patch)
+ if mid_patch is not None:
+ mid_patch._bleh_set_topts(transformer_options)
cond_or_uncond = transformer_options["cond_or_uncond"]
- if not check_applies(
- block_list,
- transformer_options,
- ):
+ if not block_list or not check_applies(block_list, transformer_options):
return h
return apply_cfg_fun(h, cond_or_uncond.index(0))
@@ -180,7 +247,55 @@ def output_block_patch(h, hsp, transformer_options, *, block_list):
)
m = model.clone()
- if input_blocks:
+
+ if middle_blocks:
+ # print("******** MIDDLE")
+ try:
+ mb = model.get_model_object("diffusion_model.middle_block.0")
+ except AttributeError:
+ mb = None
+ orig_forward = getattr(mb, "forward", None)
+ if mb is None or orig_forward is None:
+ raise ValueError("Could not get middle block or forward")
+
+ class MBForward:
+ def __init__(self, orig_forward):
+ real_orig_forward = orig_forward
+ while temp := getattr(
+ real_orig_forward, "_bleh_orig_forward", None
+ ):
+ real_orig_forward = temp
+ orig_forward = real_orig_forward
+ self._bleh_orig_forward = orig_forward
+ self._bleh_topts = None
+
+ def _bleh_set_topts(self, transformer_options: dict) -> None:
+ # if self._bleh_topts:
+ # return
+ cond_or_uncond = transformer_options["cond_or_uncond"]
+ self._bleh_topts = {
+ "cond_or_uncond": cond_or_uncond.clone()
+ if isinstance(cond_or_uncond, torch.Tensor)
+ else cond_or_uncond,
+ "sigmas": transformer_options["sigmas"].clone(),
+ "block": ("middle", 0),
+ }
+
+ def __call__(self, *args: list, **kwargs: dict) -> torch.Tensor:
+ result = self._bleh_orig_forward(*args, **kwargs)
+ try:
+ return non_output_block_patch(
+ result,
+ self._bleh_topts,
+ block_list=middle_blocks,
+ )
+ finally:
+ self._bleh_topts = None
+
+ mid_patch = MBForward(orig_forward)
+ m.add_object_patch("diffusion_model.middle_block.0.forward", mid_patch)
+
+ if input_blocks or middle_blocks:
(
m.set_model_input_block_patch_after_skip
if skip_mode
@@ -188,11 +303,7 @@ def output_block_patch(h, hsp, transformer_options, *, block_list):
)(
partial(non_output_block_patch, block_list=input_blocks),
)
- if middle_blocks:
- m.set_model_patch(
- partial(non_output_block_patch, block_list=middle_blocks),
- "middle_block_patch",
- )
+
if output_blocks:
m.set_model_output_block_patch(
partial(output_block_patch, block_list=output_blocks),
diff --git a/py/nodes/misc.py b/py/nodes/misc.py
index 61c55f4..4dfac74 100644
--- a/py/nodes/misc.py
+++ b/py/nodes/misc.py
@@ -2,6 +2,7 @@
from __future__ import annotations
import contextlib
+import math
import operator
import random
from decimal import Decimal
@@ -11,7 +12,7 @@
from comfy import model_management
from comfy.model_management import throw_exception_if_processing_interrupted
-from ..better_previews.previewer import ensure_previewer
+from ..better_previews.previewer import PREVIEWER_STATE, ensure_previewer
from ..latent_utils import normalize_to_scale
@@ -482,3 +483,159 @@ def output_block_patch(h, hsp, _transformer_options):
m.set_model_output_block_patch(output_block_patch)
return (m,)
+
+
+class BlehModelProcessLatentIn:
+ DESCRIPTION = "Advanced node that can be used to scale a raw latent for model input. Generally only needed if you're doing something that bypasses the normal latent input mechanisms."
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "go"
+ CATEGORY = "latent/advanced"
+
+ @classmethod
+ def INPUT_TYPES(cls) -> dict:
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "latent": ("LATENT",),
+ },
+ }
+
+ @classmethod
+ def go(cls, *, model, latent: dict) -> tuple[dict]:
+ latent_format = model.model.latent_format
+ samples = (
+ latent["samples"].detach().to(device="cpu", dtype=torch.float32, copy=True)
+ )
+ return (latent | {"samples": latent_format.process_in(samples)},)
+
+
+class BlehModelProcessLatentOut:
+ DESCRIPTION = "Advanced node that can be used to scale a latent to the correct range for output. Generally only needed if you're doing something that bypasses the normal latent output mechanisms."
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "go"
+ CATEGORY = "latent/advanced"
+
+ @classmethod
+ def INPUT_TYPES(cls) -> dict:
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "latent": ("LATENT",),
+ },
+ }
+
+ @classmethod
+ def go(cls, *, model, latent: dict) -> tuple[dict]:
+ latent_format = model.model.latent_format
+ samples = (
+ latent["samples"].detach().to(device="cpu", dtype=torch.float32, copy=True)
+ )
+ return (latent | {"samples": latent_format.process_out(samples)},)
+
+
+class PreviewFixGuider:
+ def __init__(self, guider, *, fps_override: int | float | None = None):
+ self.__guider = guider
+ self.__fps_override = fps_override
+
+ def __getattr__(self, k):
+ return getattr(self.__guider, k)
+
+ def sample(self, noise, latent_image, *args, **kwargs):
+ latent_shapes = (
+ (tuple(latent_image.shape),)
+ if not latent_image.is_nested
+ else tuple(tuple(t.shape) for t in latent_image.unbind())
+ )
+ PREVIEWER_STATE.last_latent_shapes = latent_shapes
+ fps_override = self.__fps_override
+ if fps_override:
+ PREVIEWER_STATE.fps_override = fps_override
+ try:
+ return self.__guider.sample(noise, latent_image, *args, **kwargs)
+ finally:
+ PREVIEWER_STATE.last_latent_shapes = None
+ PREVIEWER_STATE.fps_override = None
+
+ # def sample(self, noise, latent_image, *args, **kwargs):
+ # nest_index = self.__nest_index
+ # orig_callback = kwargs.get("callback")
+ # sample = partial(self.__guider.sample, noise, latent_image, *args)
+ # if not (latent_image.is_nested and orig_callback is not None):
+ # # Either not nested or no callback, so no need to fix previewing.
+ # return sample(**kwargs)
+ # latent_part_sizes = tuple(
+ # t.shape if isinstance(t, torch.Tensor) and not t.is_nested else None
+ # for t in latent_image.unbind()
+ # )
+ # if not (
+ # len(latent_part_sizes) >= nest_index
+ # and all(ps is not None for ps in latent_part_sizes)
+ # ):
+ # # Multiple levels of nesting not yet implemented.
+ # return sample(**kwargs)
+ # offset = 0
+ # # ComfyUI preserves the batch dimension and smashes everything else together.
+ # for ps in latent_part_sizes[:nest_index]:
+ # offset += math.prod(ps[1:])
+ # orig_shape = latent_part_sizes[nest_index]
+ # orig_nelems = math.prod(orig_shape[1:])
+
+ # def cb_wrapper(i, denoised, x, *args, **kwargs) -> None:
+ # print(
+ # f"\n\nCB SHAPES: {denoised.shape}, {x.shape}, orig {orig_shape}, orig elems {orig_nelems}",
+ # )
+ # denoised, x = (
+ # t[:, :, offset : offset + orig_nelems].reshape(
+ # t.shape[0],
+ # *orig_shape[1:],
+ # )
+ # for t in (denoised, x)
+ # )
+ # print(
+ # f"\n\nFIXED CB SHAPES: {denoised.shape}, {x.shape}",
+ # )
+ # return orig_callback(i, denoised, x, *args, **kwargs)
+
+ # kwargs["callback"] = cb_wrapper
+ # return sample(**kwargs)
+
+
+class BlehFixGuiderPreviewing:
+ DESCRIPTION = "Wraps a guider to give the Bleh previewing system a hint about the latent shapes. Only necessary for models like LTX-2 which use nested tensors."
+ FUNCTION = "go"
+ OUTPUT_NODE = False
+ CATEGORY = "hacks"
+
+ RETURN_TYPES = ("GUIDER",)
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "guider": ("GUIDER",),
+ "fps_override": (
+ "FLOAT",
+ {
+ "default": 0.0,
+ "min": 0.0,
+ "max": 9999.0,
+ "tooltip": "Can be used to override the FPS when previewing with video models. Disabled if set to 0.",
+ },
+ ),
+ },
+ }
+
+ @classmethod
+ def go(
+ cls,
+ *,
+ guider,
+ fps_override: float | None = None,
+ ) -> tuple:
+ return (
+ PreviewFixGuider(
+ guider,
+ fps_override=fps_override or None,
+ ),
+ )
diff --git a/py/nodes/sageAttention.py b/py/nodes/sageAttention.py
index af1aa46..3067b49 100644
--- a/py/nodes/sageAttention.py
+++ b/py/nodes/sageAttention.py
@@ -3,12 +3,18 @@
import contextlib
import importlib
-from typing import TYPE_CHECKING
+import math
+from enum import Enum, auto
+from functools import partial, update_wrapper
+from typing import TYPE_CHECKING, Any, NamedTuple
import comfy.ldm.modules.attention as comfyattn
import torch
import yaml
from comfy.samplers import KSAMPLER
+from tqdm import tqdm
+
+from ..latent_utils import BLENDING_MODES
try:
import sageattention
@@ -51,6 +57,22 @@
HAVE_ATTN_OVERRIDE = hasattr(comfyattn, "register_attention_function")
+# class AttnsConfig(NamedTuple):
+# name: str
+# version: str
+# supported_head_sizes: collections.abc.Collection
+
+# class AttentionRule(NamedTuple):
+
+
+# class AttentionRules(NamedTuple):
+# orig_attn: Callable
+# start_sigma: float = math.inf
+# end_sigma: float = 0.0
+# verbose: bool = False
+# rules: tuple[AttentionRule, ...] = ()
+
+
def attention_bleh( # noqa: PLR0914
q: torch.Tensor,
k: torch.Tensor,
@@ -182,10 +204,14 @@ def make_attn_wrapper(
raise ValueError(
"SpargeAttention is not available, make sure you have the spas_sage_attn Python package installed",
)
- if sageattn_function == "sparge":
+ if sageattn_function in {"sparge", "sparge2"}:
sageattn_function = spas_sage_attn.spas_sage2_attn_meansim_cuda
+ elif sageattn_function == "sparge2_topk":
+ sageattn_function = spas_sage_attn.spas_sage2_attn_meansim_topk_cuda
elif sageattn_function == "sparge1":
sageattn_function = spas_sage_attn.spas_sage_attn_meansim_cuda
+ elif sageattn_function == "sparge1_topk":
+ sageattn_function = spas_sage_attn.spas_sage_attn_meansim_topk_cuda
else:
sageattn_function = getattr(sageattention, sageattn_function)
@@ -331,6 +357,19 @@ def go(
return (model,)
+class TimeMode(Enum):
+ PERCENT = auto()
+ SIGMA = auto()
+
+
+class SageAttnOptions(NamedTuple):
+ sampler: object
+ attn_kwargs: dict[str, Any]
+ start_time: float = math.inf
+ end_time: float = 0.0
+ time_mode: TimeMode = TimeMode.PERCENT
+
+
class BlehModelWrapper:
def __init__(self, model: object, model_call: Callable):
self.__bleh_model = model
@@ -344,20 +383,25 @@ def __getattr__(self, k: str):
def sageattn_sampler(
+ config: SageAttnOptions,
model: object,
x: torch.Tensor,
sigmas: torch.Tensor,
- *,
- sageattn_sampler_options: tuple,
+ # *,
+ # sageattn_sampler_options: tuple,
**kwargs: dict,
) -> torch.Tensor:
- sampler, start_percent, end_percent, sageattn_kwargs = sageattn_sampler_options
- ms = model.inner_model.inner_model.model_sampling
- start_sigma, end_sigma = (
- round(ms.percent_to_sigma(start_percent), 4),
- round(ms.percent_to_sigma(end_percent), 4),
- )
- del ms
+ # sampler, start_percent, end_percent, sageattn_kwargs = sageattn_sampler_options
+ if config.time_mode == TimeMode.PERCENT:
+ ms = model.inner_model.inner_model.model_sampling
+ start_sigma, end_sigma = (
+ round(ms.percent_to_sigma(config.start_time), 4),
+ round(ms.percent_to_sigma(config.end_time), 4),
+ )
+ del ms
+ else:
+ start_sigma = config.start_time
+ end_sigma = config.end_time
def model_call(
model: object,
@@ -369,7 +413,7 @@ def model_call(
enabled = end_sigma <= sigma_float <= start_sigma
with sageattn_context(
enabled=enabled,
- **sageattn_kwargs,
+ **config.attn_kwargs,
) as attn_override:
if enabled and HAVE_ATTN_OVERRIDE:
model_options = kwargs.pop("model_options", {}).copy()
@@ -381,12 +425,12 @@ def model_call(
kwargs["model_options"] = model_options
return model(x, sigma, **kwargs)
- return sampler.sampler_function(
+ return config.sampler.sampler_function(
BlehModelWrapper(model, model_call),
x,
sigmas,
**kwargs,
- **sampler.extra_options,
+ **config.sampler.extra_options,
)
@@ -450,14 +494,383 @@ def go(
)
return (
KSAMPLER(
- sageattn_sampler,
- extra_options={
- "sageattn_sampler_options": (
- sampler,
- start_percent,
- end_percent,
- get_yaml_parameters(yaml_parameters),
+ update_wrapper(
+ partial(
+ sageattn_sampler,
+ SageAttnOptions(
+ start_time=start_percent,
+ end_time=end_percent,
+ time_mode=TimeMode.PERCENT,
+ sampler=sampler,
+ attn_kwargs=get_yaml_parameters(yaml_parameters),
+ ),
+ ),
+ sampler.sampler_function,
+ ),
+ ),
+ )
+
+
+class AdvancedAttnRule(NamedTuple):
+ attn_function: Callable | None
+ attn_kwargs: dict[str, Any]
+ blend_function: Callable | None = None
+ start_sigma: float = math.inf
+ end_sigma: float = 0.0
+ check_nan: bool = False
+ q_multiplier: float = 1.0
+ k_multiplier: float = 1.0
+ v_multiplier: float = 1.0
+ output_multiplier: float = 1.0
+ device: torch.device | str | None = None
+ dtype: torch.dtype | str | None = None
+ blend: float = 1.0
+ op_q: str | None = None
+ op_k: str | None = None
+ op_v: str | None = None
+ op_current_result_preblend: str | None = None
+ op_result_preblend: str | None = None
+ op_result_postblend: str | None = None
+ op_result_postblend_diff: str | None = None
+
+ @classmethod
+ def build(
+ cls,
+ *,
+ attn_function: str | Callable | None = None,
+ blend_mode: str | None = None,
+ device=None,
+ dtype=None,
+ **kwargs,
+ ) -> NamedTuple:
+ blend_function = BLENDING_MODES[blend_mode] if blend_mode is not None else None
+ my_kwargs = {k: kwargs.pop(k) for k in cls._fields if k in kwargs}
+ if attn_function == "default":
+ attn_function = None
+ elif isinstance(attn_function, str):
+ kwargs["sageattn_function"] = attn_function
+ attn_function = make_attn_wrapper(orig_attn=None, **kwargs)
+ if isinstance(dtype, str):
+ dtype = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ "float64": torch.float64,
+ }.get(dtype)
+ return cls(
+ device=device,
+ dtype=dtype,
+ blend_function=blend_function,
+ attn_function=attn_function,
+ attn_kwargs=kwargs,
+ **my_kwargs,
+ )
+
+
+class AdvancedAttnConfig(NamedTuple):
+ sampler: object
+ verbose: bool = False
+ rules: tuple[AdvancedAttnRule, ...] = ()
+ start_time: float = math.inf
+ end_time: float = 0.0
+ call_indexes: frozenset[float] = frozenset()
+ time_mode: TimeMode = TimeMode.PERCENT
+ min_cond_batch: int = 0
+ batch_slice: tuple | str | None = None
+ max_idx: int = -1
+ delegate_override: bool = True
+ op_result: str | None = None
+ latent_ops: dict[str, Callable] = {}
+
+ @classmethod
+ def build(
+ cls,
+ *,
+ rules=(),
+ time_mode: str | TimeMode | None = None,
+ call_indexes=(),
+ **kwargs,
+ ) -> NamedTuple:
+ fs = frozenset(cls._fields)
+ rules = tuple(AdvancedAttnRule.build(**r) for r in rules)
+ if isinstance(time_mode, str):
+ time_mode = getattr(TimeMode, time_mode.strip().upper())
+ call_indexes = frozenset(
+ i
+ if math.isnan(i) or i == math.inf or not isinstance(i, float)
+ else int(i) + 0.5
+ for i in call_indexes
+ )
+ kwargs = {k: v for k, v in kwargs.items() if k in fs}
+ return cls(
+ rules=rules,
+ time_mode=time_mode,
+ call_indexes=call_indexes,
+ **kwargs,
+ )
+
+ def call_op(
+ self, op_key: str | None, t: torch.Tensor, *, sigma: float
+ ) -> torch.Tensor:
+ op = None if op_key is None else self.latent_ops.get(op_key)
+ if op is None:
+ return t
+ return (
+ op(t)
+ if not hasattr(op, "EXTENDED_LATENT_OPERATION")
+ else op(t, sigma=sigma)
+ )
+
+ def attn_wrapper(
+ self,
+ sigma: float,
+ currattncall: CurrAttnCall,
+ old_override: Callable | None,
+ comfy_orig_attn: Callable,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ fallback_attn = (
+ partial(old_override, comfy_orig_attn)
+ if old_override is not None and self.delegate_override
+ else comfy_orig_attn
+ )
+ result = None
+ rules = self.rules
+ call_idx = currattncall.idx
+ rev_idx = (
+ -abs(currattncall.max_idx - currattncall.idx)
+ if currattncall.max_idx >= 0
+ else math.nan
+ )
+ ci = self.call_indexes
+ fallthrough_match = math.inf in ci
+ exclude = (call_idx + 0.5) in ci or (
+ not math.isnan(rev_idx) and (rev_idx + 0.5) in ci
+ )
+ matched = not exclude and (fallthrough_match or call_idx in ci or rev_idx in ci)
+ if self.verbose:
+ tqdm.write(
+ f"[BLEH] AdvancedAttn wrapper({currattncall.idx:<3}): sigma={sigma:.4f}, max_idx={currattncall.max_idx:<3}, rev_idx={rev_idx:<3}, matched={matched}, exclude={exclude}, q.shape={q.shape}",
+ )
+ currattncall.idx += 1
+ rules = self.rules if matched else ()
+ for rule in rules:
+ if not rule.end_sigma <= sigma <= rule.start_sigma:
+ continue
+ currq, currk, currv = q, k, v
+ if rule.q_multiplier != 1:
+ currq = currq * rule.q_multiplier
+ if rule.k_multiplier != 1:
+ currk = currk * rule.k_multiplier
+ if rule.v_multiplier != 1:
+ currv = currv * rule.v_multiplier
+ if rule.dtype is not None or rule.device is not None:
+ currq = currq.to(device=rule.device, dtype=rule.dtype)
+ currk = currk.to(device=rule.device, dtype=rule.dtype)
+ currv = currv.to(device=rule.device, dtype=rule.dtype)
+ currq = self.call_op(rule.op_q, currq, sigma=sigma)
+ currk = self.call_op(rule.op_k, currk, sigma=sigma)
+ currv = self.call_op(rule.op_v, currv, sigma=sigma)
+ attn_function = (
+ partial(rule.attn_function, fallback_attn)
+ if rule.attn_function
+ else fallback_attn
+ )
+ curr_result = attn_function(currq, currk, currv, *args, **kwargs)
+ del currq, currk, currv
+ if rule.check_nan and curr_result.isnan().any():
+ del curr_result
+ continue
+ if rule.output_multiplier != 1:
+ curr_result *= rule.output_multiplier
+ curr_result = self.call_op(
+ rule.op_current_result_preblend,
+ curr_result,
+ sigma=sigma,
+ )
+ if curr_result.dtype != q.dtype or curr_result.device != q.device:
+ curr_result = curr_result.to(q)
+ if result is not None:
+ result = self.call_op(rule.op_result_preblend, result, sigma=sigma)
+ if result is None or rule.blend_function is None:
+ result = curr_result
+ del curr_result
+ continue
+ prev_result = result
+ result = rule.blend_function(result, curr_result, rule.blend)
+ if rule.op_result_postblend_diff is not None:
+ result = prev_result + self.call_op(
+ rule.op_result_postblend_diff,
+ result - prev_result,
+ sigma=sigma,
+ )
+ del curr_result, prev_result
+ result = self.call_op(rule.op_result_postblend, result, sigma=sigma)
+ if result is None:
+ result = fallback_attn(q, k, v, *args, **kwargs)
+ return self.call_op(self.op_result, result, sigma=sigma)
+
+
+class CurrAttnCall:
+ def __init__(self, idx: int = 0, max_idx: int = -1):
+ self.idx = idx
+ self.max_idx = max_idx
+
+
+def advancedattn_sampler(
+ config: AdvancedAttnConfig,
+ model: object,
+ x: torch.Tensor,
+ sigmas: torch.Tensor,
+ **kwargs: dict,
+) -> torch.Tensor:
+ if config.time_mode == TimeMode.PERCENT:
+ ms = model.inner_model.inner_model.model_sampling
+ start_sigma, end_sigma = (
+ round(ms.percent_to_sigma(config.start_time), 4),
+ round(ms.percent_to_sigma(config.end_time), 4),
+ )
+ del ms
+ else:
+ start_sigma = config.start_time
+ end_sigma = config.end_time
+
+ max_idx = -1
+
+ def model_call(
+ model: object,
+ x: torch.Tensor,
+ sigma: torch.Tensor,
+ **kwargs: dict[str, Any],
+ ) -> torch.Tensor:
+ nonlocal max_idx
+ sigma_float = float(sigma.max().detach().cpu())
+ enabled = end_sigma <= sigma_float <= start_sigma
+ if not enabled:
+ return model(x, sigma, **kwargs)
+ calltracker = CurrAttnCall(idx=0, max_idx=max_idx)
+ if config.verbose:
+ tqdm.write(f"[BLEH] AdvancedAttn: Config: {config}")
+ model_options = kwargs.pop("model_options", {}).copy()
+ transformer_options = model_options.pop("transformer_options", {}).copy()
+ old_override = transformer_options.pop("optimized_attention_override", None)
+ attn_override = partial(
+ config.attn_wrapper,
+ sigma_float,
+ calltracker,
+ old_override,
+ )
+ transformer_options["optimized_attention_override"] = attn_override
+ model_options["transformer_options"] = transformer_options
+ kwargs["model_options"] = model_options
+ result = model(x, sigma, **kwargs)
+ max_idx = max(max_idx, calltracker.idx)
+ return result
+
+ return config.sampler.sampler_function(
+ BlehModelWrapper(model, model_call),
+ x,
+ sigmas,
+ **kwargs,
+ **config.sampler.extra_options,
+ )
+
+
+class BlehAdvancedAttentionSampler:
+ DESCRIPTION = "TBD"
+ CATEGORY = "sampling/custom_sampling/samplers"
+ RETURN_TYPES = ("SAMPLER",)
+ FUNCTION = "go"
+
+ DEFAULT_YAML_PARAMS = """verbose: false
+start_time: 0.0
+end_time: 1.0
+# One of: percent, sigma
+time_mode: percent
+# .inf means match everything. Whole float values exclude an index. I.E 2.0
+# Call index as in the Nth call to attention this model evaluation.
+# Negative indexes count from the end but can only match after a pass through the model.
+call_indexes: [.inf]
+# Can be set to null (everything), cond, uncond or a list.
+batch_slice: null
+# Requires cond batch information to be passed and at least this many items.
+min_cond_batch: 0
+rules:
+ # Passed as sageattn_function unless set to default.
+ # Keys not in this list are passed through like with the SageAttention node:
+ # attn_function, blend_mode, blend
+ - attn_function: default
+ # You can set whatever other keys you want here.
+ - attn_function: sageattn
+ # Blends target the last attention result and are ignored
+ # if it's missing.
+ # The default blend means:
+ # sageattn + (defaultattn - sageattn) * 2
+ blend_mode: cfg
+ blend: 2.0
+"""
+
+ @classmethod
+ def INPUT_TYPES(cls) -> dict:
+ return {
+ "required": {
+ "sampler": ("SAMPLER",),
+ "yaml_parameters": (
+ "STRING",
+ {
+ "default": cls.DEFAULT_YAML_PARAMS,
+ "tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.",
+ "dynamicPrompts": False,
+ "multiline": True,
+ "defaultInput": True,
+ },
+ ),
+ },
+ "optional": {
+ "op_0": ("LATENT_OPERATION",),
+ "op_1": ("LATENT_OPERATION",),
+ "op_2": ("LATENT_OPERATION",),
+ "op_3": ("LATENT_OPERATION",),
+ "op_4": ("LATENT_OPERATION",),
+ "op_5": ("LATENT_OPERATION",),
+ "op_6": ("LATENT_OPERATION",),
+ "op_7": ("LATENT_OPERATION",),
+ "op_8": ("LATENT_OPERATION",),
+ "op_9": ("LATENT_OPERATION",),
+ },
+ }
+
+ @classmethod
+ def go(
+ cls,
+ sampler: object,
+ yaml_parameters: str,
+ **kwargs: dict,
+ ) -> tuple:
+ if sageattention is None:
+ raise RuntimeError(
+ "sageattention not installed to Python environment: SageAttention feature unavailable",
+ )
+ if not HAVE_ATTN_OVERRIDE:
+ raise RuntimeError(
+ "This node only supports recent ComfyUI versions that support attention overrides.",
+ )
+ params = get_yaml_parameters(yaml_parameters)
+ params["latent_ops"] = {
+ k: v for k, v in kwargs.items() if k.startswith("op_") and v is not None
+ }
+ return (
+ KSAMPLER(
+ update_wrapper(
+ partial(
+ advancedattn_sampler,
+ AdvancedAttnConfig.build(sampler=sampler, **params),
),
- },
+ sampler.sampler_function,
+ ),
),
)
diff --git a/py/nodes/taevid.py b/py/nodes/taevid.py
index 0892342..ad3ef8a 100644
--- a/py/nodes/taevid.py
+++ b/py/nodes/taevid.py
@@ -18,7 +18,19 @@ class TAEVideoNodeBase:
def INPUT_TYPES(cls) -> dict:
return {
"required": {
- "latent_type": (("wan21", "wan22", "hunyuanvideo", "mochi"),),
+ "latent_type": (
+ (
+ "wan21",
+ "wan22",
+ "hunyuanvideo",
+ "hunyuanvideo15",
+ "mochi",
+ "ltxv",
+ ),
+ {
+ "tooltip": "Use ltxv for LTX-2 AV.",
+ },
+ ),
"parallel_mode": (
"BOOLEAN",
{
@@ -45,6 +57,8 @@ def get_taevid_model(
model_src = "taew2_2.pth from https://github.com/madebyollin/taehv"
elif latent_type == "hunyuanvideo":
model_src = "taehv.pth from https://github.com/madebyollin/taehv"
+ elif latent_type == "ltxv":
+ model_src = "taeltx_2.pth from https://github.com/madebyollin/taehv"
else:
model_src = "taem1.pth from https://github.com/madebyollin/taem1"
err_string = f"Missing TAE video model. Download {model_src} and place it in the models/vae_approx directory"
@@ -66,7 +80,7 @@ def go(cls, *, latent, latent_type: str, parallel_mode: bool) -> tuple:
class TAEVideoDecode(TAEVideoNodeBase):
RETURN_TYPES = ("IMAGE",)
CATEGORY = "latent"
- DESCRIPTION = "Fast decoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD."
+ DESCRIPTION = "Fast decoding of Wan, Hunyuan, Mochi and LTX video latents with the video equivalent of TAESD."
@classmethod
def INPUT_TYPES(cls) -> dict:
@@ -100,7 +114,7 @@ def go(cls, *, latent: dict, latent_type: str, parallel_mode: bool) -> tuple:
class TAEVideoEncode(TAEVideoNodeBase):
RETURN_TYPES = ("LATENT",)
CATEGORY = "latent"
- DESCRIPTION = "Fast encoding of Wan, Hunyuan and Mochi video latents with the video equivalent of TAESD."
+ DESCRIPTION = "Fast encoding of Wan, Hunyuan, Mochi and LTX video latents with the video equivalent of TAESD."
@classmethod
def INPUT_TYPES(cls) -> dict:
diff --git a/py/settings.py b/py/settings.py
index 327e93d..2ded21f 100644
--- a/py/settings.py
+++ b/py/settings.py
@@ -1,4 +1,7 @@
+from __future__ import annotations
+
from pathlib import Path
+from typing import NamedTuple
class Settings:
@@ -7,6 +10,8 @@ def __init__(self):
def update(self, obj):
btp = obj.get("betterTaesdPreviews", None)
+ if btp is None:
+ btp = obj.get("previews", None)
self.btp_enabled = btp is not None and btp.get("enabled", True) is True
if not self.btp_enabled:
return
@@ -23,6 +28,7 @@ def update(self, obj):
self.btp_preview_device = btp.get("preview_device")
# default, keep, float32, float16, bfloat16
self.btp_preview_dtype = btp.get("preview_dtype")
+ self.btp_preview_non_blocking = bool(btp.get("preview_non_blocking", False))
self.btp_maxed_batch_step_mode = btp.get("maxed_batch_step_mode", False)
self.btp_compile_previewer = btp.get("compile_previewer", False)
self.btp_oom_fallback = btp.get("oom_fallback", "latent2rgb")
@@ -37,6 +43,7 @@ def update(self, obj):
)
self.btp_animate_preview = btp.get("animate_preview", "none")
self.btp_verbose = btp.get("verbose", False)
+ self.btp_publish_last_preview = btp.get("publish_last_preview", False)
@staticmethod
def get_cfg_path(filename) -> Path:
diff --git a/py/wavelet_functions.py b/py/wavelet_functions.py
new file mode 100644
index 0000000..b993b41
--- /dev/null
+++ b/py/wavelet_functions.py
@@ -0,0 +1,244 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable
+
+import torch
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+try:
+ import pytorch_wavelets as ptwav
+ import pywt
+
+ HAVE_WAVELETS = True
+except ImportError:
+ ptwav = None
+ pywt = None
+ HAVE_WAVELETS = False
+
+
+def fallback[V, D](val: V | D, default: D = None) -> V | D:
+ return val if val is not None else default
+
+
+class Wavelet:
+ DEFAULT_MODE = "symmetric"
+ DEFAULT_LEVEL = 3
+ DEFAULT_WAVE = "db4"
+ DEFAULT_USE_1D_DWT = False
+ DEFAULT_USE_DTCWT = False
+ DEFAULT_QSHIFT = "qshift_a"
+ DEFAULT_BIORT = "near_sym_a"
+
+ def __init__(
+ self,
+ *,
+ wave: str = DEFAULT_WAVE,
+ level: int = DEFAULT_LEVEL,
+ mode: str = DEFAULT_MODE,
+ use_1d_dwt: bool = DEFAULT_USE_1D_DWT,
+ use_dtcwt: bool = DEFAULT_USE_DTCWT,
+ biort: str = DEFAULT_BIORT,
+ qshift: str = DEFAULT_QSHIFT,
+ inv_wave: str | None = None,
+ inv_mode: str | None = None,
+ inv_biort: str | None = None,
+ inv_qshift=None,
+ device: str | torch.device | None = None,
+ ):
+ if not HAVE_WAVELETS:
+ raise RuntimeError(
+ "Wavelet use requires the pytorch_wavelets package to be installed in your Python environment",
+ )
+ inv_wave = fallback(inv_wave, wave)
+ inv_mode = fallback(inv_mode, mode)
+ inv_biort = fallback(inv_biort, biort)
+ inv_qshift = fallback(inv_qshift, qshift)
+ if use_dtcwt:
+ fwdfun, invfun = ptwav.DTCWTForward, ptwav.DTCWTInverse
+ elif use_1d_dwt:
+ fwdfun, invfun = ptwav.DWT1DForward, ptwav.DWT1DInverse
+ else:
+ fwdfun, invfun = ptwav.DWTForward, ptwav.DWTInverse
+ if use_dtcwt:
+ self._wavelet_forward = fwdfun(
+ J=level,
+ mode=mode,
+ biort=biort,
+ qshift=qshift,
+ )
+ self._wavelet_inverse = invfun(
+ mode=inv_mode,
+ biort=inv_biort,
+ qshift=inv_qshift,
+ )
+ else:
+ self._wavelet_forward = fwdfun(J=level, wave=wave, mode=mode)
+ self._wavelet_inverse = invfun(wave=inv_wave, mode=inv_mode)
+ self.device = device
+ if device is not None:
+ self._wavelet_forward = self._wavelet_forward.to(device=device)
+ self._wavelet_inverse = self._wavelet_inverse.to(device=device)
+
+ def forward(
+ self,
+ t: torch.Tensor,
+ *,
+ forward_function: Callable | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]:
+ return fallback(forward_function, self._wavelet_forward)(t)
+
+ def inverse(
+ self,
+ yl: torch.Tensor,
+ yh: tuple[torch.Tensor, ...],
+ *,
+ inverse_function: Callable | None = None,
+ two_step_inverse: bool = False,
+ ) -> torch.Tensor:
+ inverse_function = fallback(inverse_function, self._wavelet_inverse)
+ if not two_step_inverse:
+ return inverse_function((yl, yh))
+ result = inverse_function((torch.zeros_like(yl), yh))
+ result += inverse_function(
+ (
+ yl,
+ tuple(torch.zeros_like(yh_band) for yh_band in yh),
+ )
+ )
+ return result
+
+ def to(self, *args: list, copy: bool = False, **kwargs: dict) -> Wavelet:
+ o = Wavelet.__new__(Wavelet) if copy else self
+ o._wavelet_forward = self._wavelet_forward.to(*args, **kwargs) # noqa: SLF001
+ o._wavelet_inverse = self._wavelet_inverse.to(*args, **kwargs) # noqa: SLF001
+ o.device = kwargs.get("device")
+ return o
+
+ @staticmethod
+ def wavelist() -> tuple:
+ return tuple(pywt.wavelist()) if HAVE_WAVELETS else ()
+
+ @staticmethod
+ def biortlist() -> tuple:
+ return (
+ ("near_sym_a", "near_sym_b", "antonini", "legall") if HAVE_WAVELETS else ()
+ )
+
+ @staticmethod
+ def qshiftlist() -> tuple:
+ return (
+ ("qshift_a", "qshift_b", "qshift_c", "qshift_d", "qshift_06")
+ if HAVE_WAVELETS
+ else ()
+ )
+
+ @staticmethod
+ def modelist() -> tuple:
+ return (
+ (
+ "symmetric",
+ "zero",
+ "reflect",
+ "replicate",
+ "periodization",
+ "periodic",
+ "constant",
+ )
+ if HAVE_WAVELETS
+ else ()
+ )
+
+
+def expand_yh_scales(
+ yh: Sequence,
+ *,
+ yh_scales: float | Sequence = 1.0,
+) -> float | tuple:
+ yhlen = len(yh)
+ yh_shape = yh[0].shape
+ # Doesn't make sense to target orientations for 1D DWD (3D here).
+ olen = yh_shape[2] if len(yh_shape) > 3 else 1
+ # print(f"\nSIZES: yhlen={yhlen}, olen={olen}, yh_shape={yh[0].shape}")
+ if isinstance(yh_scales, (float, int)):
+ return ((float(yh_scales),) * olen,) * yhlen
+ otemplate = (1.0,) * olen
+ yh_scales = tuple(
+ (float(band),) * olen
+ if isinstance(band, (float, int))
+ else (
+ (
+ *(float(i) for i in band[:olen]),
+ *otemplate[: olen - len(band[:olen])],
+ )
+ if isinstance(band, (tuple, list))
+ else band
+ )
+ for band in yh_scales
+ )
+ if "fill" in yh_scales:
+ fillidx = yh_scales.index("fill")
+ if "fill" in yh_scales[fillidx + 1 :]:
+ raise ValueError("Only one fill allowed.")
+ if fillidx == 0 or len(yh_scales) < 2:
+ raise ValueError(
+ "Invalid fill value, cannot be in the first position or the only item.",
+ )
+ yhslen = len(yh_scales)
+ if yhslen - 1 < yhlen:
+ # Need to pad.
+ fill = (yh_scales[fillidx - 1],) * (yhlen - (len(yh_scales) - 1))
+ yh_scales = (*yh_scales[:fillidx], *fill, *yh_scales[fillidx + 1 :])
+ else:
+ # Just remove the "fill".
+ yh_scales = (*yh_scales[:fillidx], *yh_scales[fillidx + 1 :])
+ return yh_scales[:yhlen]
+
+
+def wavelet_scaling(
+ yl: torch.Tensor,
+ yh: Sequence[torch.Tensor],
+ yl_scale: float | torch.Tensor,
+ yh_scales: float | Sequence[float | Sequence[float]] | None,
+ *,
+ in_place: bool = False,
+) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]:
+ if not in_place:
+ yl = yl.clone()
+ yh = tuple(yhband.clone() for yhband in yh)
+ if yl_scale != 1.0:
+ yl *= yl_scale
+ yh_scales = expand_yh_scales(
+ yh,
+ yh_scales=yh_scales if yh_scales is not None else 1.0,
+ )
+ for hscale, ht in zip(yh_scales, yh):
+ if isinstance(hscale, (int, float)):
+ ht *= hscale # noqa: PLW2901
+ continue
+ for lidx in range(min(ht.shape[2], len(hscale))):
+ ht[:, :, lidx] *= hscale[lidx]
+ return (yl, yh)
+
+
+def wavelet_blend(
+ a: tuple,
+ b: tuple,
+ *,
+ yl_factor: torch.Tensor | float,
+ blend_function: Callable,
+ yh_factor: torch.Tensor | float | None = None,
+ yh_blend_function: Callable | None = None,
+) -> tuple:
+ if not isinstance(yl_factor, torch.Tensor):
+ yl_factor = a[0].new_full((1,), yl_factor)
+ if yh_factor is None:
+ yh_factor = yl_factor
+ elif not isinstance(yh_factor, torch.Tensor):
+ yh_factor = a[0].new_full((1,), yh_factor)
+ yh_blend_function = fallback(yh_blend_function, blend_function)
+ return (
+ blend_function(a[0], b[0], yl_factor),
+ tuple(yh_blend_function(ta, tb, yh_factor) for ta, tb in zip(a[1], b[1])),
+ )