From 0ed392b6269769cd097c2dba3b87bb2708194897 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Mon, 25 Aug 2025 13:09:33 -0400 Subject: [PATCH] stuff --- configs/dit_v4_dmd.yml | 9 +- inference/build_cache_object.py | 84 +++++++++++++++ inference/causvid_pipeline.py | 8 +- inference/decoder_test.py | 96 +++++++++++++++++ inference/model_cache_test.py | 155 ++++++++++++++++++++++++++ inference/model_fp8_test.py | 174 ++++++++++++++++++++++++++++++ inference/model_test.py | 116 ++++++++++++++++++++ inference/test_sampling.py | 53 +++------ owl_wms/__init__.py | 6 +- owl_wms/nn/attn.py | 16 +-- owl_wms/nn/kv_cache.py | 75 +++++++++++++ owl_wms/sampling/av_caching_v2.py | 5 +- owl_wms/utils/owl_vae_bridge.py | 1 - 13 files changed, 734 insertions(+), 64 deletions(-) create mode 100644 inference/build_cache_object.py create mode 100644 inference/decoder_test.py create mode 100644 inference/model_cache_test.py create mode 100644 inference/model_fp8_test.py create mode 100644 inference/model_test.py diff --git a/configs/dit_v4_dmd.yml b/configs/dit_v4_dmd.yml index 2196b383..c7bd52df 100644 --- a/configs/dit_v4_dmd.yml +++ b/configs/dit_v4_dmd.yml @@ -72,9 +72,8 @@ train: sampler_id: av_caching sampler_kwargs: n_steps: 2 - #window_length: 16 cfg_scale: 1.0 - num_frames: 600 + num_frames: 60 noise_prev: 0.2 custom_schedule: [1.0, 0.5] @@ -82,11 +81,11 @@ train: eval_sample_dir: null vae_id: null - vae_batch_size: 4 + vae_batch_size: 1 vae_scale: 0.63 - vae_cfg_path: /mnt/data/shahbuland/owl-vaes/configs/cod_yt_v2/base.yml - vae_ckpt_path: /mnt/data/checkpoints/owl_vaes/cod_yt_v2/cod_yt_v2_515k_ema_decoder.pt + vae_cfg_path: ./dec_dist.yml + vae_ckpt_path: ./dec_distill_cod_v2.pt update_ratio: 5 ts_shift: 8 diff --git a/inference/build_cache_object.py b/inference/build_cache_object.py new file mode 100644 index 00000000..dc2f6e09 --- /dev/null +++ b/inference/build_cache_object.py @@ -0,0 +1,84 @@ +import joblib + +import torch +from owl_wms.configs import Config +from owl_wms import from_pretrained +from owl_wms.nn.rope import RoPE +from owl_wms.nn.kv_cache import KVCache + +def cast_rope_buffers_to_fp32(module): + """Cast RoPE buffers to fp32 for numerical stability""" + for submodule in module.modules(): + if isinstance(submodule, RoPE): + if hasattr(submodule, "cos"): + submodule.cos = submodule.cos.float() + if hasattr(submodule, "sin"): + submodule.sin = submodule.sin.float() + +@torch.no_grad() +def build_cache(): + # Configuration + cfg_path = "configs/dit_v4_dmd.yml" + ckpt_path = "vid_dit_v4_dmd_7k.pt" + + # Load model (no decoder needed) + print("Loading model...") + model = from_pretrained(cfg_path, ckpt_path, return_decoder=False) + model = model.core.eval().cuda().bfloat16() + + # Cast RoPE buffers to fp32 + cast_rope_buffers_to_fp32(model) + + print("Model loaded successfully!") + + # Load data cache + data = torch.load("data_cache.pt") + vid = data["vid"] + mouse = data["mouse"] + btn = data["btn"] + + batch_size = vid.size(0) + init_len = vid.size(1) + + print(f"Cache input shapes:") + print(f" vid: {vid.shape}") + print(f" mouse: {mouse.shape}") + print(f" btn: {btn.shape}") + + # Initialize KV cache + kv_cache = KVCache(model.config) + kv_cache.reset(batch_size) + + # Build cache with context frames + noise_prev = 0.2 + vid_noisy = vid * (1. - noise_prev) + torch.randn_like(vid) * noise_prev + t_noisy = vid.new_full((batch_size, init_len), noise_prev) + + init_len = 60 + vid_noisy = vid_noisy[:,:init_len] + t_noisy = t_noisy[:vid_noisy.size(0),:init_len] + mouse = mouse[:vid_noisy.size(0),:init_len] + btn = btn[:vid_noisy.size(0),:init_len] + + + print("Building KV cache...") + kv_cache.enable_cache_updates() + _ = model( + vid_noisy, + t_noisy, + mouse, + btn, + kv_cache=kv_cache + ) + kv_cache.disable_cache_updates() + + # Save the cache object + cache_object = kv_cache + + print("Saving cache object...") + joblib.dump(cache_object, 'kv_cache_object.pkl') + print("Cache object saved to 'kv_cache_object.pkl'") + +if __name__ == "__main__": + build_cache() + diff --git a/inference/causvid_pipeline.py b/inference/causvid_pipeline.py index 23f52ef6..7ff9b735 100644 --- a/inference/causvid_pipeline.py +++ b/inference/causvid_pipeline.py @@ -26,7 +26,7 @@ def to_bgr_uint8(frame, target_size=(1080,1920)): return frame class CausvidPipeline: - def __init__(self, cfg_path="configs/causvid.yml", ckpt_path="causvid_ema.pt"): + def __init__(self, cfg_path="configs/dit_v4_dmd.yml", ckpt_path="vid_dit_v4_dmd_7k.pt"): cfg = Config.from_yaml(cfg_path) model_cfg = cfg.model train_cfg = cfg.train @@ -42,12 +42,6 @@ def __init__(self, cfg_path="configs/causvid.yml", ckpt_path="causvid_ema.pt"): ) self.frame_decoder = self.frame_decoder.cuda().bfloat16().eval() - #audio_decoder = get_decoder_only( - # None, - # train_cfg.audio_vae_cfg_path, - # train_cfg.audio_vae_ckpt_path - #) - # Store scales as instance variables self.frame_scale = train_cfg.vae_scale self.audio_scale = train_cfg.audio_vae_scale diff --git a/inference/decoder_test.py b/inference/decoder_test.py new file mode 100644 index 00000000..b7757e42 --- /dev/null +++ b/inference/decoder_test.py @@ -0,0 +1,96 @@ +import torch +from torch import nn +from owl_wms import from_pretrained +import gc + + +import torch, torch_tensorrt +from torch_tensorrt.dynamo import compile + +cfg_path = "configs/dit_v4_dmd.yml" +ckpt_path = "vid_dit_v4_dmd_7k.pt" + +_, decoder = from_pretrained(cfg_path, ckpt_path, return_decoder=True) +decoder = decoder.eval().cuda().bfloat16() +#decoder = torch.compile(decoder, mode = 'max-autotune', fullgraph = True, dynamic = False) +decoder = +# Clear cache +torch.cuda.empty_cache() +gc.collect() + +# Configuration +BATCH_SIZE = 1 + +@torch.no_grad() +def test_decoder(): + print(f"Testing decoder with batch size {BATCH_SIZE}...") + + print("Decoder loaded successfully!") + print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters()) / 1e6:.1f}M") + + def create_test_inputs(): + """Create randomized test inputs""" + x = torch.randn(BATCH_SIZE, 128, 8, 8, device='cuda', dtype=torch.bfloat16) + return x + + # Create initial inputs to show shapes + x = create_test_inputs() + + print(f"Input shapes:") + print(f" x: {x.shape}") + + # Check initial VRAM + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Initial VRAM usage: {initial_memory:.2f} GB") + + # Warmup runs + print("Running 5 warmup iterations...") + for i in range(5): + x = create_test_inputs() + output = decoder(x) + print(f" Warmup {i+1}/5 completed") + + # Timing runs with CUDA events + print("Running 10 timed iterations...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + times = [] + for i in range(10): + x = create_test_inputs() + + torch.cuda.synchronize() + start_event.record() + + output = decoder(x) + + end_event.record() + torch.cuda.synchronize() + + elapsed_time = start_event.elapsed_time(end_event) # milliseconds + times.append(elapsed_time) + print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms") + + # Calculate statistics + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + fps = 1000.0 / avg_time # Convert ms to fps + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + current_memory = torch.cuda.memory_allocated() / 1024**3 + + print(f"\nTiming Results:") + print(f" Average time: {avg_time:.2f} ms") + print(f" Min time: {min_time:.2f} ms") + print(f" Max time: {max_time:.2f} ms") + print(f" Average FPS: {fps:.2f}") + print(f"\nMemory Results:") + print(f" Output shape: {output.shape}") + print(f" Peak VRAM usage: {peak_memory:.2f} GB") + print(f" Current VRAM usage: {current_memory:.2f} GB") + print(f" VRAM increase: {current_memory - initial_memory:.2f} GB") + +if __name__ == "__main__": + test_decoder() \ No newline at end of file diff --git a/inference/model_cache_test.py b/inference/model_cache_test.py new file mode 100644 index 00000000..d37da05a --- /dev/null +++ b/inference/model_cache_test.py @@ -0,0 +1,155 @@ +import torch +import time +from owl_wms.configs import Config +from owl_wms import from_pretrained +from owl_wms.nn.rope import RoPE +from owl_wms.nn.kv_cache import KVCache, InferenceKVCache + +# Configuration +N_FRAMES_CACHE = 60 +N_FRAMES = 1 +cfg_path = "configs/dit_v4_dmd.yml" +ckpt_path = "vid_dit_v4_dmd_7k.pt" + +def cast_rope_buffers_to_fp32(module): + """Cast RoPE buffers to fp32 for numerical stability""" + for submodule in module.modules(): + if isinstance(submodule, RoPE): + if hasattr(submodule, "cos"): + submodule.cos = submodule.cos.float() + if hasattr(submodule, "sin"): + submodule.sin = submodule.sin.float() + +@torch.no_grad() +def test_model_forward(): + print(f"Testing model forward pass with {N_FRAMES_CACHE} cache frames + {N_FRAMES} new frames...") + + # Load model + print("Loading model...") + model, vae = from_pretrained(cfg_path, ckpt_path, return_decoder=True) + model = model.core.eval().cuda().bfloat16() + vae = vae.eval().cuda().bfloat16() + + + # Cast RoPE buffers to fp32 + cast_rope_buffers_to_fp32(model) + + # Load config for model dimensions + cfg = Config.from_yaml(cfg_path) + + print("Model loaded successfully!") + print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + + # Create test inputs + batch_size = 1 + + def create_cache_inputs(): + """Create randomized cache inputs""" + video = torch.randn(batch_size, N_FRAMES_CACHE, 128, 8, 8, device='cuda', dtype=torch.bfloat16) + ts = torch.randn(batch_size, N_FRAMES_CACHE, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(batch_size, N_FRAMES_CACHE, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(batch_size, N_FRAMES_CACHE, 11, device='cuda', dtype=torch.bfloat16) + return video, ts, mouse, btn + + def create_test_inputs(): + """Create randomized test inputs""" + video = torch.randn(batch_size, N_FRAMES, 128, 8, 8, device='cuda', dtype=torch.bfloat16) + ts = torch.randn(batch_size, N_FRAMES, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(batch_size, N_FRAMES, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(batch_size, N_FRAMES, 11, device='cuda', dtype=torch.bfloat16) + return video, ts, mouse, btn + + # Initialize KV cache + kv_cache = KVCache(model.config) + kv_cache.reset(batch_size) + + # Create initial cache inputs to show shapes + cache_video, cache_ts, cache_mouse, cache_btn = create_cache_inputs() + video, ts, mouse, btn = create_test_inputs() + + print(f"Cache input shapes:") + print(f" cache_video: {cache_video.shape}") + print(f" cache_ts: {cache_ts.shape}") + print(f" cache_mouse: {cache_mouse.shape}") + print(f" cache_btn: {cache_btn.shape}") + + print(f"New input shapes:") + print(f" video: {video.shape}") + print(f" ts: {ts.shape}") + print(f" mouse: {mouse.shape}") + print(f" btn: {btn.shape}") + + # Check initial VRAM + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Initial VRAM usage: {initial_memory:.2f} GB") + + # Initialize cache with initial context + print("Initializing KV cache with context...") + kv_cache.enable_cache_updates() + _ = model(cache_video, cache_ts, cache_mouse, cache_btn, kv_cache=kv_cache) + kv_cache.disable_cache_updates() + + kv_cache = InferenceKVCache(kv_cache) + model.transformer.enable_decoding() + + model = torch.compile(model, mode = 'max-autotune', dynamic=False, fullgraph=True) + + def call_fn(model, video, ts, mouse, btn, kv_cache): + for _ in range(2): + pred = model(video, ts, mouse, btn, kv_cache=kv_cache) + video = video - pred * 0.5 + + kv_cache.enable_cache_updates() + _ = model(video, ts, mouse, btn, kv_cache = kv_cache) + kv_cache.disable_cache_updates() + + # Warmup runs + print("Running 5 warmup iterations...") + for i in range(5): + video, ts, mouse, btn = create_test_inputs() + call_fn(model, video, ts, mouse, btn, kv_cache) + print(f" Warmup {i+1}/5 completed") + + # Timing runs with CUDA events + print("Running 10 timed iterations...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + times = [] + for i in range(10): + video, ts, mouse, btn = create_test_inputs() + + torch.cuda.synchronize() + start_event.record() + + call_fn(model, video, ts, mouse, btn, kv_cache) + + end_event.record() + torch.cuda.synchronize() + + elapsed_time = start_event.elapsed_time(end_event) # milliseconds + times.append(elapsed_time) + print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms") + + # Calculate statistics + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + fps = 1000.0 / avg_time # Convert ms to fps + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + current_memory = torch.cuda.memory_allocated() / 1024**3 + + print(f"\nTiming Results:") + print(f" Average time: {avg_time:.2f} ms") + print(f" Min time: {min_time:.2f} ms") + print(f" Max time: {max_time:.2f} ms") + print(f" Average FPS: {fps:.2f}") + print(f"\nMemory Results:") + print(f" Peak VRAM usage: {peak_memory:.2f} GB") + print(f" Current VRAM usage: {current_memory:.2f} GB") + print(f" VRAM increase: {current_memory - initial_memory:.2f} GB") + +if __name__ == "__main__": + test_model_forward() diff --git a/inference/model_fp8_test.py b/inference/model_fp8_test.py new file mode 100644 index 00000000..aec7ed61 --- /dev/null +++ b/inference/model_fp8_test.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import time +from owl_wms.configs import Config +from owl_wms import from_pretrained +from owl_wms.nn.rope import RoPE + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import fp8_autocast + +# Configuration +N_FRAMES = 1 +cfg_path = "configs/dit_v4_dmd.yml" +ckpt_path = "vid_dit_v4_dmd_7k.pt" + +def cast_rope_buffers_to_fp32(module): + """Cast RoPE buffers to fp32 for numerical stability""" + for submodule in module.modules(): + if isinstance(submodule, RoPE): + if hasattr(submodule, "cos"): + submodule.cos = submodule.cos.float() + if hasattr(submodule, "sin"): + submodule.sin = submodule.sin.float() + +def _swap_linear(module: nn.Module, replace_layernorm: bool = False): + """ + Recursively replace nn.Linear with te.Linear. + If replace_layernorm=True, also replace nn.LayerNorm with te.LayerNorm. + """ + for name, child in list(module.named_children()): + # Recurse first + _swap_linear(child, replace_layernorm) + + # Swap Linear + if isinstance(child, nn.Linear): + new = te.Linear( + in_features=child.in_features, + out_features=child.out_features, + bias=(child.bias is not None), + ).to(device=child.weight.device, dtype=child.weight.dtype) + + with torch.no_grad(): + new.weight.copy_(child.weight) + if child.bias is not None: + new.bias.copy_(child.bias) + + setattr(module, name, new) + + # (Optional) Swap LayerNorm – enable only if you want TE’s fused LN + elif replace_layernorm and isinstance(child, nn.LayerNorm): + # TE’s LayerNorm takes the hidden size (tuple OK) and eps + new_ln = te.LayerNorm( + child.normalized_shape, + eps=child.eps, + elementwise_affine=True, + ).to(device=child.weight.device, dtype=child.weight.dtype) + + with torch.no_grad(): + if child.weight is not None: + new_ln.weight.copy_(child.weight) + if child.bias is not None: + new_ln.bias.copy_(child.bias) + + setattr(module, name, new_ln) + +def convert_to_te_linears(model: nn.Module, replace_layernorm: bool = False) -> nn.Module: + model = model # in-place by default; return for convenience + _swap_linear(model, replace_layernorm=replace_layernorm) + return model + +@torch.no_grad() +def test_model_forward(): + print(f"Testing model forward pass with {N_FRAMES} frames...") + + # Load model + print("Loading model...") + model = from_pretrained(cfg_path, ckpt_path, return_decoder=False) + + # Convert to bfloat16 first, then replace linear layers + model = model.core.eval().cuda().bfloat16() + model.transformer.enable_decoding() + + # Replace attn blocks + n_layers = model.config.n_layers + config = model.config + + convert_to_te_linears(model) + + print("Done") + # Cast RoPE buffers to fp32 + cast_rope_buffers_to_fp32(model) + + # Load config for model dimensions + cfg = Config.from_yaml(cfg_path) + + print("Model loaded successfully!") + print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + + # Create test inputs + batch_size = 1 + + def create_test_inputs(): + """Create randomized test inputs""" + video = torch.randn(batch_size, N_FRAMES, 128, 8, 8, device='cuda', dtype=torch.bfloat16) + ts = torch.randn(batch_size, N_FRAMES, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(batch_size, N_FRAMES, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(batch_size, N_FRAMES, 11, device='cuda', dtype=torch.bfloat16) + return video, ts, mouse, btn + + # Create initial inputs to show shapes + video, ts, mouse, btn = create_test_inputs() + + print(f"Input shapes:") + print(f" video: {video.shape}") + print(f" ts: {ts.shape}") + print(f" mouse: {mouse.shape}") + print(f" btn: {btn.shape}") + + # Check initial VRAM + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Initial VRAM usage: {initial_memory:.2f} GB") + + # Warmup runs + print("Running 5 warmup iterations...") + for i in range(5): + video, ts, mouse, btn = create_test_inputs() + output = model(video, ts, mouse, btn) + print(f" Warmup {i+1}/5 completed") + + # Timing runs with CUDA events + print("Running 10 timed iterations...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + times = [] + for i in range(10): + video, ts, mouse, btn = create_test_inputs() + + torch.cuda.synchronize() + start_event.record() + + with fp8_autocast(enabled=True): + output = model(video, ts, mouse, btn) + + end_event.record() + torch.cuda.synchronize() + + elapsed_time = start_event.elapsed_time(end_event) # milliseconds + times.append(elapsed_time) + print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms") + + # Calculate statistics + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + fps = 1000.0 / avg_time # Convert ms to fps + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + current_memory = torch.cuda.memory_allocated() / 1024**3 + + print(f"\nTiming Results:") + print(f" Average time: {avg_time:.2f} ms") + print(f" Min time: {min_time:.2f} ms") + print(f" Max time: {max_time:.2f} ms") + print(f" Average FPS: {fps:.2f}") + print(f"\nMemory Results:") + print(f" Output shape: {output.shape}") + print(f" Peak VRAM usage: {peak_memory:.2f} GB") + print(f" Current VRAM usage: {current_memory:.2f} GB") + print(f" VRAM increase: {current_memory - initial_memory:.2f} GB") + +if __name__ == "__main__": + test_model_forward() diff --git a/inference/model_test.py b/inference/model_test.py new file mode 100644 index 00000000..c45613dd --- /dev/null +++ b/inference/model_test.py @@ -0,0 +1,116 @@ +import torch +import time +from owl_wms.configs import Config +from owl_wms import from_pretrained +from owl_wms.nn.rope import RoPE + +# Configuration +N_FRAMES = 1 +cfg_path = "configs/dit_v4_dmd.yml" +ckpt_path = "vid_dit_v4_dmd_7k.pt" + +def cast_rope_buffers_to_fp32(module): + """Cast RoPE buffers to fp32 for numerical stability""" + for submodule in module.modules(): + if isinstance(submodule, RoPE): + if hasattr(submodule, "cos"): + submodule.cos = submodule.cos.float() + if hasattr(submodule, "sin"): + submodule.sin = submodule.sin.float() + +@torch.no_grad() +def test_model_forward(): + print(f"Testing model forward pass with {N_FRAMES} frames...") + + # Load model + print("Loading model...") + model, decoder = from_pretrained(cfg_path, ckpt_path, return_decoder=True) + model = model.core.eval().cuda().bfloat16() + model = torch.compile(model) + + decoder = decoder.eval().cuda().bfloat16() + + # Cast RoPE buffers to fp32 + cast_rope_buffers_to_fp32(model) + + # Load config for model dimensions + cfg = Config.from_yaml(cfg_path) + + print("Model loaded successfully!") + print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") + + # Create test inputs + batch_size = 1 + + def create_test_inputs(): + """Create randomized test inputs""" + video = torch.randn(batch_size, N_FRAMES, 128, 8, 8, device='cuda', dtype=torch.bfloat16) + ts = torch.randn(batch_size, N_FRAMES, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(batch_size, N_FRAMES, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(batch_size, N_FRAMES, 11, device='cuda', dtype=torch.bfloat16) + return video, ts, mouse, btn + + # Create initial inputs to show shapes + video, ts, mouse, btn = create_test_inputs() + + print(f"Input shapes:") + print(f" video: {video.shape}") + print(f" ts: {ts.shape}") + print(f" mouse: {mouse.shape}") + print(f" btn: {btn.shape}") + + # Check initial VRAM + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Initial VRAM usage: {initial_memory:.2f} GB") + + # Warmup runs + print("Running 5 warmup iterations...") + for i in range(5): + video, ts, mouse, btn = create_test_inputs() + output = model(video, ts, mouse, btn) + print(f" Warmup {i+1}/5 completed") + + # Timing runs with CUDA events + print("Running 10 timed iterations...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + times = [] + for i in range(10): + video, ts, mouse, btn = create_test_inputs() + + torch.cuda.synchronize() + start_event.record() + + output = model(video, ts, mouse, btn) + + end_event.record() + torch.cuda.synchronize() + + elapsed_time = start_event.elapsed_time(end_event) # milliseconds + times.append(elapsed_time) + print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms") + + # Calculate statistics + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + fps = 1000.0 / avg_time # Convert ms to fps + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + current_memory = torch.cuda.memory_allocated() / 1024**3 + + print(f"\nTiming Results:") + print(f" Average time: {avg_time:.2f} ms") + print(f" Min time: {min_time:.2f} ms") + print(f" Max time: {max_time:.2f} ms") + print(f" Average FPS: {fps:.2f}") + print(f"\nMemory Results:") + print(f" Output shape: {output.shape}") + print(f" Peak VRAM usage: {peak_memory:.2f} GB") + print(f" Current VRAM usage: {current_memory:.2f} GB") + print(f" VRAM increase: {current_memory - initial_memory:.2f} GB") + +if __name__ == "__main__": + test_model_forward() diff --git a/inference/test_sampling.py b/inference/test_sampling.py index cbcfd9c9..b2bb60e8 100644 --- a/inference/test_sampling.py +++ b/inference/test_sampling.py @@ -8,9 +8,10 @@ from owl_wms.nn.rope import RoPE import torch +import gc -cfg_path = "configs/dit_v4.yml" -ckpt_path = "/mnt/data/lapp0/checkpoints/89220499/checkpoints/step_120000.pt" +cfg_path = "configs/dit_v4_dmd.yml" +ckpt_path = "vid_dit_v4_dmd_7k.pt" model, decoder = from_pretrained(cfg_path, ckpt_path, return_decoder=True) model = model.core.eval().cuda().bfloat16() @@ -46,47 +47,18 @@ def cast_rope_buffers_to_fp32(module): ) from owl_wms.sampling import get_sampler_cls - -only_return_generated = train_cfg.sampler_kwargs.pop("only_return_generated") import os sampler = get_sampler_cls(train_cfg.sampler_id)(**train_cfg.sampler_kwargs) -cache_path = "test_sampling_cache.pt" +data = torch.load("data_cache.pt") +vid = data["vid"] +mouse = data["mouse"] +btn = data["btn"] -if os.path.exists(cache_path): - print(f"Loading cached data from {cache_path}") - cache = torch.load(cache_path) - vid = cache["vid"] - mouse = cache["mouse"] - btn = cache["btn"] -else: - loader = get_loader( - train_cfg.data_id, - 1, # batch size must be 1 for the loader - **train_cfg.sample_data_kwargs - ) - - loader = iter(loader) - vids, mouses, btns, doc_ids = [], [], [], [] - for _ in range(16): - vid, mouse, btn, doc_id = [t.bfloat16().cuda() for t in next(loader)] - vids.append(vid) - mouses.append(mouse) - btns.append(btn) - doc_ids.append(doc_id) - # Stack along batch dimension - vids = torch.cat(vids, dim=0) - mouses = torch.cat(mouses, dim=0) - btns = torch.cat(btns, dim=0) - # Only use the first video, but all mouse/btn for batch_permute_to_length - vid = vids[:1] - mouse, btn = batch_permute_to_length(mouses, btns, sampler.num_frames + vid.size(1)) - mouse = mouse[:1] - btn = btn[:1] - # Save to cache - torch.save({"vid": vid, "mouse": mouse, "btn": btn}, cache_path) - print(f"Saved data to cache at {cache_path}") +vid = vid[:1] +mouse = mouse[:1] +btn = btn[:1] with torch.no_grad(): @@ -97,7 +69,10 @@ def cast_rope_buffers_to_fp32(module): btn = btn[:, vid.size(1):] del model - + # Clear cuda cachce and collect garbage + torch.cuda.empty_cache() + gc.collect() + video = decode_fn(latent_vid * train_cfg.vae_scale) wandb_av_out = to_wandb_av(video, None, mouse, btn) diff --git a/owl_wms/__init__.py b/owl_wms/__init__.py index c339b906..3ff7dbc3 100644 --- a/owl_wms/__init__.py +++ b/owl_wms/__init__.py @@ -9,7 +9,11 @@ def from_pretrained(cfg_path, ckpt_path, return_decoder=False): cfg = Config.from_yaml(cfg_path) model = get_model_cls(cfg.model.model_id)(cfg.model) - model.load_state_dict(versatile_load(ckpt_path)) + + try: + model.load_state_dict(versatile_load(ckpt_path)) + except: + model.core.load_state_dict(versatile_load(ckpt_path)) if not return_decoder: return model diff --git a/owl_wms/nn/attn.py b/owl_wms/nn/attn.py index 57c97feb..557a029a 100644 --- a/owl_wms/nn/attn.py +++ b/owl_wms/nn/attn.py @@ -76,6 +76,10 @@ def __init__(self, config, layer_idx, local = False): self.local = local self.local_offset = config.local_window * config.tokens_per_frame + def cache_update(self, k, v, kv_cache = None): + if kv_cache is not None and kv_cache.should_update: + kv_cache.update(k, v, self.layer_idx) + def forward(self, x, block_mask, kv_cache=None): B, L, _ = x.shape @@ -88,15 +92,11 @@ def forward(self, x, block_mask, kv_cache=None): q = self.rope(q, offset=offset) k = self.rope(k, offset=offset) - # prepend cached values - if offset > 0: - old_k, old_v = kv_cache.get(self.layer_idx) - k = torch.cat([old_k, k], dim=2) - v = torch.cat([old_v, v], dim=2) - # update cache - if kv_cache is not None and kv_cache.should_update: - kv_cache.update(k, v, self.layer_idx) + self.cache_update(k, v, kv_cache) + + if offset > 0: + k, v = kv_cache.get(self.layer_idx) # NOTE: Using block_mask = None to mark decoding, probably need something more explicit in future if self.local and block_mask is None: diff --git a/owl_wms/nn/kv_cache.py b/owl_wms/nn/kv_cache.py index b4c7628c..89f9b2d4 100644 --- a/owl_wms/nn/kv_cache.py +++ b/owl_wms/nn/kv_cache.py @@ -102,3 +102,78 @@ def detach(self): @property def shape(self): return self.cache[0][0].shape + +class InferenceKVCache: + """ + Optimized for inference with static shapes and in-place updates + """ + def __init__(self, kv_cache : SingleKVCache): + self.cache = kv_cache.cache + self.offsets = kv_cache.offsets + + self.tokens_per_frame = kv_cache.config.tokens_per_frame + self.n_layers = kv_cache.config.n_layers + self.should_update = False + + def enable_cache_updates(self): + self.should_update = True + + def disable_cache_updates(self): + self.should_update = False + + def to(self, device = 'cuda', dtype = torch.bfloat16): + self.device = device + self.dtype = dtype + return self + + def get(self, layer_ind): + assert self.cache is not None, "Must reset cache before using" + k,v = self.cache[layer_ind] + return k,v + + def update(self, new_k, new_v, layer_ind): + # Assumed that both are tokens_per_frame + self.cache[layer_ind][0][:,:,:-self.tokens_per_frame] = self.cache[layer_ind][0][:,:,self.tokens_per_frame:] + self.cache[layer_ind][1][:,:,:-self.tokens_per_frame] = self.cache[layer_ind][1][:,:,self.tokens_per_frame:] + + self.cache[layer_ind][0][:,:,-self.tokens_per_frame:] = new_k + self.cache[layer_ind][1][:,:,-self.tokens_per_frame:] = new_v + + self.offsets[layer_ind] += self.tokens_per_frame + + def _update(self, new_k, new_v, layer_ind): + self.cache[layer_ind][0] = torch.cat([self.cache[layer_ind][0][:,:,self.tokens_per_frame:], new_k], dim=2) + self.cache[layer_ind][1] = torch.cat([self.cache[layer_ind][1][:,:,self.tokens_per_frame:], new_v], dim=2) + self.offsets[layer_ind] += self.tokens_per_frame + + def truncate(self, *args, **kwargs): + pass + + def length_at(self, idx): + return self.cache[idx][0].shape[2] + + def get_offset(self, idx=0): + return self.offsets[idx] + + def __len__(self): + assert self.cache is not None, "Must reset cache before using" + return self.cache[0][0].shape[2] + + def n_frames(self): + assert len(self) % self.config.tokens_per_frame == 0 + return len(self) // self.config.tokens_per_frame + + def clone(self): + # Clones all tensors for max-autotune to work properly + for i in range(self.config.n_layers): + self.cache[i] = (self.cache[i][0].clone(), self.cache[i][1].clone()) + return self + + def detach(self): + for i in range(self.config.n_layers): + self.cache[i] = (self.cache[i][0].detach(), self.cache[i][1].detach()) + return self + + @property + def shape(self): + return self.cache[0][0].shape diff --git a/owl_wms/sampling/av_caching_v2.py b/owl_wms/sampling/av_caching_v2.py index 29e4dd3c..fc6f7d53 100644 --- a/owl_wms/sampling/av_caching_v2.py +++ b/owl_wms/sampling/av_caching_v2.py @@ -2,7 +2,7 @@ from tqdm import tqdm import gc -from ..nn.kv_cache import KVCache +from ..nn.kv_cache import KVCache, InferenceKVCache from .schedulers import get_sd3_euler @@ -74,6 +74,7 @@ def __call__(self, model, x, mouse, btn, compile_on_decode = False): kv_cache=kv_cache ) kv_cache.disable_cache_updates() + kv_cache = InferenceKVCache(kv_cache) def new_xt(): return torch.randn_like(prev_x[:,:1]), prev_t.new_ones(batch_size, 1) @@ -133,8 +134,6 @@ def new_xt(): kv_cache=kv_cache ) kv_cache.disable_cache_updates() - if self.max_window is not None and len(latents) > self.max_window: - kv_cache.truncate(1, front=False) # Eject oldest gc.collect() torch.cuda.empty_cache() diff --git a/owl_wms/utils/owl_vae_bridge.py b/owl_wms/utils/owl_vae_bridge.py index de985ca4..9a0f119f 100644 --- a/owl_wms/utils/owl_vae_bridge.py +++ b/owl_wms/utils/owl_vae_bridge.py @@ -2,7 +2,6 @@ import os import torch -from diffusers import AutoencoderDC sys.path.append("./owl-vaes") from owl_vaes.utils.proxy_init import load_proxy_model