Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions configs/dit_v4_dmd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,20 @@ train:
sampler_id: av_caching
sampler_kwargs:
n_steps: 2
#window_length: 16
cfg_scale: 1.0
num_frames: 600
num_frames: 60
noise_prev: 0.2
custom_schedule: [1.0, 0.5]

n_samples: 4
eval_sample_dir: null

vae_id: null
vae_batch_size: 4
vae_batch_size: 1

vae_scale: 0.63
vae_cfg_path: /mnt/data/shahbuland/owl-vaes/configs/cod_yt_v2/base.yml
vae_ckpt_path: /mnt/data/checkpoints/owl_vaes/cod_yt_v2/cod_yt_v2_515k_ema_decoder.pt
vae_cfg_path: ./dec_dist.yml
vae_ckpt_path: ./dec_distill_cod_v2.pt

update_ratio: 5
ts_shift: 8
Expand Down
84 changes: 84 additions & 0 deletions inference/build_cache_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import joblib

import torch
from owl_wms.configs import Config
from owl_wms import from_pretrained
from owl_wms.nn.rope import RoPE
from owl_wms.nn.kv_cache import KVCache

def cast_rope_buffers_to_fp32(module):
"""Cast RoPE buffers to fp32 for numerical stability"""
for submodule in module.modules():
if isinstance(submodule, RoPE):
if hasattr(submodule, "cos"):
submodule.cos = submodule.cos.float()
if hasattr(submodule, "sin"):
submodule.sin = submodule.sin.float()

@torch.no_grad()
def build_cache():
# Configuration
cfg_path = "configs/dit_v4_dmd.yml"
ckpt_path = "vid_dit_v4_dmd_7k.pt"

# Load model (no decoder needed)
print("Loading model...")
model = from_pretrained(cfg_path, ckpt_path, return_decoder=False)
model = model.core.eval().cuda().bfloat16()

# Cast RoPE buffers to fp32
cast_rope_buffers_to_fp32(model)

print("Model loaded successfully!")

# Load data cache
data = torch.load("data_cache.pt")
vid = data["vid"]
mouse = data["mouse"]
btn = data["btn"]

batch_size = vid.size(0)
init_len = vid.size(1)

print(f"Cache input shapes:")
print(f" vid: {vid.shape}")
print(f" mouse: {mouse.shape}")
print(f" btn: {btn.shape}")

# Initialize KV cache
kv_cache = KVCache(model.config)
kv_cache.reset(batch_size)

# Build cache with context frames
noise_prev = 0.2
vid_noisy = vid * (1. - noise_prev) + torch.randn_like(vid) * noise_prev
t_noisy = vid.new_full((batch_size, init_len), noise_prev)

init_len = 60
vid_noisy = vid_noisy[:,:init_len]
t_noisy = t_noisy[:vid_noisy.size(0),:init_len]
mouse = mouse[:vid_noisy.size(0),:init_len]
btn = btn[:vid_noisy.size(0),:init_len]


print("Building KV cache...")
kv_cache.enable_cache_updates()
_ = model(
vid_noisy,
t_noisy,
mouse,
btn,
kv_cache=kv_cache
)
kv_cache.disable_cache_updates()

# Save the cache object
cache_object = kv_cache

print("Saving cache object...")
joblib.dump(cache_object, 'kv_cache_object.pkl')
print("Cache object saved to 'kv_cache_object.pkl'")

if __name__ == "__main__":
build_cache()

8 changes: 1 addition & 7 deletions inference/causvid_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def to_bgr_uint8(frame, target_size=(1080,1920)):
return frame

class CausvidPipeline:
def __init__(self, cfg_path="configs/causvid.yml", ckpt_path="causvid_ema.pt"):
def __init__(self, cfg_path="configs/dit_v4_dmd.yml", ckpt_path="vid_dit_v4_dmd_7k.pt"):
cfg = Config.from_yaml(cfg_path)
model_cfg = cfg.model
train_cfg = cfg.train
Expand All @@ -42,12 +42,6 @@ def __init__(self, cfg_path="configs/causvid.yml", ckpt_path="causvid_ema.pt"):
)
self.frame_decoder = self.frame_decoder.cuda().bfloat16().eval()

#audio_decoder = get_decoder_only(
# None,
# train_cfg.audio_vae_cfg_path,
# train_cfg.audio_vae_ckpt_path
#)

# Store scales as instance variables
self.frame_scale = train_cfg.vae_scale
self.audio_scale = train_cfg.audio_vae_scale
Expand Down
96 changes: 96 additions & 0 deletions inference/decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from torch import nn
from owl_wms import from_pretrained
import gc


import torch, torch_tensorrt
from torch_tensorrt.dynamo import compile

cfg_path = "configs/dit_v4_dmd.yml"
ckpt_path = "vid_dit_v4_dmd_7k.pt"

_, decoder = from_pretrained(cfg_path, ckpt_path, return_decoder=True)
decoder = decoder.eval().cuda().bfloat16()
#decoder = torch.compile(decoder, mode = 'max-autotune', fullgraph = True, dynamic = False)
decoder =
# Clear cache
torch.cuda.empty_cache()
gc.collect()

# Configuration
BATCH_SIZE = 1

@torch.no_grad()
def test_decoder():
print(f"Testing decoder with batch size {BATCH_SIZE}...")

print("Decoder loaded successfully!")
print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters()) / 1e6:.1f}M")

def create_test_inputs():
"""Create randomized test inputs"""
x = torch.randn(BATCH_SIZE, 128, 8, 8, device='cuda', dtype=torch.bfloat16)
return x

# Create initial inputs to show shapes
x = create_test_inputs()

print(f"Input shapes:")
print(f" x: {x.shape}")

# Check initial VRAM
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Initial VRAM usage: {initial_memory:.2f} GB")

# Warmup runs
print("Running 5 warmup iterations...")
for i in range(5):
x = create_test_inputs()
output = decoder(x)
print(f" Warmup {i+1}/5 completed")

# Timing runs with CUDA events
print("Running 10 timed iterations...")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

times = []
for i in range(10):
x = create_test_inputs()

torch.cuda.synchronize()
start_event.record()

output = decoder(x)

end_event.record()
torch.cuda.synchronize()

elapsed_time = start_event.elapsed_time(end_event) # milliseconds
times.append(elapsed_time)
print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms")

# Calculate statistics
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
fps = 1000.0 / avg_time # Convert ms to fps

peak_memory = torch.cuda.max_memory_allocated() / 1024**3
current_memory = torch.cuda.memory_allocated() / 1024**3

print(f"\nTiming Results:")
print(f" Average time: {avg_time:.2f} ms")
print(f" Min time: {min_time:.2f} ms")
print(f" Max time: {max_time:.2f} ms")
print(f" Average FPS: {fps:.2f}")
print(f"\nMemory Results:")
print(f" Output shape: {output.shape}")
print(f" Peak VRAM usage: {peak_memory:.2f} GB")
print(f" Current VRAM usage: {current_memory:.2f} GB")
print(f" VRAM increase: {current_memory - initial_memory:.2f} GB")

if __name__ == "__main__":
test_decoder()
155 changes: 155 additions & 0 deletions inference/model_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import time
from owl_wms.configs import Config
from owl_wms import from_pretrained
from owl_wms.nn.rope import RoPE
from owl_wms.nn.kv_cache import KVCache, InferenceKVCache

# Configuration
N_FRAMES_CACHE = 60
N_FRAMES = 1
cfg_path = "configs/dit_v4_dmd.yml"
ckpt_path = "vid_dit_v4_dmd_7k.pt"

def cast_rope_buffers_to_fp32(module):
"""Cast RoPE buffers to fp32 for numerical stability"""
for submodule in module.modules():
if isinstance(submodule, RoPE):
if hasattr(submodule, "cos"):
submodule.cos = submodule.cos.float()
if hasattr(submodule, "sin"):
submodule.sin = submodule.sin.float()

@torch.no_grad()
def test_model_forward():
print(f"Testing model forward pass with {N_FRAMES_CACHE} cache frames + {N_FRAMES} new frames...")

# Load model
print("Loading model...")
model, vae = from_pretrained(cfg_path, ckpt_path, return_decoder=True)
model = model.core.eval().cuda().bfloat16()
vae = vae.eval().cuda().bfloat16()


# Cast RoPE buffers to fp32
cast_rope_buffers_to_fp32(model)

# Load config for model dimensions
cfg = Config.from_yaml(cfg_path)

print("Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

# Create test inputs
batch_size = 1

def create_cache_inputs():
"""Create randomized cache inputs"""
video = torch.randn(batch_size, N_FRAMES_CACHE, 128, 8, 8, device='cuda', dtype=torch.bfloat16)
ts = torch.randn(batch_size, N_FRAMES_CACHE, device='cuda', dtype=torch.bfloat16)
mouse = torch.randn(batch_size, N_FRAMES_CACHE, 2, device='cuda', dtype=torch.bfloat16)
btn = torch.randn(batch_size, N_FRAMES_CACHE, 11, device='cuda', dtype=torch.bfloat16)
return video, ts, mouse, btn

def create_test_inputs():
"""Create randomized test inputs"""
video = torch.randn(batch_size, N_FRAMES, 128, 8, 8, device='cuda', dtype=torch.bfloat16)
ts = torch.randn(batch_size, N_FRAMES, device='cuda', dtype=torch.bfloat16)
mouse = torch.randn(batch_size, N_FRAMES, 2, device='cuda', dtype=torch.bfloat16)
btn = torch.randn(batch_size, N_FRAMES, 11, device='cuda', dtype=torch.bfloat16)
return video, ts, mouse, btn

# Initialize KV cache
kv_cache = KVCache(model.config)
kv_cache.reset(batch_size)

# Create initial cache inputs to show shapes
cache_video, cache_ts, cache_mouse, cache_btn = create_cache_inputs()
video, ts, mouse, btn = create_test_inputs()

print(f"Cache input shapes:")
print(f" cache_video: {cache_video.shape}")
print(f" cache_ts: {cache_ts.shape}")
print(f" cache_mouse: {cache_mouse.shape}")
print(f" cache_btn: {cache_btn.shape}")

print(f"New input shapes:")
print(f" video: {video.shape}")
print(f" ts: {ts.shape}")
print(f" mouse: {mouse.shape}")
print(f" btn: {btn.shape}")

# Check initial VRAM
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Initial VRAM usage: {initial_memory:.2f} GB")

# Initialize cache with initial context
print("Initializing KV cache with context...")
kv_cache.enable_cache_updates()
_ = model(cache_video, cache_ts, cache_mouse, cache_btn, kv_cache=kv_cache)
kv_cache.disable_cache_updates()

kv_cache = InferenceKVCache(kv_cache)
model.transformer.enable_decoding()

model = torch.compile(model, mode = 'max-autotune', dynamic=False, fullgraph=True)

def call_fn(model, video, ts, mouse, btn, kv_cache):
for _ in range(2):
pred = model(video, ts, mouse, btn, kv_cache=kv_cache)
video = video - pred * 0.5

kv_cache.enable_cache_updates()
_ = model(video, ts, mouse, btn, kv_cache = kv_cache)
kv_cache.disable_cache_updates()

# Warmup runs
print("Running 5 warmup iterations...")
for i in range(5):
video, ts, mouse, btn = create_test_inputs()
call_fn(model, video, ts, mouse, btn, kv_cache)
print(f" Warmup {i+1}/5 completed")

# Timing runs with CUDA events
print("Running 10 timed iterations...")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

times = []
for i in range(10):
video, ts, mouse, btn = create_test_inputs()

torch.cuda.synchronize()
start_event.record()

call_fn(model, video, ts, mouse, btn, kv_cache)

end_event.record()
torch.cuda.synchronize()

elapsed_time = start_event.elapsed_time(end_event) # milliseconds
times.append(elapsed_time)
print(f" Iteration {i+1}/10: {elapsed_time:.2f} ms")

# Calculate statistics
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
fps = 1000.0 / avg_time # Convert ms to fps

peak_memory = torch.cuda.max_memory_allocated() / 1024**3
current_memory = torch.cuda.memory_allocated() / 1024**3

print(f"\nTiming Results:")
print(f" Average time: {avg_time:.2f} ms")
print(f" Min time: {min_time:.2f} ms")
print(f" Max time: {max_time:.2f} ms")
print(f" Average FPS: {fps:.2f}")
print(f"\nMemory Results:")
print(f" Peak VRAM usage: {peak_memory:.2f} GB")
print(f" Current VRAM usage: {current_memory:.2f} GB")
print(f" VRAM increase: {current_memory - initial_memory:.2f} GB")

if __name__ == "__main__":
test_model_forward()
Loading