diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..86a9f629 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +checkpoints/* +venv/ +__pycache__/* +generated_videos/ +owl-vaes/ +*.pt +*.env +*.pyc +.vscode/* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..702b781d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,57 @@ +# Use CUDA 12.8 runtime as base image for lightweight deployment +FROM nvidia/cuda:12.8.1-runtime-ubuntu22.04 + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONPATH=/app + +# Install system dependencies (without python3.12 first) +RUN apt-get update && apt-get install -y --no-install-recommends \ + wget \ + curl \ + python3-pip \ + git \ + software-properties-common \ + && rm -rf /var/lib/apt/lists/* + +# Add deadsnakes PPA and install Python 3.12 +RUN add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + python3.12 \ + python3.12-dev \ + python3.12-venv \ + && rm -rf /var/lib/apt/lists/* + +# Set Python 3.12 as default +RUN ln -sf /usr/bin/python3.12 /usr/bin/python3 && \ + ln -sf /usr/bin/python3 /usr/bin/python + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/root/.cargo/bin sh +ENV PATH="/root/.cargo/bin:${PATH}" + +# Create working directory +WORKDIR /app + +# Copy requirements file first for better layer caching +COPY requirements.txt . + +# Install PyTorch with CUDA 12.8 support and sm120 architecture support +RUN uv pip install --system torch torchvision --index-url https://download.pytorch.org/whl/cu128 + +# Install other requirements from requirements.txt +RUN uv pip install --system -r requirements.txt + +RUN git submodule update --init --recursive + +# Copy the entire application +COPY . /app + +# Expose the port that the FastAPI server runs on +EXPOSE 8000 + +# Set the default command to run the web server +CMD ["python3", "webapp/server.py", "--port", "8000", "--no-debug"] diff --git a/checkpoints/wm/dcae_hf_cod/basic.yml b/checkpoints/wm/dcae_hf_cod/basic.yml new file mode 100644 index 00000000..541a3991 --- /dev/null +++ b/checkpoints/wm/dcae_hf_cod/basic.yml @@ -0,0 +1,67 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: game_rft_core + sample_size: 4 + channels: 128 + + n_layers: 17 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 16 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 60 + + causal: false + +train: + trainer_id: rft + data_id: cod_latent + data_kwargs: + window_length: 60 + root: ../cod_data/BlackOpsColdWar + add_optical_flow: false + + target_batch_size: 320 + batch_size: 40 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 1.0e-4 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [core.proj_in, core.proj_out.proj] + + scheduler: null + + checkpoint_dir: checkpoints/v2 + resume_ckpt: checkpoints/v2/step_165000.pt + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: window + sampler_kwargs: + n_steps: 32 + cfg_scale: 1.3 + window_length: 60 + num_frames: 120 + noise_prev: 0.2 + only_return_generated: true + + vae_batch_size: 16 + vae_scale: 2.17 + n_samples: 8 + +wandb: + name: shahbuland + project: video_models + run_name: v2 \ No newline at end of file diff --git a/configs/360p_v2.yml b/configs/360p_v2.yml new file mode 100644 index 00000000..9a3ced68 --- /dev/null +++ b/configs/360p_v2.yml @@ -0,0 +1,70 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: game_rft + sample_size: 4 + channels: 128 + + n_layers: 13 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 16 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 30 + + causal: false + +train: + trainer_id: rft + data_id: cod_s3 + data_kwargs: + window_length: 30 + bucket_name: cod-data-latent-360x640to4x4 + include_keyframe: false + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 1.0e-4 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [core.proj_in, core.proj_out.proj] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: window + sampler_kwargs: + n_steps: 10 + cfg_scale: 1.3 + window_length: 30 + num_frames: 60 + noise_prev: 0.2 + only_return_generated: false + + n_samples: 8 + + vae_id: 720pr3dc + vae_batch_size: 4 + vae_scale: 0.13 + vae_cfg_path: configs/owl_vaes/cod_128x.yml + vae_ckpt_path: checkpoints/owl_vaes/cod_128x_30k_ema.pt + +wandb: + name: shahbuland + project: video_models + run_name: v3 \ No newline at end of file diff --git a/configs/av.yml b/configs/av.yml new file mode 100644 index 00000000..72811dfb --- /dev/null +++ b/configs/av.yml @@ -0,0 +1,75 @@ +model: + model_id: game_rft_audio + sample_size: 4 + channels: 128 + audio_channels: 64 + + n_layers: 13 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 17 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 30 + + causal: false + +train: + trainer_id: av + data_id: cod_s3_audio + data_kwargs: + window_length: 30 + bucket_name: cod-data-latent-360x640to4x4 + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 1.0e-4 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [core.proj_in, core.proj_out.proj] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: av_window + sampler_kwargs: + n_steps: 10 + cfg_scale: 1.3 + window_length: 30 + num_frames: 60 + noise_prev: 0.2 + only_return_generated: false + + n_samples: 8 + + vae_id: null + vae_batch_size: 4 + vae_scale: 0.13 + audio_vae_scale: 0.17 + + vae_cfg_path: configs/owl_vaes/cod_128x.yml + vae_ckpt_path: checkpoints/owl_vaes/cod_128x_30k_ema.pt + + audio_vae_id: null + audio_vae_cfg_path: configs/owl_vaes/cod_audio.yml + audio_vae_ckpt_path: checkpoints/owl_vaes/cod_audio_20k_ema.pt + +wandb: + name: shahbuland + project: video_models + run_name: av \ No newline at end of file diff --git a/configs/causvid.yml b/configs/causvid.yml new file mode 100644 index 00000000..065a7f9b --- /dev/null +++ b/configs/causvid.yml @@ -0,0 +1,69 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: game_rft + sample_size: 4 + channels: 128 + + n_layers: 17 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 16 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.0 + n_frames: 30 + + causal: false + +train: + trainer_id: causvid + data_id: cod_latent + data_kwargs: + window_length: 30 + root: ../cod_data/BlackOpsColdWar + add_optical_flow: false + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 2.0e-6 + weight_decay: 1.0e-4 + eps: 1.0e-15 + betas: [0.9, 0.95] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: window + sampler_kwargs: + n_steps: 20 + cfg_scale: 1.3 + window_length: 30 + num_frames: 60 + noise_prev: 0.2 + only_return_generated: true + + n_samples: 8 + + vae_id: 720pr3dc + vae_batch_size: 4 + vae_scale: 0.35 + vae_cfg_path: configs/owl_vaes/128x_cod_stage2.yml + vae_ckpt_path: 720p_cod_vae_30m_35k_steps.pt + + teacher_ckpt: null # Set later TODO + +wandb: + name: shahbuland + project: video_models + run_name: v2 \ No newline at end of file diff --git a/configs/owl_vaes/cod_128x.yml b/configs/owl_vaes/cod_128x.yml new file mode 100644 index 00000000..76502ca9 --- /dev/null +++ b/configs/owl_vaes/cod_128x.yml @@ -0,0 +1,53 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: dcae + sample_size: [360,640] + channels: 3 + latent_size: 4 + latent_channels: 128 + + noise_decoder_inputs: 0.0 + ch_0: 128 + ch_max: 1024 + + encoder_blocks_per_stage: [3, 3, 3, 3, 3, 3, 3, 3] + decoder_blocks_per_stage: [3, 3, 3, 3, 3, 3, 3, 3] + + checkpoint_grads: true + +train: + trainer_id: rec + data_id: s3_cod + target_batch_size: 128 + batch_size: 16 + + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + weight_decay: 1.0e-4 + betas: [0.9, 0.95] + eps: 1.0e-15 + + lpips_type: convnext + loss_weights: + latent_reg: 1.0e-6 + lpips: 10.0 + se_reg: 0.0 + + scheduler: LinearWarmup + scheduler_kwargs: + warmup_steps: 3000 + min_lr: 5.0e-6 + + checkpoint_dir: checkpoints/cod_128x + resume_ckpt: null #checkpoints/2d_64x/step_10000.pt + + sample_interval: 1000 + save_interval: 5000 + +wandb: + name: ${env:WANDB_USER_NAME} + project: new_vaes + run_name: 128x_cod \ No newline at end of file diff --git a/configs/owl_vaes/cod_audio.yml b/configs/owl_vaes/cod_audio.yml new file mode 100644 index 00000000..ad198e7c --- /dev/null +++ b/configs/owl_vaes/cod_audio.yml @@ -0,0 +1,57 @@ +model: + model_id: audio_ae + + channels: 2 + latent_channels: 64 + ch_0: 128 + ch_max: 512 + + strides: [3, 5, 7, 7, 1] + + eq: true + checkpoint_grads: true + +train: + trainer_id: audio_rec + data_id: local_cod_audio + data_kwargs: + window_length: 88200 + root: "../cod_download/raw" + + target_batch_size: 128 + batch_size: 16 + epochs: 100 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + eps: 1.0e-15 + betas: [0.9, 0.95] + weight_decay: 1.0e-4 + + loss_weights: + recon: 2.5 + stft: 1.5 + kl: 1.0e-5 + lr_ms_ratio: 0.5 + hubert: 0.0 + crt: 4.0 + + sample_rate: 44100 + n_fft_list: [1024, 2048, 512] + + scheduler: LinearWarmup + scheduler_kwargs: + warmup_steps: 1500 + min_lr: 1.0e-6 + + checkpoint_dir: checkpoints/audio_ae + sample_interval: 500 + save_interval: 5000 + + resume_ckpt: null + +wandb: + name: ${env:WANDB_USER_NAME} + project: owl_audio_vaes + run_name: audio_ae_baseline \ No newline at end of file diff --git a/configs/self_forcing.yaml b/configs/self_forcing.yaml new file mode 100644 index 00000000..9b951031 --- /dev/null +++ b/configs/self_forcing.yaml @@ -0,0 +1,77 @@ +# Config for Self-Forcing training with autoregressive rollout +model: + model_id: game_rft + sample_size: [5,8] + channels: 64 + + n_layers: 17 + n_heads: 20 + d_model: 1280 + + tokens_per_frame: 40 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 30 + + causal: true # Enable causal attention + +train: + trainer_id: self_forcing + data_id: cod_latent + data_kwargs: + window_length: 30 + root: ../cod_data/BlackOpsColdWar + add_optical_flow: false + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + weight_decay: 1.0e-4 + eps: 1.0e-15 + betas: [0.9, 0.95] + + scheduler: null + + checkpoint_dir: checkpoints/self_forcing + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: window + sampler_kwargs: + n_steps: 20 + cfg_scale: 1.3 + window_length: 30 + num_frames: 60 + noise_prev: 0.2 + only_return_generated: true + + n_samples: 8 + + vae_id: dcae + vae_batch_size: 4 + vae_scale: 0.35 + vae_cfg_path: configs/owl_vaes/128x_cod_stage2.yml + vae_ckpt_path: 720p_cod_vae_30m_35k_steps.pt + + # Self-forcing specific parameters + # teacher_ckpt: checkpoints/bidirectional_teacher/best.pt # Pretrained bidirectional model + teacher_ckpt: null # Pretrained bidirectional model + loss_type: dmd # Options: dmd, sid, gan + gradient_steps: 2 # Number of steps to backprop through + rollout_steps: 5 # Total autoregressive rollout length + stochastic_steps: true # Random gradient truncation + update_ratio: 5 # Critic updates per generator update + cfg_scale: 1.3 # Classifier-free guidance scale + +wandb: + name: samibg + project: video_models + run_name: self_forcing_dmd \ No newline at end of file diff --git a/configs/shortcut.yml b/configs/shortcut.yml new file mode 100644 index 00000000..891a9707 --- /dev/null +++ b/configs/shortcut.yml @@ -0,0 +1,64 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: game_rft_shortcut + sample_size: 5 + channels: 64 + + n_layers: 5 + n_heads: 6 + d_model: 384 + + tokens_per_frame: 40 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 60 + + causal: false + +train: + trainer_id: shortcut + data_id: cod_s3 + data_kwargs: + window_length: 60 + bucket_name: cod-data-latent-360x640to5x8 + include_keyframe: true + + target_batch_size: 16 + batch_size: 16 + + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + weight_decay: 0.1 + eps: 1.0e-15 + betas: [0.9, 0.95] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: shortcut + sampler_kwargs: + window_length: 60 + num_frames: 60 + only_return_generated: true + + n_samples: 8 + + vae_id: 720pr3dc + vae_batch_size: 4 + vae_scale: 0.35 + vae_cfg_path: configs/owl_vaes/128x_cod_stage2.yml + vae_ckpt_path: 720p_cod_vae_30m_35k_steps.pt + +wandb: + name: shahbuland + project: video_models + run_name: v2 diff --git a/configs/shortcut_2.yml b/configs/shortcut_2.yml new file mode 100644 index 00000000..b6a79a58 --- /dev/null +++ b/configs/shortcut_2.yml @@ -0,0 +1,70 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: shortcut_2 + sample_size: 4 + channels: 128 + + n_layers: 13 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 16 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 30 + + causal: false + +train: + trainer_id: shortcut_2 + data_id: cod_s3 + data_kwargs: + window_length: 30 + bucket_name: cod-data-latent-360x640to4x4 + include_keyframe: false + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 0.1 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [ + core.proj_in, + core.proj_out.proj + ] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: shortcut_2 + sampler_kwargs: + window_length: 30 + num_frames: 120 + only_return_generated: true + + n_samples: 8 + + vae_id: 720pr3dc + vae_batch_size: 4 + vae_scale: 0.13 + vae_cfg_path: configs/owl_vaes/cod_128x.yml + vae_ckpt_path: checkpoints/owl_vaes/cod_128x_30k_ema.pt + +wandb: + name: shahbuland + project: video_models + run_name: bidir_shortcut diff --git a/configs/webapp/config.yaml b/configs/webapp/config.yaml new file mode 100644 index 00000000..1de97cc9 --- /dev/null +++ b/configs/webapp/config.yaml @@ -0,0 +1,25 @@ +model_checkpoint_path: "checkpoints/wm/dcae_hf_cod/ckpt_165k_ema.pt" +run_config_path: "checkpoints/wm/dcae_hf_cod/basic.yml" +device: "cuda" + +stream_config: + fps: 20 + frames_per_batch: 60 + window_length: 60 + device: "cuda" + n_buttons: 11 + n_mouse_axes: 2 + mouse_range: [-1.0, 1.0] + action_margin_px_height: 150 + +sampling_config: + sampling_steps: 20 + vae_scale: 1.0 + cfg_scale: 1.3 + window_length: 60 + num_frames: 60 + noise_prev: 0.2 + +run_config: null # loaded at runtime from model_config_path, and used to access model and train config + + diff --git a/owl_wms/configs.py b/owl_wms/configs.py index d572a3c9..7b463e00 100644 --- a/owl_wms/configs.py +++ b/owl_wms/configs.py @@ -1,26 +1,30 @@ -from dataclasses import dataclass, field -from typing import List, Optional import yaml from omegaconf import OmegaConf +from dataclasses import dataclass + @dataclass class TransformerConfig: model_id : str = None + channels : int = 128 + sample_size : int = 16 + patch_size : int = 1 n_layers : int = 12 n_heads : int = 12 d_model : int = 384 - patch_size : int = 1 - channels : int = 128 - sample_size : int = 16 + audio_channels : int = 64 cfg_prob : float = 0.1 n_buttons : int = 8 tokens_per_frame : int = 16 + audio_tokens : int = 0 + n_frames : int = 120 causal : bool = False + @dataclass class TrainingConfig: trainer_id : str = None @@ -60,6 +64,11 @@ class TrainingConfig: vae_scale : float = 0.34 vae_batch_size: int = 4 + audio_vae_id : str = None + audio_vae_cfg_path : str = None + audio_vae_ckpt_path : str = None + audio_vae_scale : float = 0.17 + @dataclass class WANDBConfig: name : str = None @@ -78,4 +87,5 @@ def from_yaml(cls, path): raw_cfg = yaml.safe_load(f) cfg = OmegaConf.create(raw_cfg) - return OmegaConf.structured(cls(**cfg)) \ No newline at end of file + return OmegaConf.structured(cls(**cfg)) + diff --git a/owl_wms/data/__init__.py b/owl_wms/data/__init__.py index 564abd25..ce809017 100644 --- a/owl_wms/data/__init__.py +++ b/owl_wms/data/__init__.py @@ -8,4 +8,7 @@ def get_loader(data_id, batch_size, **data_kwargs): return local_cod_latent.get_loader(batch_size, **data_kwargs) elif data_id == "cod_s3": from . import s3_cod_latent - return s3_cod_latent.get_loader(batch_size, **data_kwargs) \ No newline at end of file + return s3_cod_latent.get_loader(batch_size, **data_kwargs) + elif data_id == "cod_s3_audio": + from . import s3_cod_latent_audio + return s3_cod_latent_audio.get_loader(batch_size, **data_kwargs) \ No newline at end of file diff --git a/owl_wms/data/s3_cod_latent.py b/owl_wms/data/s3_cod_latent.py index d223d315..f7eb59bb 100644 --- a/owl_wms/data/s3_cod_latent.py +++ b/owl_wms/data/s3_cod_latent.py @@ -33,13 +33,15 @@ def pop(self): BUCKET_NAME="cod-data-latent-360x640to5x8" class S3CoDLatentDataset(IterableDataset): - def __init__(self, window_length=120, file_share_max=20, rank=0, world_size=1): + def __init__(self, window_length=120, file_share_max=20, rank=0, world_size=1, bucket_name = BUCKET_NAME, include_keyframe = False): super().__init__() self.window = window_length self.file_share_max = file_share_max self.rank = rank self.world_size = world_size + self.include_keyframe = include_keyframe + self.bucket_name = bucket_name # Queue parameters self.max_tars = 2 @@ -79,7 +81,7 @@ def background_download_tars(self): tar_path = self.random_sample_prefix() try: # Download tar directly to memory - response = self.s3_client.get_object(Bucket=BUCKET_NAME, Key=tar_path) + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=tar_path) tar_data = response['Body'].read() self.tar_queue.add(tar_data) except Exception as e: @@ -140,8 +142,23 @@ def background_load_data(self): latent_slice = latent[window_start:window_start+self.window].float() mouse_slice = mouse[window_start:window_start+self.window] button_slice = button[window_start:window_start+self.window] - - self.data_queue.add((latent_slice, mouse_slice, button_slice)) + + if self.include_keyframe: + # Sample keyframe from nearby in video but not in window + buffer = 400 + valid_range_start = max(0, window_start - buffer) + valid_range_end = min(len(latent), window_start + self.window + buffer) + + # Exclude the actual window frames + valid_frames = list(range(valid_range_start, window_start)) + \ + list(range(window_start + self.window, valid_range_end)) + + if valid_frames: + keyframe_idx = random.choice(valid_frames) + latent_keyframe = latent[keyframe_idx].float().unsqueeze(0) + self.data_queue.add((latent_slice, latent_keyframe, mouse_slice, button_slice)) + else: + self.data_queue.add((latent_slice, mouse_slice, button_slice)) except Exception as e: print(f"Error processing tar: {e}") @@ -157,12 +174,25 @@ def __iter__(self): time.sleep(0.1) def collate_fn(batch): - # batch is list of triples - latents, mouses, buttons = zip(*batch) - latents = torch.stack(latents) # [b,n,c,h,w] - mouses = torch.stack(mouses) # [b,n,2] - buttons = torch.stack(buttons) # [b,n,n_buttons] - return latents, mouses, buttons + # batch is list of triples or quads + items = zip(*batch) + items = list(items) + + if len(items) == 3: + # No keyframe case + latents, mouses, buttons = items + latents = torch.stack(latents) # [b,n,c,h,w] + mouses = torch.stack(mouses) # [b,n,2] + buttons = torch.stack(buttons) # [b,n,n_buttons] + return latents, mouses, buttons + else: + # With keyframe case + latents, keyframes, mouses, buttons = items + latents = torch.stack(latents) # [b,n,c,h,w] + keyframes = torch.stack(keyframes) # [b,1,c,h,w] + mouses = torch.stack(mouses) # [b,n,2] + buttons = torch.stack(buttons) # [b,n,n_buttons] + return latents, keyframes, mouses, buttons def get_loader(batch_size, **data_kwargs): if dist.is_initialized(): diff --git a/owl_wms/data/s3_cod_latent_audio.py b/owl_wms/data/s3_cod_latent_audio.py new file mode 100644 index 00000000..4c401b92 --- /dev/null +++ b/owl_wms/data/s3_cod_latent_audio.py @@ -0,0 +1,213 @@ +import boto3 +import threading +from dotenv import load_dotenv +import os + +load_dotenv() + +import torch +import random +from torch.utils.data import IterableDataset, DataLoader +import torch.distributed as dist +import tarfile +import io +import time + +class RandomizedQueue: + def __init__(self): + self.items = [] + + def add(self, item): + idx = random.randint(0, len(self.items)) + self.items.insert(idx, item) + + def pop(self): + if not self.items: + return None + idx = random.randint(0, len(self.items) - 1) + return self.items.pop(idx) + +TOTAL_SHARDS = 1 +NUM_SUBDIRS=1 +NUM_TARS=9 +BUCKET_NAME="cod-data-latent-360x640to4x4" + +class S3CoDLatentAudioDataset(IterableDataset): + def __init__(self, window_length=120, file_share_max=20, rank=0, world_size=1, bucket_name = BUCKET_NAME): + super().__init__() + + self.window = window_length + self.file_share_max = file_share_max + self.rank = rank + self.world_size = world_size + self.bucket_name = bucket_name + + # Queue parameters + self.max_tars = 2 + self.max_data = 1000 + + # Initialize queues + self.tar_queue = RandomizedQueue() + self.data_queue = RandomizedQueue() + + # Setup S3 client + self.s3_client = boto3.client( + 's3', + endpoint_url=os.environ['AWS_ENDPOINT_URL_S3'], + aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], + aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], + region_name=os.environ['AWS_REGION'], + ) + + # Start background threads + self.tar_thread = threading.Thread(target=self.background_download_tars, daemon=True) + self.data_thread = threading.Thread(target=self.background_load_data, daemon=True) + self.tar_thread.start() + self.data_thread.start() + + def random_sample_prefix(self): + # For now just 2 shards (00, 01) + shard = random.randint(0, TOTAL_SHARDS-1) + # Each shard has 1000 subdirs + subdir = random.randint(0, NUM_SUBDIRS-1) + # Each subdir has multiple tars + tar_num = random.randint(0, NUM_TARS-1) + return f"{shard:02d}/{subdir:04d}/{tar_num:04d}.tar" + + def background_download_tars(self): + while True: + if len(self.tar_queue.items) < self.max_tars: + tar_path = self.random_sample_prefix() + try: + # Download tar directly to memory + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=tar_path) + tar_data = response['Body'].read() + self.tar_queue.add(tar_data) + except Exception as e: + print(f"Error downloading tar {tar_path}: {e}") + else: + time.sleep(1) + + def process_tensor_file(self, tar, base_name, suffix): + try: + f = tar.extractfile(f"{base_name}.{suffix}.pt") + if f is not None: + tensor_data = f.read() + tensor = torch.load(io.BytesIO(tensor_data)) + return tensor + except: + return None + return None + + def background_load_data(self): + while True: + if len(self.data_queue.items) < self.max_data: + tar_data = self.tar_queue.pop() + if tar_data is None: + time.sleep(1) + continue + + try: + tar_file = io.BytesIO(tar_data) + with tarfile.open(fileobj=tar_file) as tar: + members = tar.getmembers() + base_names = set() + + # Get unique base names + for member in members: + if member.name.endswith('.latent.pt'): + base_names.add(member.name.split('.')[0]) + + for base_name in base_names: + # Load all tensors for this base name + latent = self.process_tensor_file(tar, base_name, "latent") + mouse = self.process_tensor_file(tar, base_name, "mouse") + button = self.process_tensor_file(tar, base_name, "buttons") + audio = self.process_tensor_file(tar, base_name, "audiolatent") + + if all(t is not None for t in [latent, mouse, button, audio]): + min_len = min(len(latent), len(mouse), len(button), len(audio)) + + # Sample multiple windows if requested + for _ in range(self.file_share_max): + if len(self.data_queue.items) >= self.max_data: + break + + max_start = min_len - self.window + if max_start <= 0: + continue + + window_start = random.randint(0, max_start) + + latent_slice = latent[window_start:window_start+self.window].float() + mouse_slice = mouse[window_start:window_start+self.window] + button_slice = button[window_start:window_start+self.window] + audio_slice = audio[window_start:window_start+self.window] + + self.data_queue.add((latent_slice, mouse_slice, button_slice, audio_slice)) + + except Exception as e: + print(f"Error processing tar: {e}") + else: + time.sleep(1) + + def __iter__(self): + while True: + item = self.data_queue.pop() + if item is not None: + yield item + else: + time.sleep(0.1) + +def collate_fn(batch): + # batch is list of quadruples + latents, mouses, buttons, audios = zip(*batch) + + latents = torch.stack(latents) # [b,n,c,h,w] + mouses = torch.stack(mouses) # [b,n,2] + buttons = torch.stack(buttons) # [b,n,n_buttons] + audios = torch.stack(audios) # [b,n,d] + + return latents, audios, mouses, buttons + +def get_loader(batch_size, **data_kwargs): + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + ds = S3CoDLatentAudioDataset(rank=rank, world_size=world_size, **data_kwargs) + return DataLoader(ds, batch_size=batch_size, collate_fn=collate_fn) + +if __name__ == "__main__": + import time + loader = get_loader(1, window_length = 30, file_share_max = 20) + + start = time.time() + batch = next(iter(loader)) + video_latent, audio_latent, mouse, button = batch + for i in range(2, 6): + BASE_IDX=i + basedir = f"/home/louis/owl-wms/webapp/static/histories/base{BASE_IDX}" + os.makedirs(basedir, exist_ok=True) + torch.save(video_latent, f"/home/louis/owl-wms/webapp/static/histories/base{BASE_IDX}/video_latent.pt") + torch.save(audio_latent, f"/home/louis/owl-wms/webapp/static/histories/base{BASE_IDX}/audio_latent.pt") + torch.save(mouse, f"/home/louis/owl-wms/webapp/static/histories/base{BASE_IDX}/mouse.pt") + torch.save(button, f"/home/louis/owl-wms/webapp/static/histories/base{BASE_IDX}/buttons.pt") + end = time.time() + first_time = end - start + + start = time.time() + batch = next(iter(loader)) + end = time.time() + second_time = end - start + + x,y,z = batch + print(f"Time to load first batch: {first_time:.2f}s") + print(f"Time to load second batch: {second_time:.2f}s") + print(f"Video shape: {x.shape}") + print(x.std()) + print(f"Mouse shape: {y.shape}") + print(f"Button shape: {z.shape}") diff --git a/owl_wms/models/__init__.py b/owl_wms/models/__init__.py index 34fcc572..e1b51392 100644 --- a/owl_wms/models/__init__.py +++ b/owl_wms/models/__init__.py @@ -1,7 +1,18 @@ -from .gamerft import GameRFT - def get_model_cls(model_id): if model_id == "game_rft": + from .gamerft import GameRFT + return GameRFT + if model_id == "game_rft_shortcut": + from .gamerft_shortcut import ShortcutGameRFT + return ShortcutGameRFT + if model_id == "causal_game_rft": + from .causal_gamerft import CausalGameRFT + return CausalGameRFT + if model_id == "shortcut_2": + from .gamerft_shortcut_simple import ShortcutGameRFT + return ShortcutGameRFT + if model_id == "game_rft_audio": + from .gamerft_audio import GameRFT return GameRFT diff --git a/owl_wms/models/causal_gamerft.py b/owl_wms/models/causal_gamerft.py new file mode 100644 index 00000000..8d1de154 --- /dev/null +++ b/owl_wms/models/causal_gamerft.py @@ -0,0 +1,298 @@ +""" +Causal GameRFT model with KV cache support for efficient autoregressive generation +""" + +import torch +from torch import nn +import torch.nn.functional as F +import einops as eo + +from ..nn.embeddings import ( + TimestepEmbedding, + ControlEmbedding, + LearnedPosEnc +) +from ..nn.attn import UViT, FinalLayer +from ..nn.kv_cache import KVCache + +class CausalGameRFTCore(nn.Module): + """ + Core model with causal attention support and KV caching + """ + def __init__(self, config): + super().__init__() + + # Modify config for causal attention + self.causal = config.causal + self.tokens_per_frame = config.tokens_per_frame + self.n_frames = config.n_frames + + # Initialize transformer with causal flag + config_copy = deepcopy(config) + config_copy.causal = self.causal + self.transformer = UViT(config_copy) + + self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) + self.t_embed = TimestepEmbedding(config.d_model) + + self.proj_in = nn.Linear(config.channels, config.d_model, bias=False) + self.proj_out = FinalLayer(config.sample_size, config.d_model, config.channels) + + self.pos_enc = LearnedPosEnc(config.tokens_per_frame * config.n_frames, config.d_model) + + # For caching context frames + self.cached_frames = None + self.cached_frame_count = 0 + + def forward(self, x, t, mouse, btn, kv_cache=None, use_cache=False): + """ + Forward pass with optional KV caching + + Args: + x: Input frames [b,n,c,h,w] + t: Timesteps [b,n] + mouse: Mouse inputs [b,n,2] + btn: Button inputs [b,n,n_buttons] + kv_cache: Optional KVCache object for caching + use_cache: Whether to use cached context + """ + # Handle caching for autoregressive generation + if use_cache and self.cached_frames is not None: + # Concatenate cached frames with new input + x = torch.cat([self.cached_frames, x], dim=1) + t = torch.cat([self.cached_timesteps, t], dim=1) + mouse = torch.cat([self.cached_mouse, mouse], dim=1) + btn = torch.cat([self.cached_btn, btn], dim=1) + + # Update cache with new frames + self.cached_frames = x + self.cached_timesteps = t + self.cached_mouse = mouse + self.cached_btn = btn + elif use_cache: + # Initialize cache + self.cached_frames = x + self.cached_timesteps = t + self.cached_mouse = mouse + self.cached_btn = btn + + # Standard forward pass + ctrl_cond = self.control_embed(mouse, btn) + t_cond = self.t_embed(t) + + cond = ctrl_cond + t_cond # [b,n,d] + + b, n, c, h, w = x.shape + x = eo.rearrange(x, 'b n c h w -> b (n h w) c') + + x = self.proj_in(x) + x = self.pos_enc(x) + + # Pass KV cache to transformer if provided + x = self.transformer(x, cond, kv_cache=kv_cache) + + x = self.proj_out(x, cond) # -> [b,n*hw,c] + x = eo.rearrange(x, 'b (n h w) c -> b n c h w', n=n, h=h, w=w) + + # If using cache, only return the newly generated frames + if use_cache and self.cached_frame_count > 0: + x = x[:, -1:] # Return only last frame + + return x + + def reset_cache(self): + """Reset the cached context frames""" + self.cached_frames = None + self.cached_timesteps = None + self.cached_mouse = None + self.cached_btn = None + self.cached_frame_count = 0 + +class CausalGameRFT(nn.Module): + """ + Causal GameRFT model for autoregressive video generation + """ + def __init__(self, config): + super().__init__() + + self.core = CausalGameRFTCore(config) + self.cfg_prob = config.cfg_prob + self.causal = config.causal + + def forward(self, x, mouse, btn, return_dict=False, cfg_prob=None, kv_cache=None, use_cache=False): + """ + Forward pass with diffusion loss computation + + For training: Standard diffusion loss + For generation: Can use KV cache for efficiency + """ + b, n, c, h, w = x.shape + + # Apply classifier-free guidance dropout + if cfg_prob is None: + cfg_prob = self.cfg_prob + if cfg_prob > 0.0 and self.training: + mask = torch.rand(b, device=x.device) <= cfg_prob + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + # Where mask is True, replace with zeros + mouse = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_mouse, mouse) + btn = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_btn, btn) + + if self.training: + # Standard diffusion training + with torch.no_grad(): + ts = torch.rand(b, n, device=x.device, dtype=x.dtype).sigmoid() + + ts_exp = eo.repeat(ts, 'b n -> b n 1 1 1') + z = torch.randn_like(x) + + lerpd = x * (1. - ts_exp) + z * ts_exp + target = z - x + + pred = self.core(lerpd, ts, mouse, btn) + diff_loss = F.mse_loss(pred, target) + + if not return_dict: + return diff_loss + else: + return { + 'diffusion_loss': diff_loss, + 'lerpd': lerpd, + 'pred': pred, + 'ts': ts, + 'z': z + } + else: + # Generation mode - can use KV cache + return self.core(x, torch.zeros(b, n, device=x.device), mouse, btn, + kv_cache=kv_cache, use_cache=use_cache) + + def generate_next_frame(self, context, mouse, btn, num_steps=50, cfg_scale=1.3, kv_cache=None): + """ + Generate the next frame given context frames + + Args: + context: Context frames [b, n_context, c, h, w] + mouse: Mouse inputs for context + 1 new frame [b, n_context+1, 2] + btn: Button inputs for context + 1 new frame [b, n_context+1, n_buttons] + num_steps: Number of denoising steps + cfg_scale: Classifier-free guidance scale + kv_cache: Optional KV cache for efficiency + + Returns: + next_frame: Generated next frame [b, 1, c, h, w] + """ + b = context.shape[0] + device = context.device + + # Initialize next frame with noise + next_frame = torch.randn(b, 1, *context.shape[2:], device=device) + + # Combine context and noisy next frame + full_sequence = torch.cat([context, next_frame], dim=1) + + # Denoising loop + for step in range(num_steps): + t = (1.0 - step / num_steps) * torch.ones(b, full_sequence.shape[1], device=device) + + # Zero timestep for context frames (they're clean) + t[:, :-1] = 0.0 + + with torch.no_grad(): + # Conditional prediction + pred_cond = self.core(full_sequence, t, mouse, btn, kv_cache=kv_cache) + + if cfg_scale > 1.0: + # Unconditional prediction for CFG + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + pred_uncond = self.core(full_sequence, t, null_mouse, null_btn, kv_cache=kv_cache) + + # Apply CFG + pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond) + else: + pred = pred_cond + + # Update only the last frame + noise_level = t[:, -1:, None, None, None] + full_sequence[:, -1:] = full_sequence[:, -1:] - pred[:, -1:] * noise_level / num_steps + + return full_sequence[:, -1:] + + def generate_sequence(self, initial_frame, mouse, btn, num_frames, + num_steps=50, cfg_scale=1.3, use_kv_cache=True): + """ + Generate a full sequence autoregressively + + Args: + initial_frame: Starting frame [b, 1, c, h, w] + mouse: Mouse inputs for full sequence [b, num_frames, 2] + btn: Button inputs for full sequence [b, num_frames, n_buttons] + num_frames: Number of frames to generate + num_steps: Denoising steps per frame + cfg_scale: CFG scale + use_kv_cache: Whether to use KV caching for efficiency + + Returns: + sequence: Generated sequence [b, num_frames, c, h, w] + """ + generated = [initial_frame] + + # Initialize KV cache if requested + kv_cache = KVCache(self.core.transformer.config) if use_kv_cache else None + if kv_cache: + kv_cache.reset(initial_frame.shape[0]) + + # Reset any internal caching + self.core.reset_cache() + + for frame_idx in range(1, num_frames): + # Get context and actions up to current frame + context = torch.cat(generated, dim=1) + frame_mouse = mouse[:, :frame_idx+1] + frame_btn = btn[:, :frame_idx+1] + + # Generate next frame + next_frame = self.generate_next_frame( + context, frame_mouse, frame_btn, + num_steps=num_steps, cfg_scale=cfg_scale, + kv_cache=kv_cache + ) + + generated.append(next_frame) + + return torch.cat(generated, dim=1) + + +if __name__ == "__main__": + from ..configs import Config + from copy import deepcopy + + # Test causal model + cfg = Config.from_yaml("configs/basic.yml").model + cfg.causal = True + model = CausalGameRFT(cfg).cuda().bfloat16() + + # Test training forward pass + with torch.no_grad(): + x = torch.randn(2, 30, 128, 4, 4, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(2, 30, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(2, 30, 11, device='cuda', dtype=torch.bfloat16) + + loss = model(x, mouse, btn) + print(f"Training loss: {loss.item()}") + + # Test generation + model.eval() + with torch.no_grad(): + initial = torch.randn(2, 1, 128, 4, 4, device='cuda', dtype=torch.bfloat16) + mouse_seq = torch.randn(2, 10, 2, device='cuda', dtype=torch.bfloat16) + btn_seq = torch.randn(2, 10, 11, device='cuda', dtype=torch.bfloat16) + + generated = model.generate_sequence( + initial, mouse_seq, btn_seq, + num_frames=10, num_steps=20, cfg_scale=1.3 + ) + print(f"Generated sequence shape: {generated.shape}") \ No newline at end of file diff --git a/owl_wms/models/gamerft.py b/owl_wms/models/gamerft.py index ade60112..515a3047 100644 --- a/owl_wms/models/gamerft.py +++ b/owl_wms/models/gamerft.py @@ -18,7 +18,7 @@ class GameRFTCore(nn.Module): def __init__(self, config): super().__init__() - + self.config = config self.transformer = UViT(config) self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) self.t_embed = TimestepEmbedding(config.d_model) @@ -57,15 +57,17 @@ def __init__(self, config): self.core = GameRFTCore(config) self.cfg_prob = config.cfg_prob - def forward(self, x, mouse, btn): + def forward(self, x, mouse, btn, return_dict = False, cfg_prob = None): # x is [b,n,c,h,w] # mouse is [b,n,2] # btn is [b,n,n_buttons] b,n,c,h,w = x.shape # Apply classifier-free guidance dropout - if self.cfg_prob > 0.: - mask = torch.rand(b, device=x.device) < self.cfg_prob + if cfg_prob is None: + cfg_prob = self.cfg_prob + if cfg_prob > 0.0: + mask = torch.rand(b, device=x.device) <= self.cfg_prob null_mouse = torch.zeros_like(mouse) null_btn = torch.zeros_like(btn) @@ -85,7 +87,16 @@ def forward(self, x, mouse, btn): pred = self.core(lerpd, ts, mouse, btn) diff_loss = F.mse_loss(pred, target) - return diff_loss + if not return_dict: + return diff_loss + else: + return { + 'diffusion_loss' : diff_loss, + 'lerpd' : lerpd, + 'pred' : pred, + 'ts': ts, + 'z': z + } if __name__ == "__main__": from ..configs import Config diff --git a/owl_wms/models/gamerft_audio.py b/owl_wms/models/gamerft_audio.py new file mode 100644 index 00000000..d28026b3 --- /dev/null +++ b/owl_wms/models/gamerft_audio.py @@ -0,0 +1,148 @@ +""" +GameRFT with Audio +""" + +import torch +from torch import nn +import torch.nn.functional as F + +import einops as eo + +from ..nn.embeddings import ( + TimestepEmbedding, + ControlEmbedding, + LearnedPosEnc +) +from ..nn.attn import UViT, FinalLayer + +class GameRFTCore(nn.Module): + def __init__(self, config): + super().__init__() + + self.transformer = UViT(config) + self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) + self.t_embed = TimestepEmbedding(config.d_model) + + self.proj_in = nn.Linear(config.channels, config.d_model, bias = False) + self.proj_out = FinalLayer(config.sample_size, config.d_model, config.channels) + + self.audio_proj_in = nn.Linear(config.audio_channels, config.d_model, bias=False) + self.audio_proj_out = FinalLayer(None, config.d_model, config.audio_channels) + + self.pos_enc = LearnedPosEnc(config.tokens_per_frame * config.n_frames, config.d_model) + + def forward(self, x, audio, t, mouse, btn): + # x is [b,n,c,h,w] + # audio is [b,n,c] + # t is [b,n] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + + ctrl_cond = self.control_embed(mouse, btn) + t_cond = self.t_embed(t) + + cond = ctrl_cond + t_cond # [b,n,d] + + b,n,c,h,w = x.shape + x = eo.rearrange(x, 'b n c h w -> b (n h w) c') + + x = self.proj_in(x) + audio = self.audio_proj_in(audio).unsqueeze(-2) # [b,n,1,d] + + x = eo.rearrange(x, 'b (n f) d -> b n f d', n = n) + x = torch.cat([x, audio], dim = -2) + x = eo.rearrange(x, 'b n f d -> b (n f) d') + + x = self.pos_enc(x) + x = self.transformer(x, cond) + + # Split into video and audio tokens + x = eo.rearrange(x, 'b (n f) d -> b n f d', n=n) + video, audio = x[...,:-1,:], x[...,-1:,:] + + # Project video tokens + video = eo.rearrange(video, 'b n f d -> b (n f) d') + video = self.proj_out(video, cond) + video = eo.rearrange(video, 'b (n h w) c -> b n c h w', n=n, h=h, w=w) + + # Project audio tokens + audio = eo.rearrange(audio, 'b n 1 d -> b n d') + audio = self.audio_proj_out(audio, cond) + + return video, audio + +class GameRFT(nn.Module): + def __init__(self, config): + super().__init__() + + self.core = GameRFTCore(config) + self.cfg_prob = config.cfg_prob + + def forward(self, x, audio, mouse, btn, return_dict = False, cfg_prob = None): + # x is [b,n,c,h,w] + # audio is [b,n,c] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + b,n,c,h,w = x.shape + + # Apply classifier-free guidance dropout + if cfg_prob is None: + cfg_prob = self.cfg_prob + if cfg_prob > 0.0: + mask = torch.rand(b, device=x.device) <= self.cfg_prob + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + # Where mask is True, replace with zeros + mouse = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_mouse, mouse) + btn = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_btn, btn) + + with torch.no_grad(): + ts = torch.randn(b,n,device=x.device,dtype=x.dtype).sigmoid() + + # Video noise + ts_exp = eo.repeat(ts, 'b n -> b n 1 1 1') + z_video = torch.randn_like(x) + lerpd_video = x * (1. - ts_exp) + z_video * ts_exp + target_video = z_video - x + + # Audio noise + ts_exp_audio = ts.unsqueeze(-1) + z_audio = torch.randn_like(audio) + lerpd_audio = audio * (1. - ts_exp_audio) + z_audio * ts_exp_audio + target_audio = z_audio - audio + + pred_video, pred_audio = self.core(lerpd_video, lerpd_audio, ts, mouse, btn) + video_loss = F.mse_loss(pred_video, target_video) + audio_loss = F.mse_loss(pred_audio, target_audio) + diff_loss = video_loss + audio_loss + + if not return_dict: + return diff_loss + else: + return { + 'diffusion_loss': diff_loss, + 'video_loss': video_loss, + 'audio_loss': audio_loss, + 'lerpd_video': lerpd_video, + 'lerpd_audio': lerpd_audio, + 'pred_video': pred_video, + 'pred_audio': pred_audio, + 'ts': ts, + 'z_video': z_video, + 'z_audio': z_audio + } + +if __name__ == "__main__": + from ..configs import Config + + cfg = Config.from_yaml("configs/basic.yml").model + model = GameRFT(cfg).cuda().bfloat16() + + with torch.no_grad(): + x = torch.randn(1, 128, 16, 256, device='cuda', dtype=torch.bfloat16) + mouse = torch.randn(1, 128, 2, device='cuda', dtype=torch.bfloat16) + btn = torch.randn(1, 128, 11, device='cuda', dtype=torch.bfloat16) + + loss = model(x, mouse, btn) + print(f"Loss: {loss.item()}") \ No newline at end of file diff --git a/owl_wms/models/gamerft_shortcut.py b/owl_wms/models/gamerft_shortcut.py new file mode 100644 index 00000000..d3f56c7d --- /dev/null +++ b/owl_wms/models/gamerft_shortcut.py @@ -0,0 +1,300 @@ +""" +Causal-First RFT With Shortcut objective +""" + +import torch +from torch import nn +import torch.nn.functional as F + +import einops as eo + +from ..nn.embeddings import ( + TimestepEmbedding, + StepEmbedding, + ControlEmbedding, + LearnedPosEnc +) +from ..nn.attn import UViT, FinalLayer +from ..nn.mmattn import MMUViT +from ..utils import freeze + +class ShortcutGameRFTCore(nn.Module): + def __init__(self, config): + super().__init__() + + self.transformer = MMUViT(config) + self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) + + self.step_embed = StepEmbedding(config.d_model) + self.t_embed = TimestepEmbedding(config.d_model) + + self.proj_in = nn.Linear(config.channels, config.d_model, bias = False) + self.proj_out = FinalLayer(config.sample_size, config.d_model, config.channels) + + self.proj_y_in = nn.Linear(config.channels, config.d_model, bias = False) + self.pos_enc_y = LearnedPosEnc(config.tokens_per_frame, config.d_model) + + self.config = config + + def sample(self, x, y, mouse, btn, kv_cache = None, t = None, d = None): + """ + This is a function that largely abstracts + away most things for the specific case where + you are only generating the one next token + + The return is one step sample always + """ + if x is None: + x = torch.randn_like(y) + + b,n,c,h,w = x.shape + if t is None: + t = torch.ones_like(x[:,:,0,0,0]) + if d is None: + d = torch.ones_like(x[:,:,0,0,0]) + + return x - self.forward(x, y, t, mouse, btn, d, kv_cache) + + def forward(self, x, y, t, mouse, btn, d, kv_cache = None): + # x is [b,n,c,h,w] + # y is [b,1,c,h,w] + # t is [b,n] + # d is [b,n] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + + ctrl_cond = self.control_embed(mouse, btn) + t_cond = self.t_embed(t) + d_cond = self.step_embed(d) + + cond = ctrl_cond + t_cond + d_cond # [b,n,d] + + b,n,c,h,w = x.shape + x = eo.rearrange(x, 'b n c h w -> b (n h w) c') + y = eo.rearrange(y, 'b n c h w -> b (n h w) c') + + x = self.proj_in(x) + + y = self.proj_y_in(y) + y = self.pos_enc_y(y) + + x = self.transformer(x, y, cond, kv_cache) + x = self.proj_out(x, cond) # -> [b,n*hw,c] + x = eo.rearrange(x, 'b (n h w) c -> b n c h w', n=n,h=h,w=w) + + return x + +def sample_discrete_timesteps(steps, eps = 1.0e-6): + # steps is Tensor([1,4,2,64,16]) as an example + b,n = steps.shape + + ts_list = [] + ts = torch.rand(b, n, device=steps.device, dtype=steps.dtype) * (steps - eps) + ts = ts.clamp(eps).ceil() / steps + """ + Example, if d was all 2, ts would be [0,2] + so do clamp, then ceil will be 1 or 2 (0, 2] + then do t / 2 and get 0.5 or 1.0, our desired timesteps + """ + return ts + +def sample_steps(b, n, device, dtype, min_val = 0): + valid = torch.tensor([2**i for i in range(min_val, 8)]) # [1,2,...,128] + inds = torch.randint(low=0,high=len(valid), size = (b,n)) + steps = valid[inds].to(device=device,dtype=dtype) + return steps + +#@torch.compile() +@torch.no_grad() +def get_sc_targets(ema, x, y, mouse, btn, cfg_scale): + steps_slow = sample_steps(x.shape[0], x.shape[1], x.device, x.dtype, min_val = 1) + steps_fast = steps_slow / 2 + + dt_slow = 1./steps_slow + dt_fast = 1./steps_fast + + def expand(t): + #b,c,h,w = x.shape + #t = eo.repeat(t,'b -> b c h w',c=c,h=h,w=w) + #return t + return t[:,:,None,None,None] + + ts = sample_discrete_timesteps(steps_fast) + cfg_mask = torch.isclose(steps_slow, torch.ones_like(steps_slow)*128) + cfg_mask = expand(cfg_mask) # -> [b,n,1,1,1] + + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + pred_1_uncond = ema(x, y, ts, null_mouse, null_btn, steps_slow) + pred_1_cond = ema(x, y, ts, mouse, btn, steps_slow) + pred_1_cfg = pred_1_uncond + cfg_scale * (pred_1_cond - pred_1_uncond) + pred_1 = torch.where(cfg_mask, pred_1_cfg, pred_1_cond) + + x_new = x - pred_1 * expand(dt_slow) + ts_new = ts - dt_slow + + pred_2_uncond = ema(x_new, y, ts_new, null_mouse, null_btn, steps_slow) + pred_2_cond = ema(x_new, y, ts_new, mouse, btn, steps_slow) + pred_2_cfg = pred_2_uncond + cfg_scale * (pred_2_cond - pred_2_uncond) + pred_2 = torch.where(cfg_mask, pred_2_cfg, pred_2_cond) + + pred = 0.5 * (pred_1 + pred_2) + return pred, steps_fast, ts + +class ShortcutGameRFT(nn.Module): + def __init__(self, config): + super().__init__() + + self.core = ShortcutGameRFTCore(config) + self.cfg_prob = config.cfg_prob + + self.sc_frac = 0.25 + self.sc_max_steps = 128 + self.cfg_scale = 1.3 + + self.config = config + + def get_sc_loss(self, x, y, mouse, btn, ema): + target, steps, ts = get_sc_targets(ema, x, y, mouse, btn, self.cfg_scale) + pred = self.core(x, y, ts, mouse, btn, steps) + sc_loss = F.mse_loss(pred, target) + return sc_loss + + def forward(self, x, y, mouse, btn, ema): + # x is [b,n,c,h,w] + # y (seed frame) is [b,1,c,h,w] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + with torch.no_grad(): + _,n,c,h,w = x.shape + + # Split batches between consistency/rf + b = int(len(x) * (1 - self.sc_frac)) + x,x_sc = x[:b], x[b:] + y,y_sc = y[:b], y[b:] + mouse,mouse_sc = mouse[:b], mouse[b:] + btn,btn_sc = btn[:b], btn[b:] + + # Apply classifier-free guidance dropout + if self.cfg_prob > 0.0: + mask = torch.rand(b, device=x.device) <= self.cfg_prob + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + # Where mask is True, replace with zeros + mouse = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_mouse, mouse) + btn = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_btn, btn) + + d = torch.ones_like(x[:,:,0,0,0])*self.sc_max_steps + ts = sample_discrete_timesteps(d) + ts = torch.randn(b,n,device=x.device,dtype=x.dtype).sigmoid() + + ts_exp = eo.repeat(ts, 'b n -> b n 1 1 1') + z = torch.randn_like(x) + + lerpd = x * (1. - ts_exp) + z * ts_exp + target = z - x + + pred = self.core(lerpd, y, ts, mouse, btn, d) + diff_loss = F.mse_loss(pred, target) + sc_loss = self.get_sc_loss(x_sc, y_sc, mouse_sc, btn_sc, ema) + + return diff_loss, sc_loss + +def test_inference_cache(): + from ..configs import TransformerConfig + from ..nn.kv_cache import KVCache + + cfg = TransformerConfig( + None, # model_id + 6, # n_layers + 6, # n_heads + 384, # d_model + 1, # patch_size + 128, # channels + 16, # sample_size + 0.1, # cfg_prob + 11, # n_buttons + 16, # tokens_per_frame + 10, # n_frames + True # causal + ) + + model = ShortcutGameRFTCore(cfg).bfloat16().cuda() + + NUM_FRAMES = 10 + x = torch.randn(1, NUM_FRAMES, 128, 4, 4).bfloat16().cuda() + y = torch.randn(1, 1, 128, 4, 4).bfloat16().cuda() + mouse = torch.randn(1, NUM_FRAMES, 2).bfloat16().cuda() + btn = torch.randn(1, NUM_FRAMES, 11).bfloat16().cuda() + t = torch.full((1, NUM_FRAMES), 0.25, device='cuda', dtype=torch.bfloat16) + d = torch.full((1, NUM_FRAMES), 4, device='cuda', dtype=torch.bfloat16) + + cache = KVCache(cfg).to(device='cuda', dtype=torch.bfloat16) + cache.reset(1) + + with torch.no_grad(): + # First pass - generate cache for all frames + cache.enable_cache_updates() + out = model(x, y, t, mouse, btn, d, cache) + print(f"Initial cache length: {len(cache)}") + print(f"Initial cache shape: {cache.cache[0][0].shape}") + + # Generate single new frame with t=1, d=1 + new_x = torch.randn(1, 1, 128, 4, 4).bfloat16().cuda() + new_mouse = torch.randn(1, 1, 2).bfloat16().cuda() + new_btn = torch.randn(1, 1, 11).bfloat16().cuda() + new_t = torch.ones(1, 1, device='cuda', dtype=torch.bfloat16) + new_d = torch.ones(1, 1, device='cuda', dtype=torch.bfloat16) + + # Disable cache updates for inference + cache.disable_cache_updates() + new_out = model(new_x, y, new_t, new_mouse, new_btn, new_d, cache) + print(f"After inference cache length: {len(cache)}") + print(f"After inference cache shape: {cache.cache[0][0].shape}") + + # Re-enable cache updates and update cache with t=0.25, d=4 + cache.enable_cache_updates() + new_t = torch.full((1, 1), 0.25, device='cuda', dtype=torch.bfloat16) + new_d = torch.full((1, 1), 4, device='cuda', dtype=torch.bfloat16) + new_out = model(new_x, y, new_t, new_mouse, new_btn, new_d, cache) + print(f"Final cache length: {len(cache)}") + print(f"Final cache shape: {cache.cache[0][0].shape}") + +def test_wrapper(): + from ..configs import TransformerConfig + from ema_pytorch import EMA + from copy import deepcopy + + cfg = TransformerConfig( + None, # model_id + 6, # n_layers + 6, # n_heads + 384, # d_model + 1, # patch_size + 128, # channels + 16, # sample_size + 0.1, # cfg_prob + 11, # n_buttons + 16, # tokens_per_frame + 10, # n_frames + True # causal + ) + + model = ShortcutGameRFT(cfg).bfloat16().cuda() + ema = EMA(model, beta=0.999,update_after_step=0,update_every=1) + model.set_ema(ema) + + NUM_FRAMES = 10 + x = torch.randn(4, NUM_FRAMES, 128, 4, 4).bfloat16().cuda() + y = torch.randn(4, 1, 128, 4, 4).bfloat16().cuda() + mouse = torch.randn(4, NUM_FRAMES, 2).bfloat16().cuda() + btn = torch.randn(4, NUM_FRAMES, 11).bfloat16().cuda() + + with torch.no_grad(): + loss_1, loss_2 = model(x, y, mouse, btn) + print(loss_1, loss_2) + +if __name__ == "__main__": + test_wrapper() diff --git a/owl_wms/models/gamerft_shortcut_audio.py b/owl_wms/models/gamerft_shortcut_audio.py new file mode 100644 index 00000000..33b47a4f --- /dev/null +++ b/owl_wms/models/gamerft_shortcut_audio.py @@ -0,0 +1,220 @@ +""" +Shortcut simple with audio +""" + +import torch +from torch import nn +import torch.nn.functional as F + +import einops as eo + +from ..nn.embeddings import ( + TimestepEmbedding, + StepEmbedding, + ControlEmbedding, + LearnedPosEnc +) +from ..nn.attn import UViT, FinalLayer +from ..nn.mmattn import MMUViT +from ..utils import freeze + +class ShortcutGameRFTCore(nn.Module): + def __init__(self, config): + super().__init__() + + self.transformer = UViT(config) + self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) + + self.step_embed = StepEmbedding(config.d_model) + self.t_embed = TimestepEmbedding(config.d_model) + + self.proj_in = nn.Linear(config.channels, config.d_model, bias = False) + self.proj_out = FinalLayer(config.sample_size, config.d_model, config.channels) + + self.audio_proj = nn.Linear(config.audio_channels, config.d_model, bias = False) + self.audio_proj_out = nn.Linear(config.d_model, config.audio_channels, bias = False) + + self.config = config + + def sample(self, x, audio, mouse, btn, kv_cache = None, t = None, d = None): + """ + This is a function that largely abstracts + away most things for the specific case where + you are only generating the one next token + + The return is one step sample always + """ + + b,n,c,h,w = x.shape + if t is None: + t = torch.ones_like(x[:,:,0,0,0]) + if d is None: + d = torch.ones_like(x[:,:,0,0,0]) + + pred_x, pred_audio = self.forward(x, audio, t, mouse, btn, d, kv_cache) + return x - pred_x, audio - pred_audio + + def forward(self, x, audio, t, mouse, btn, d, kv_cache = None): + # x is [b,n,c,h,w] + # a is [b,c,n] + # t is [b,n] + # d is [b,n] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + + ctrl_cond = self.control_embed(mouse, btn) + t_cond = self.t_embed(t) + d_cond = self.step_embed(d) + + cond = ctrl_cond + t_cond + d_cond # [b,n,d] + + audio = self.audio_proj(audio.transpose(-1,-2))[:,:,None] # -> [b,n,1,d] + + b,n,c,h,w = x.shape + x = eo.rearrange(x, 'b n c h w -> b (n h w) c') + x = self.proj_in(x) + x = eo.rearrange(x, 'b (n f) d -> b n f d', n = n) + x = torch.cat([x, audio], dim=-2) # [b,n,f,d] + x = eo.rearrange(x, 'b n f d -> b (n f) d') + + x = self.transformer(x, cond, kv_cache) + + x = eo.rearrange(x, 'b (n f) c -> b n f w', n=n) + audio = x[:,:,-1] # [b,n,d] + x = x[:,:,:-1] # [b,n,f,d] + + x = eo.rearrange(x, 'b n f d -> b (n f) d') + x = self.proj_out(x, cond) # -> [b,n*hw,c] + x = eo.rearrange(x, 'b (n h w) c -> b n c h w', n=n,h=h,w=w) + audio = self.audio_proj_out(audio).transpose(-1,-2) # [b,d,n] + + return x, audio + +def sample_discrete_timesteps(steps, eps = 1.0e-6): + # steps is Tensor([1,4,2,64,16]) as an example + b,n = steps.shape + + ts_list = [] + ts = torch.rand(b, n, device=steps.device, dtype=steps.dtype) * (steps - eps) + ts = ts.clamp(eps).ceil() / steps + """ + Example, if d was all 2, ts would be [0,2] + so do clamp, then ceil will be 1 or 2 (0, 2] + then do t / 2 and get 0.5 or 1.0, our desired timesteps + """ + return ts + +def sample_steps(b, n, device, dtype, min_val = 0): + valid = torch.tensor([2**i for i in range(min_val, 8)]) # [1,2,...,128] + inds = torch.randint(low=0,high=len(valid), size = (b,n)) + steps = valid[inds].to(device=device,dtype=dtype) + return steps + +#@torch.compile() +@torch.no_grad() +def get_sc_targets(ema, x, audio, mouse, btn, cfg_scale): + steps_slow = sample_steps(x.shape[0], x.shape[1], x.device, x.dtype, min_val = 1) + steps_fast = steps_slow / 2 + + dt_slow = 1./steps_slow + dt_fast = 1./steps_fast + + def expand(t): + #b,c,h,w = x.shape + #t = eo.repeat(t,'b -> b c h w',c=c,h=h,w=w) + #return t + return t[:,:,None,None,None] + + ts = sample_discrete_timesteps(steps_fast) + cfg_mask = torch.isclose(steps_slow, torch.ones_like(steps_slow)*128) + cfg_mask = expand(cfg_mask) # -> [b,n,1,1,1] + + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + pred_1_x_uncond, pred_1_a_uncond = ema(x, audio, ts, null_mouse, null_btn, steps_slow) + pred_1_x_cond, pred_1_a_cond = ema(x, audio, ts, mouse, btn, steps_slow) + pred_1_x_cfg = pred_1_x_uncond + cfg_scale * (pred_1_x_cond - pred_1_x_uncond) + pred_1_a_cfg = pred_1_a_uncond + cfg_scale * (pred_1_a_cond - pred_1_a_uncond) + pred_1_x = torch.where(cfg_mask, pred_1_x_cfg, pred_1_x_cond) + pred_1_a = pred_1_a_cfg if cfg_mask.any() else pred_1_a_cond + + x_new = x - pred_1_x * expand(dt_slow) + audio_new = audio - pred_1_a * dt_slow[:,None,:] + ts_new = ts - dt_slow + + pred_2_x_uncond, pred_2_a_uncond = ema(x_new, audio_new, ts_new, null_mouse, null_btn, steps_slow) + pred_2_x_cond, pred_2_a_cond = ema(x_new, audio_new, ts_new, mouse, btn, steps_slow) + pred_2_x_cfg = pred_2_x_uncond + cfg_scale * (pred_2_x_cond - pred_2_x_uncond) + pred_2_a_cfg = pred_2_a_uncond + cfg_scale * (pred_2_a_cond - pred_2_a_uncond) + pred_2_x = torch.where(cfg_mask, pred_2_x_cfg, pred_2_x_cond) + pred_2_a = pred_2_a_cfg if cfg_mask.any() else pred_2_a_cond + + pred_x = 0.5 * (pred_1_x + pred_2_x) + pred_a = 0.5 * (pred_1_a + pred_2_a) + return (pred_x, pred_a), steps_fast, ts + +class ShortcutGameRFT(nn.Module): + def __init__(self, config): + super().__init__() + + self.core = ShortcutGameRFTCore(config) + self.cfg_prob = config.cfg_prob + + self.sc_frac = 0.25 + self.sc_max_steps = 128 + self.cfg_scale = 1.3 + + self.config = config + + def get_sc_loss(self, x, audio, mouse, btn, ema): + (target_x, target_a), steps, ts = get_sc_targets(ema, x, audio, mouse, btn, self.cfg_scale) + pred_x, pred_a = self.core(x, audio, ts, mouse, btn, steps) + sc_loss_x = F.mse_loss(pred_x, target_x) + sc_loss_a = F.mse_loss(pred_a, target_a) + return sc_loss_x + sc_loss_a + + def forward(self, x, audio, mouse, btn, ema): + # x is [b,n,c,h,w] + # audio is [b,c,n] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + with torch.no_grad(): + _,n,c,h,w = x.shape + + # Split batches between consistency/rf + b = int(len(x) * (1 - self.sc_frac)) + x,x_sc = x[:b], x[b:] + audio,audio_sc = audio[:b], audio[b:] + mouse,mouse_sc = mouse[:b], mouse[b:] + btn,btn_sc = btn[:b], btn[b:] + + # Apply classifier-free guidance dropout + if self.cfg_prob > 0.0: + mask = torch.rand(b, device=x.device) <= self.cfg_prob + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + # Where mask is True, replace with zeros + mouse = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_mouse, mouse) + btn = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_btn, btn) + + d = torch.ones_like(x[:,:,0,0,0])*self.sc_max_steps + ts = sample_discrete_timesteps(d) + ts = torch.randn(b,n,device=x.device,dtype=x.dtype).sigmoid() + + ts_exp = eo.repeat(ts, 'b n -> b n 1 1 1') + z_x = torch.randn_like(x) + z_a = torch.randn_like(audio) + + lerpd_x = x * (1. - ts_exp) + z_x * ts_exp + lerpd_a = audio * (1. - ts[:,None,:]) + z_a * ts[:,None,:] + target_x = z_x - x + target_a = z_a - audio + + pred_x, pred_a = self.core(lerpd_x, lerpd_a, ts, mouse, btn, d) + diff_loss_x = F.mse_loss(pred_x, target_x) + diff_loss_a = F.mse_loss(pred_a, target_a) + sc_loss = self.get_sc_loss(x_sc, audio_sc, mouse_sc, btn_sc, ema) + + return diff_loss_x + diff_loss_a, sc_loss \ No newline at end of file diff --git a/owl_wms/models/gamerft_shortcut_simple.py b/owl_wms/models/gamerft_shortcut_simple.py new file mode 100644 index 00000000..54d031ce --- /dev/null +++ b/owl_wms/models/gamerft_shortcut_simple.py @@ -0,0 +1,192 @@ +""" +Causal-First RFT With Shortcut objective +""" + +import torch +from torch import nn +import torch.nn.functional as F + +import einops as eo + +from ..nn.embeddings import ( + TimestepEmbedding, + StepEmbedding, + ControlEmbedding, + LearnedPosEnc +) +from ..nn.attn import UViT, FinalLayer +from ..nn.mmattn import MMUViT +from ..utils import freeze + +class ShortcutGameRFTCore(nn.Module): + def __init__(self, config): + super().__init__() + + self.transformer = UViT(config) + self.control_embed = ControlEmbedding(config.n_buttons, config.d_model) + + self.step_embed = StepEmbedding(config.d_model) + self.t_embed = TimestepEmbedding(config.d_model) + + self.proj_in = nn.Linear(config.channels, config.d_model, bias = False) + self.proj_out = FinalLayer(config.sample_size, config.d_model, config.channels) + + self.config = config + if config.audio_tokens > 0: + self.audio_proj = nn.Linear(config.audio_channels, config.d_model, bias = False) + + def sample(self, x, mouse, btn, kv_cache = None, t = None, d = None): + """ + This is a function that largely abstracts + away most things for the specific case where + you are only generating the one next token + + The return is one step sample always + """ + + b,n,c,h,w = x.shape + if t is None: + t = torch.ones_like(x[:,:,0,0,0]) + if d is None: + d = torch.ones_like(x[:,:,0,0,0]) + + return x - self.forward(x, t, mouse, btn, d, kv_cache) + + def forward(self, x, t, mouse, btn, d, kv_cache = None): + # x is [b,n,c,h,w] + # y is [b,1,c,h,w] + # t is [b,n] + # d is [b,n] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + + ctrl_cond = self.control_embed(mouse, btn) + t_cond = self.t_embed(t) + d_cond = self.step_embed(d) + + cond = ctrl_cond + t_cond + d_cond # [b,n,d] + + b,n,c,h,w = x.shape + x = eo.rearrange(x, 'b n c h w -> b (n h w) c') + x = self.proj_in(x) + x = self.transformer(x, cond, kv_cache) + x = self.proj_out(x, cond) # -> [b,n*hw,c] + x = eo.rearrange(x, 'b (n h w) c -> b n c h w', n=n,h=h,w=w) + + return x + +def sample_discrete_timesteps(steps, eps = 1.0e-6): + # steps is Tensor([1,4,2,64,16]) as an example + b,n = steps.shape + + ts_list = [] + ts = torch.rand(b, n, device=steps.device, dtype=steps.dtype) * (steps - eps) + ts = ts.clamp(eps).ceil() / steps + """ + Example, if d was all 2, ts would be [0,2] + so do clamp, then ceil will be 1 or 2 (0, 2] + then do t / 2 and get 0.5 or 1.0, our desired timesteps + """ + return ts + +def sample_steps(b, n, device, dtype, min_val = 0): + valid = torch.tensor([2**i for i in range(min_val, 8)]) # [1,2,...,128] + inds = torch.randint(low=0,high=len(valid), size = (b,n)) + steps = valid[inds].to(device=device,dtype=dtype) + return steps + +#@torch.compile() +@torch.no_grad() +def get_sc_targets(ema, x, mouse, btn, cfg_scale): + steps_slow = sample_steps(x.shape[0], x.shape[1], x.device, x.dtype, min_val = 1) + steps_fast = steps_slow / 2 + + dt_slow = 1./steps_slow + dt_fast = 1./steps_fast + + def expand(t): + #b,c,h,w = x.shape + #t = eo.repeat(t,'b -> b c h w',c=c,h=h,w=w) + #return t + return t[:,:,None,None,None] + + ts = sample_discrete_timesteps(steps_fast) + cfg_mask = torch.isclose(steps_slow, torch.ones_like(steps_slow)*128) + cfg_mask = expand(cfg_mask) # -> [b,n,1,1,1] + + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + pred_1_uncond = ema(x, ts, null_mouse, null_btn, steps_slow) + pred_1_cond = ema(x, ts, mouse, btn, steps_slow) + pred_1_cfg = pred_1_uncond + cfg_scale * (pred_1_cond - pred_1_uncond) + pred_1 = torch.where(cfg_mask, pred_1_cfg, pred_1_cond) + + x_new = x - pred_1 * expand(dt_slow) + ts_new = ts - dt_slow + + pred_2_uncond = ema(x_new, ts_new, null_mouse, null_btn, steps_slow) + pred_2_cond = ema(x_new, ts_new, mouse, btn, steps_slow) + pred_2_cfg = pred_2_uncond + cfg_scale * (pred_2_cond - pred_2_uncond) + pred_2 = torch.where(cfg_mask, pred_2_cfg, pred_2_cond) + + pred = 0.5 * (pred_1 + pred_2) + return pred, steps_fast, ts + +class ShortcutGameRFT(nn.Module): + def __init__(self, config): + super().__init__() + + self.core = ShortcutGameRFTCore(config) + self.cfg_prob = config.cfg_prob + + self.sc_frac = 0.25 + self.sc_max_steps = 128 + self.cfg_scale = 1.3 + + self.config = config + + def get_sc_loss(self, x, mouse, btn, ema): + target, steps, ts = get_sc_targets(ema, x, mouse, btn, self.cfg_scale) + pred = self.core(x, ts, mouse, btn, steps) + sc_loss = F.mse_loss(pred, target) + return sc_loss + + def forward(self, x, mouse, btn, ema): + # x is [b,n,c,h,w] + # mouse is [b,n,2] + # btn is [b,n,n_buttons] + with torch.no_grad(): + _,n,c,h,w = x.shape + + # Split batches between consistency/rf + b = int(len(x) * (1 - self.sc_frac)) + x,x_sc = x[:b], x[b:] + mouse,mouse_sc = mouse[:b], mouse[b:] + btn,btn_sc = btn[:b], btn[b:] + + # Apply classifier-free guidance dropout + if self.cfg_prob > 0.0: + mask = torch.rand(b, device=x.device) <= self.cfg_prob + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + # Where mask is True, replace with zeros + mouse = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_mouse, mouse) + btn = torch.where(mask.unsqueeze(-1).unsqueeze(-1), null_btn, btn) + + d = torch.ones_like(x[:,:,0,0,0])*self.sc_max_steps + ts = sample_discrete_timesteps(d) + ts = torch.randn(b,n,device=x.device,dtype=x.dtype).sigmoid() + + ts_exp = eo.repeat(ts, 'b n -> b n 1 1 1') + z = torch.randn_like(x) + + lerpd = x * (1. - ts_exp) + z * ts_exp + target = z - x + + pred = self.core(lerpd, ts, mouse, btn, d) + diff_loss = F.mse_loss(pred, target) + sc_loss = self.get_sc_loss(x_sc, mouse_sc, btn_sc, ema) + + return diff_loss, sc_loss \ No newline at end of file diff --git a/owl_wms/nn/attn.py b/owl_wms/nn/attn.py index 38cee69e..a57106ff 100644 --- a/owl_wms/nn/attn.py +++ b/owl_wms/nn/attn.py @@ -8,8 +8,7 @@ import einops as eo from .modulation import AdaLN, Gate -#from .embeddings import FlatVideoRoPE -from rotary_embedding_torch import RotaryEmbedding +from .rope import FlatVideoRoPE torch.backends.cuda.enable_flash_sdp(enabled = True) @@ -42,6 +41,8 @@ def __init__(self, config : 'TransformerConfig'): self.qk_norm = QKNorm(config.d_model // config.n_heads) self.layer_ind = None + self.rope = FlatVideoRoPE(config) + self.tokens_per_frame = config.tokens_per_frame self.causal = config.causal @@ -49,11 +50,13 @@ def forward(self, x, kv_cache = None): q,k,v = eo.rearrange(self.qkv(x), 'b n (three h d) -> three b h n d', three = 3, h = self.n_heads) q,k = self.qk_norm(q,k) - if not self.causal or len(kv_cache) > 0: + if not self.causal or (kv_cache is not None and len(kv_cache) > 0): mask = None else: mask = create_block_causal_mask(x.shape[1], self.tokens_per_frame).to(x.device) + mask = mask.to(device=x.device,dtype=x.dtype) mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) + mask = mask.unsqueeze(1) if kv_cache is not None: old_k, old_v = kv_cache.get(self.layer_ind) @@ -62,7 +65,8 @@ def forward(self, x, kv_cache = None): new_k = torch.cat([old_k, k], dim = 2).contiguous() new_v = torch.cat([old_v, v], dim = 2).contiguous() - + + q,new_k = self.rope(q,new_k) if kv_cache.should_update: kv_cache.update(new_k, new_v, self.layer_ind) @@ -70,6 +74,7 @@ def forward(self, x, kv_cache = None): x = F.scaled_dot_product_attention(q, new_k, new_v, attn_mask = mask) x = x[:,:,-q.shape[2]:] # Skip cached outputs (not relevant now) else: + q,k = self.rope(q,k) x = F.scaled_dot_product_attention(q,k,v, attn_mask = mask) x = eo.rearrange(x, 'b h n d -> b n (h d)') @@ -82,9 +87,6 @@ def __init__(self, config): dim = config.d_model - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.attn = Attn(config) self.mlp = MLP(config) @@ -246,4 +248,4 @@ def test_kv_cache(): print("Cache test complete") if __name__ == "__main__": - test_attn_mask() \ No newline at end of file + test_attn_mask() diff --git a/owl_wms/nn/embeddings.py b/owl_wms/nn/embeddings.py index 01a6c1c2..c10ebddd 100644 --- a/owl_wms/nn/embeddings.py +++ b/owl_wms/nn/embeddings.py @@ -2,6 +2,8 @@ from torch import nn import torch.nn.functional as F +import math + import einops as eo from .mlp import MLPCustom @@ -11,89 +13,20 @@ ) import einops as eo -class VideoRoPE(nn.Module): - """ - Video RoPE embedding for when latents are 3D [n,h,w] - """ - def __init__(self, config : 'TransformerConfig'): - super().__init__() - - dim_head = config.d_model // config.n_heads - self.pos_emb = RotaryEmbedding( - dim = dim_head//8, - freqs_for = 'pixel', - max_freq = 256 - ) - n_patches = config.sample_size // config.patch_size - self.tokens_per_frame = n_patches**2 - - self.rearrange_in = lambda x: eo.rearrange(x, 'b h (n_t n_y n_x) d -> b h n_t n_y n_x d', n_y = n_patches) - self.rearrange_out = lambda x: eo.rearrange(x, 'b h n_t n_y n_x d -> b h (n_t n_y n_x) d') - self.get_freqs = lambda n_t: self.pos_emb.get_axial_freqs(n_t, n_patches, n_patches) - - def forward(self, q, k): - # q k both [b,h,n,d] - q = self.rearrange_in(q) - k = self.rearrange_in(k) - - n_t = q.shape[2] - freqs = self.get_freqs(n_t) - - q = apply_rotary_emb(freqs.float(), q.float()).to(q.dtype) - k = apply_rotary_emb(freqs.float(), k.float()).to(k.dtype) - - q = self.rearrange_out(q) - k = self.rearrange_out(k) - - return q, k - -class FlatVideoRoPE(nn.Module): - """ - Video RoPE embedding for when latents are 2d [n,m] (1D Frame Tokenization) - """ - def __init__(self, config : 'TransformerConfig'): - super().__init__() - - dim_head = config.d_model // config.n_heads - self.pos_emb = RotaryEmbedding( - dim = dim_head//4, - freqs_for = 'pixel', - max_freq = 256 - ) - self.pos_emb.freqs.requires_grad = False - self.tokens_per_frame = config.sample_size - - self.rearrange_in = lambda x: eo.rearrange(x, 'b h (n_t m) d -> b h n_t m d', m = self.tokens_per_frame) - self.rearrange_out = lambda x: eo.rearrange(x, 'b h n_t m d -> b h (n_t m) d') - self.get_freqs = lambda n_t: self.pos_emb.get_axial_freqs(n_t, self.tokens_per_frame) - - def forward(self, q, k): - # q k both [b,h,n,d] - q = self.rearrange_in(q) - k = self.rearrange_in(k) - - n_t = q.shape[2] - with torch.no_grad(): - freqs = self.get_freqs(n_t) - - q = apply_rotary_emb(freqs.float(), q.float()).to(q.dtype) - k = apply_rotary_emb(freqs.float(), k.float()).to(k.dtype) - - q = self.rearrange_out(q) - k = self.rearrange_out(k) - - return q, k - - class LearnedPosEnc(nn.Module): def __init__(self, n_seq, dim): super().__init__() + self.n_seq = n_seq self.p = nn.Parameter(torch.randn(n_seq,dim)*0.02) def forward(self, x): b,n,d = x.shape - p = eo.repeat(self.p, 'n d -> b n d', b = b) + if n < self.n_seq: + # Only add positional embeddings for the last n tokens + p = eo.repeat(self.p[-n:], 'n d -> b n d', b=b) + else: + p = eo.repeat(self.p, 'n d -> b n d', b=b) return x + p class SinCosEmbed(nn.Module): @@ -152,6 +85,26 @@ def forward(self, x): x = self.mlp(x) return x +class StepEmbedding(nn.Module): + def __init__(self, d_out, d_in=512, max_steps=128): + super().__init__() + + self.mlp = MLPCustom(d_in, dim_middle = 4 * d_out, dim_out=d_out) + self.max_steps = max_steps + mult = 1000 / math.log2(max_steps) + self.sincos = SinCosEmbed(d_in, theta=300, mult=mult) + + def forward(self, steps): + if not isinstance(steps, torch.Tensor): + steps = torch.tensor(steps, device=self.mlp.fc_uv.weight.device, dtype=self.mlp.fc_uv.weight.dtype) + if steps.ndim == 0: + steps = steps.unsqueeze(0) + + # Map steps to [0, log2(max_steps)] + t = (math.log2(self.max_steps) - torch.log2(steps.float())).to(steps.dtype) + embs = self.sincos(t) + return self.mlp(embs) + class ConditionEmbedding(nn.Module): def __init__(self, n_classes, dim): super().__init__() diff --git a/owl_wms/nn/kv_cache.py b/owl_wms/nn/kv_cache.py index 99d33eb9..fea43f0c 100644 --- a/owl_wms/nn/kv_cache.py +++ b/owl_wms/nn/kv_cache.py @@ -14,6 +14,9 @@ def __init__(self, config : TransformerConfig): self.should_update = False + self.max_length = config.tokens_per_frame * config.n_frames + self.noise_caches = 0.0 + def enable_cache_updates(self): self.should_update = True @@ -23,6 +26,7 @@ def disable_cache_updates(self): def to(self, device = 'cuda', dtype = torch.bfloat16): self.device = device self.dtype = dtype + return self def reset(self, batch_size = 1): self.shape = (batch_size, self.config.n_heads, 0, self.config.d_model//self.config.n_heads) @@ -33,6 +37,9 @@ def reset(self, batch_size = 1): def get(self, layer_ind): assert self.cache is not None, "Must reset cache before using" k,v = self.cache[layer_ind] + if self.noise_caches > 0.0: + k = k + torch.randn_like(k) * self.noise_caches + v = v + torch.randn_like(v) * self.noise_caches return k,v @torch.no_grad() @@ -46,11 +53,31 @@ def push(self, new_k, new_v, layer_ind): @torch.no_grad() def update(self, new_k, new_v, layer_ind): assert self.cache is not None, "Must reset cache before using" - self.cache[layer_ind] = (new_k,new_v) + + def tuple_truncate(k, v): + k = k[:,:,-self.max_length:] + v = v[:,:,-self.max_length:] + return k, v + + self.cache[layer_ind] = tuple_truncate(new_k,new_v) + + @torch.no_grad() + def truncate(self, truncate_amt): + """ + Truncate frames from the KV cache + """ + truncate_amt = truncate_amt * self.config.tokens_per_frame + def tuple_truncate(k, v): + k = k[:,:,truncate_amt:] + v = v[:,:,truncate_amt:] + return k, v + + for i in range(self.config.n_layers): + self.cache[i] = tuple_truncate(*self.cache[i]) def __len__(self): assert self.cache is not None, "Must reset cache before using" - return self.cache[0].shape[2] + return self.cache[0][0].shape[2] def shape(self): - return self.shape \ No newline at end of file + return self.shape diff --git a/owl_wms/nn/mmattn.py b/owl_wms/nn/mmattn.py new file mode 100644 index 00000000..00855779 --- /dev/null +++ b/owl_wms/nn/mmattn.py @@ -0,0 +1,293 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from .normalization import LayerNorm, RMSNorm, QKNorm +from .mlp import MLP + +import einops as eo + +from .modulation import AdaLN, Gate +from .rope import FlatVideoRoPE + +torch.backends.cuda.enable_flash_sdp(enabled = True) + +from einops._torch_specific import allow_ops_in_compiled_graph +allow_ops_in_compiled_graph() + +""" +This code makes the assumption that there are some +tokens from another modality that must always be attended to +""" + +def create_block_causal_mask_with_mm(tokens, context_tokens, tokens_per_frame): + frames = tokens // tokens_per_frame + + # Create base causal mask, nothing is masked + total_tokens = tokens + context_tokens + mask = torch.zeros(total_tokens, total_tokens) + + # Allow attention within each frame + for i in range(frames): + start = i * tokens_per_frame + end = (i + 1) * tokens_per_frame + mask[start:end, end:tokens] = True # Can't see future frames + + # Context tokens can only attend to themselves + mask[tokens:, :tokens] = True # Mask out attention to regular tokens + + # Regular tokens can still attend to all context tokens (no masking needed) + # The zeros in mask[:, tokens:] allow tokens to attend to all context + + return mask + +class MMAttn(nn.Module): + """ + MMDiT style attention + """ + def __init__(self, config : 'TransformerConfig'): + super().__init__() + + self.n_heads = config.n_heads + + self.qkv_1 = nn.Linear(config.d_model, 3 * config.d_model) + self.qkv_2 = nn.Linear(config.d_model, 3 * config.d_model) + + self.out_1 = nn.Linear(config.d_model, config.d_model) + self.out_2 = nn.Linear(config.d_model, config.d_model) + + self.qk_norm_1 = QKNorm(config.d_model // config.n_heads) + self.qk_norm_2 = QKNorm(config.d_model // config.n_heads) + + self.config = config + self.causal = config.causal + + self.rope = FlatVideoRoPE(config) + + def split(self, qkv): + return eo.rearrange(qkv, 'b n (three h d) -> three b h n d', three = 3, h = self.n_heads) + + def merge(self, x): + return eo.rearrange(x, 'b h n d -> b n (h d)') + + def forward(self, x_1, x_2, kv_cache=None): + n1 = x_1.shape[1] + + q1,k1,v1 = self.split(self.qkv_1(x_1)) + q2,k2,v2 = self.split(self.qkv_2(x_2)) + + q1,k1 = self.qk_norm_1(q1,k1) + q2,k2 = self.qk_norm_2(q2,k2) + + if not self.causal or (kv_cache is not None and len(kv_cache) > 0): + mask = None + else: + mask = create_block_causal_mask_with_mm(x_1.shape[1], x_2.shape[1], self.config.tokens_per_frame) + mask = mask.to(device=x_1.device,dtype=x_1.dtype) + mask = mask.unsqueeze(0).repeat(x_1.shape[0],1,1) + mask = mask.unsqueeze(1) # head dim + + if kv_cache is not None: + if len(kv_cache) > 0: + old_k, old_v = kv_cache.get(self.layer_ind) + + new_k = torch.cat([old_k, k1], dim=2).contiguous() + new_v = torch.cat([old_v, v1], dim=2).contiguous() + else: + new_k = k1.contiguous() + new_v = v1.contiguous() + + if kv_cache.should_update: + kv_cache.update(new_k, new_v, self.layer_ind) + + q1, new_k = self.rope(q1, new_k) + + k = torch.cat([new_k, k2], dim=-2) + v = torch.cat([new_v, v2], dim=-2) + q = torch.cat([q1, q2], dim=-2) + + x = F.scaled_dot_product_attention(q, k, v, attn_mask = mask) + x = x[:,:,-q.shape[2]:] # Only keep latest outputs + x = self.merge(x) + else: + q1, k1 = self.rope(q1,k1) + + q = torch.cat([q1,q2],dim=-2) + k = torch.cat([k1,k2],dim=-2) + v = torch.cat([v1,v2],dim=-2) + + x = F.scaled_dot_product_attention(q,k,v, attn_mask = mask) + x = self.merge(x) + + x_1, x_2 = x[:,:n1], x[:,n1:] + x_1 = self.out_1(x_1) + x_2 = self.out_2(x_2) + + return x_1, x_2 + +class MMDiTBlock(nn.Module): + def __init__(self, config): + super().__init__() + + dim = config.d_model + + self.attn = MMAttn(config) + + self.mlp_1 = MLP(config) + self.mlp_2 = MLP(config) + + # Stream 1 - AdaLN and gating + self.adaln1_1 = AdaLN(dim) + self.gate1_1 = Gate(dim) + self.adaln2_1 = AdaLN(dim) + self.gate2_1 = Gate(dim) + + # Stream 2 - Standard LayerNorm + self.ln1_2 = nn.LayerNorm(dim) + self.ln2_2 = nn.LayerNorm(dim) + + def forward(self, x, y, cond, kv_cache = None): + res1_x = x.clone() + res1_y = y.clone() + + # First attention block + x = self.adaln1_1(x, cond) + y = self.ln1_2(y) + + x, y = self.attn(x, y, kv_cache) + + x = self.gate1_1(x, cond) + + x = res1_x + x + y = res1_y + y + + # Second MLP block + res2_x = x.clone() + res2_y = y.clone() + + x = self.adaln2_1(x, cond) + y = self.ln2_2(y) + + x = self.mlp_1(x) + y = self.mlp_2(y) + + x = self.gate2_1(x, cond) + + x = res2_x + x + y = res2_y + y + + return x, y + +class MMUViT(nn.Module): + def __init__(self, config): + super().__init__() + + blocks = [] + for i in range(config.n_layers): + blocks.append(MMDiTBlock(config)) + blocks[-1].attn.layer_ind = i + + self.blocks = nn.ModuleList(blocks) + + # For odd number of layers, need linear projections for skip connections + n_skip_connections = config.n_layers // 2 + skip_projs = [] + for _ in range(n_skip_connections): + skip_projs.append(nn.Linear(config.d_model * 2, config.d_model)) + self.skip_projs = nn.ModuleList(skip_projs) + + def forward(self, x, y, cond, kv_cache = None): + # Cache early block outputs for skip connections + early_features = [] + n_blocks = len(self.blocks) + mid_idx = n_blocks // 2 + + # Early blocks + for i in range(mid_idx): + x,y = self.blocks[i](x, y, cond, kv_cache) + early_features.append(x) + + # Middle block (if odd number of layers) + x,y = self.blocks[mid_idx](x, y, cond, kv_cache) + + # Late blocks with skip connections + for i in range(mid_idx + 1, n_blocks): + # Get corresponding early block output + early_idx = n_blocks - 1 - i + early_feat = early_features[early_idx] + + # Concatenate early and current features + skip_idx = i - (mid_idx + 1) + x = torch.cat([x, early_feat], dim=-1) + x = self.skip_projs[skip_idx](x) + + x,y = self.blocks[i](x, y, cond, kv_cache) + + return x + + +def test_fwd_with_cache(): + from ..configs import TransformerConfig + from .kv_cache import KVCache + + import matplotlib.pyplot as plt + + cfg = TransformerConfig( + None, + 6, + 6, + 384, + 1, + 128, + 4, + 0.1, + 8, + 16, + True + ) + + model = MMUViT(cfg).bfloat16().cuda() + + NUM_FRAMES = 10 + x = torch.randn(1,16*NUM_FRAMES,384).bfloat16().cuda() + y = torch.randn(1,16,384).bfloat16().cuda() + cond=torch.randn(1,16,384).bfloat16().cuda() + + cache = KVCache(cfg).to(device='cuda',dtype=torch.bfloat16) + cache.reset(1) + + with torch.no_grad(): + cache.enable_cache_updates() + out = model(x,y,cond,cache) + + new_x = torch.randn(1,16,384).bfloat16().cuda() + cond = torch.randn(1,1,384).bfloat16().cuda() + + print(len(cache)) + print(cache.cache[0][0].shape) + new_out = model(new_x, y, cond, cache) + + print(len(cache)) + print(cache.cache[0][0].shape) + +def test_mask(): + import matplotlib.pyplot as plt + + n_frames = 10 + n_tok_per_frame = 16 + n_context = 16 + + mask = create_block_causal_mask_with_mm(n_frames*n_tok_per_frame, n_context, n_tok_per_frame) + + plt.figure(figsize=(10,10)) + plt.imshow(mask.float().cpu().numpy(), cmap='gray') + plt.colorbar() + plt.title(f'Block Causal Mask with MM ({n_frames*n_tok_per_frame} tokens, {n_context} context, {n_tok_per_frame} per frame)') + plt.xlabel('Key Position') + plt.ylabel('Query Position') + plt.savefig('test_mm_mask.png') + plt.close() + + +if __name__ == "__main__": + test_mask() diff --git a/owl_wms/nn/rope.py b/owl_wms/nn/rope.py new file mode 100644 index 00000000..8c5b8cfd --- /dev/null +++ b/owl_wms/nn/rope.py @@ -0,0 +1,97 @@ +""" +Variants of RoPE were becoming heavy for embeddings so +I made a unique script for all of them here +""" + +from rotary_embedding_torch import ( + RotaryEmbedding, + apply_rotary_emb +) +import einops as eo +import torch +from torch import nn + +class VideoRoPE(nn.Module): + """ + Video RoPE embedding for when latents are 3D [n,h,w] + """ + def __init__(self, config : 'TransformerConfig'): + super().__init__() + + dim_head = config.d_model // config.n_heads + self.pos_emb = RotaryEmbedding( + dim = dim_head//8, + freqs_for = 'pixel', + max_freq = 256 + ) + n_patches = config.sample_size // config.patch_size + self.tokens_per_frame = n_patches**2 + + self.rearrange_in = lambda x: eo.rearrange(x, 'b h (n_t n_y n_x) d -> b h n_t n_y n_x d', n_y = n_patches) + self.rearrange_out = lambda x: eo.rearrange(x, 'b h n_t n_y n_x d -> b h (n_t n_y n_x) d') + self.get_freqs = lambda n_t: self.pos_emb.get_axial_freqs(n_t, n_patches, n_patches) + + def forward(self, q, k): + # q k both [b,h,n,d] + q = self.rearrange_in(q) + k = self.rearrange_in(k) + + n_t = q.shape[2] + freqs = self.get_freqs(n_t) + + q = apply_rotary_emb(freqs.float(), q.float()).to(q.dtype) + k = apply_rotary_emb(freqs.float(), k.float()).to(k.dtype) + + q = self.rearrange_out(q) + k = self.rearrange_out(k) + + return q, k + +class FlatVideoRoPE(nn.Module): + """ + Half-flat of RoPE that treats [n_frames, tokens_per_frame] as [n_frames, tokens_per_frame] image + """ + def __init__(self, config): + super().__init__() + + dim_head = config.d_model // config.n_heads + self.pos_emb = RotaryEmbedding( + dim = dim_head//4, + freqs_for='pixel', + max_freq=256 + ) + + self.m = config.tokens_per_frame + + def pad_q(self, q, k): + # Pad Q when it's needed for kv caching + q_len = q.shape[2] + k_len = k.shape[2] + + def forward(self, q, k): + # q|k is [b,h,n_frames*tokens_per_frame,d] + n = k.shape[2]//self.m + m = self.m + + truncate = n + if q.shape[2] < n * m: + truncate = q.shape[2]//m # How many frames is q? + + q = eo.rearrange(q, 'b h (n m) d -> b h n m d', n=q.shape[2]//m,m=m) + k = eo.rearrange(k, 'b h (n m) d -> b h n m d', n=n,m=m) + + with torch.no_grad(): + freqs = self.pos_emb.get_axial_freqs(n,m) + q = apply_rotary_emb(freqs[-truncate:].detach(), q) + k = apply_rotary_emb(freqs.detach(), k) + + q = eo.rearrange(q, 'b h n m d -> b h (n m) d') + k = eo.rearrange(k, 'b h n m d -> b h (n m) d') + + if truncate is not None: + q = q[:,:,-truncate*m:] + + return q,k + + + diff --git a/owl_wms/sampling/__init__.py b/owl_wms/sampling/__init__.py index c35f799e..b2bc45cc 100644 --- a/owl_wms/sampling/__init__.py +++ b/owl_wms/sampling/__init__.py @@ -8,4 +8,13 @@ def get_sampler_cls(sampler_id): elif sampler_id == "cfg": return CFGSampler elif sampler_id == "window": - return WindowCFGSampler \ No newline at end of file + return WindowCFGSampler + elif sampler_id == "shortcut": + from .shortcut_sampler import CacheShortcutSampler + return CacheShortcutSampler + elif sampler_id == "shortcut_2": + from .shortcut_sampler import WindowShortcutSamplerNoKeyframe + return WindowShortcutSamplerNoKeyframe + elif sampler_id == "av_window": + from .av_window import AVWindowSampler + return AVWindowSampler \ No newline at end of file diff --git a/owl_wms/sampling/av_window.py b/owl_wms/sampling/av_window.py new file mode 100644 index 00000000..ee4715ad --- /dev/null +++ b/owl_wms/sampling/av_window.py @@ -0,0 +1,264 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from tqdm import tqdm + +from ..utils import batch_permute_to_length + +def zlerp(x, alpha): + z = torch.randn_like(x) + return x * (1. - alpha) + z * alpha + +class AVWindowSampler: + """ + Window CFG Sampler samples new frames one by one, by inpainting the final frame. + This is basically diffusion forcing. + + :param n_steps: Number of diffusion steps for each frame (diffusoin steps) + :param cfg_scale: CFG scale for each frame + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample + :param noise_prev: Noise previous frame + :param only_return_generated: Whether to only return the generated frames + """ + def __init__(self, n_steps = 20, cfg_scale = 1.3, window_length = 60, num_frames = 60, noise_prev = 0.2, only_return_generated = False): + self.n_steps = n_steps + self.cfg_scale = cfg_scale + self.window_length = window_length + self.num_frames = num_frames + self.noise_prev = noise_prev + self.only_return_generated = only_return_generated + + @torch.no_grad() + def __call__(self, model, dummy_batch, audio, mouse, btn, decode_fn = None, audio_decode_fn = None, image_scale = 1, audio_scale = 1): + # dummy_batch is [b,n,c,h,w] + # audio is [b,n,c] and should be treated same as video (it'e being generated) + # mouse is [b,n,2] + # btn is [b,n,n_button] + + # output will be [b,n+self.num_frames,c,h,w] + + sampling_steps = self.n_steps + num_frames = self.num_frames + + dt = 1. / sampling_steps + + clean_history = dummy_batch.clone() + clean_audio_history = audio.clone() + + extended_mouse, extended_btn = batch_permute_to_length(mouse, btn, num_frames + self.window_length) + + def step_history(): + # Video history + new_history = clean_history.clone()[:,-self.window_length:] + b,n,c,h,w = new_history.shape + new_history[:,:-1] = zlerp(new_history[:,1:],self.noise_prev) + new_history[:,-1] = torch.randn_like(new_history[:,0]) + + # Audio history + new_audio = clean_audio_history.clone()[:,-self.window_length:] + new_audio[:,:-1] = zlerp(new_audio[:,1:],self.noise_prev) + new_audio[:,-1] = torch.randn_like(new_audio[:,0]) + + return new_history, new_audio + + for frame_idx in tqdm(range(num_frames)): + local_history, local_audio = step_history() + ts_history = torch.ones(local_history.shape[0], local_history.shape[1], device=local_history.device,dtype=local_history.dtype) + ts_history[:,:-1] = self.noise_prev + + mouse = extended_mouse[:,frame_idx:frame_idx+self.window_length] + btn = extended_btn[:,frame_idx:frame_idx+self.window_length] + + mouse_batch = torch.cat([mouse, torch.zeros_like(mouse)], dim=0) + btn_batch = torch.cat([btn, torch.zeros_like(btn)], dim=0) + for _ in range(sampling_steps): + # CFG Branches + x = local_history.clone() + a = local_audio.clone() + ts = ts_history.clone() + + x_batch = torch.cat([x, x], dim=0) + a_batch = torch.cat([a, a], dim=0) + ts_batch = torch.cat([ts, ts], dim=0) + + pred_video_batch, pred_audio_batch = model(x_batch, a_batch, ts_batch, mouse_batch, btn_batch) + + # Split predictions back into conditional and unconditional + cond_pred_video, uncond_pred_video = pred_video_batch.chunk(2) + cond_pred_audio, uncond_pred_audio = pred_audio_batch.chunk(2) + + pred_video = uncond_pred_video + self.cfg_scale * (cond_pred_video - uncond_pred_video) + pred_audio = uncond_pred_audio + self.cfg_scale * (cond_pred_audio - uncond_pred_audio) + + x = x - pred_video*dt + a = a - pred_audio*dt + ts = ts - dt + + local_history[:,-1] = x[:,-1] + local_audio[:,-1] = a[:,-1] + ts_history[:,-1] = ts[:,-1] + + # Frame is entirely cleaned now + new_frame = local_history[:,-1:] + new_audio = local_audio[:,-1:] + clean_history = torch.cat([clean_history, new_frame], dim=1) + clean_audio_history = torch.cat([clean_audio_history, new_audio], dim=1) + + x = clean_history + audio = clean_audio_history + if self.only_return_generated: + x = x[:,-num_frames:] + audio = audio[:,-num_frames:] + extended_mouse = extended_mouse[:,-num_frames:] + extended_btn = extended_btn[:,-num_frames:] + + if decode_fn is not None: + x = x * image_scale + x = decode_fn(x) + + if audio_decode_fn is not None: + audio = audio * audio_scale + audio = audio_decode_fn(audio) + + return x, audio, extended_mouse, extended_btn + + +class Inference_AV_WindowSampler: + """ + Window CFG Sampler samples new frames one by one, by inpainting the final frame. + This is basically diffusion forcing. + + :param n_steps: Number of diffusion steps for each frame (diffusoin steps) + :param cfg_scale: CFG scale for each frame + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample + :param noise_prev: Noise previous frame + :param only_return_generated: Whether to only return the generated frames + """ + def __init__(self, n_steps = 20, cfg_scale = 1.3, window_length = 60, num_frames = 60, noise_prev = 0.2, only_return_generated = False): + self.n_steps = n_steps + self.cfg_scale = cfg_scale + self.window_length = window_length + self.num_frames = num_frames + self.noise_prev = noise_prev + self.only_return_generated = only_return_generated + + @torch.no_grad() + def __call__(self, model, dummy_batch, audio, mouse, btn, decode_fn = None, audio_decode_fn = None, image_scale = 1, audio_scale = 1): + # dummy_batch is [b,n,c,h,w] + # audio is [b,n,c] and should be treated same as video (it'e being generated) + # mouse is [b,n,2] + # btn is [b,n,n_button] + + # output will be [b,n+self.num_frames,c,h,w] + + sampling_steps = self.n_steps + num_frames = self.num_frames + + dt = 1. / sampling_steps + + clean_history = dummy_batch.clone() + clean_audio_history = audio.clone() + + assert mouse.shape[1] == num_frames + self.window_length + assert btn.shape[1] == num_frames + self.window_length + + extended_mouse, extended_btn = mouse, btn + + def step_history(): + # Video history + new_history = clean_history.clone()[:,-self.window_length:] + b,n,c,h,w = new_history.shape + new_history[:,:-1] = zlerp(new_history[:,1:],self.noise_prev) + new_history[:,-1] = torch.randn_like(new_history[:,0]) + + # Audio history + new_audio = clean_audio_history.clone()[:,-self.window_length:] + new_audio[:,:-1] = zlerp(new_audio[:,1:],self.noise_prev) + new_audio[:,-1] = torch.randn_like(new_audio[:,0]) + + return new_history, new_audio + + for frame_idx in tqdm(range(num_frames)): + local_history, local_audio = step_history() + ts_history = torch.ones(local_history.shape[0], local_history.shape[1], device=local_history.device,dtype=local_history.dtype) + ts_history[:,:-1] = self.noise_prev + + mouse = extended_mouse[:,frame_idx:frame_idx+self.window_length] + btn = extended_btn[:,frame_idx:frame_idx+self.window_length] + + mouse_batch = torch.cat([mouse, torch.zeros_like(mouse)], dim=0) + btn_batch = torch.cat([btn, torch.zeros_like(btn)], dim=0) + for _ in range(sampling_steps): + # CFG Branches + x = local_history.clone() + a = local_audio.clone() + ts = ts_history.clone() + + x_batch = torch.cat([x, x], dim=0) + a_batch = torch.cat([a, a], dim=0) + ts_batch = torch.cat([ts, ts], dim=0) + + pred_video_batch, pred_audio_batch = model(x_batch, a_batch, ts_batch, mouse_batch, btn_batch) + + # Split predictions back into conditional and unconditional + cond_pred_video, uncond_pred_video = pred_video_batch.chunk(2) + cond_pred_audio, uncond_pred_audio = pred_audio_batch.chunk(2) + + pred_video = uncond_pred_video + self.cfg_scale * (cond_pred_video - uncond_pred_video) + pred_audio = uncond_pred_audio + self.cfg_scale * (cond_pred_audio - uncond_pred_audio) + + x = x - pred_video*dt + a = a - pred_audio*dt + ts = ts - dt + + local_history[:,-1] = x[:,-1] + local_audio[:,-1] = a[:,-1] + ts_history[:,-1] = ts[:,-1] + + # Frame is entirely cleaned now + new_frame = local_history[:,-1:] + new_audio = local_audio[:,-1:] + clean_history = torch.cat([clean_history, new_frame], dim=1) + clean_audio_history = torch.cat([clean_audio_history, new_audio], dim=1) + + x = clean_history + audio = clean_audio_history + if self.only_return_generated: + x = x[:,-num_frames:] + audio = audio[:,-num_frames:] + extended_mouse = extended_mouse[:,-num_frames:] + extended_btn = extended_btn[:,-num_frames:] + + pixel_latents = x + pixels = None + + audio_latents = audio + audio_wav = None + + if decode_fn is not None: + pixels = decode_fn(pixel_latents * image_scale) + + if audio_decode_fn is not None: + audio_wav = audio_decode_fn(audio_latents * audio_scale) + + return ( + pixel_latents, audio_latents, # NOTE Need this for history + pixels, audio_wav, # NOTE Need this for rendering + extended_mouse, extended_btn, + clean_history, clean_audio_history + ) + + +def test_av_window_sampler(): + sampler = AVWindowSampler() + model = lambda x, ts, mouse, btn: x + dummy_batch = torch.randn(1, 32, 128, 4, 4) + audio = torch.randn(1, 32, 128) + mouse = torch.zeros(1, 32, 2) + btn = torch.zeros(1, 32, 11) + +if __name__ == "__main__": + test_av_window_sampler() diff --git a/owl_wms/sampling/cfg.py b/owl_wms/sampling/cfg.py index 3a262366..2e2dc976 100644 --- a/owl_wms/sampling/cfg.py +++ b/owl_wms/sampling/cfg.py @@ -29,10 +29,10 @@ def __call__(self, model, dummy_batch, mouse, btn, decode_fn = None, scale = 1): x = x - pred*dt ts = ts - dt + pixels = None if decode_fn is not None: - x = x * scale - x = decode_fn(x) - return x, mouse, btn + pixels = decode_fn(x * scale) + return x, pixels, mouse, btn class InpaintCFGSampler(CFGSampler): @torch.no_grad() @@ -63,15 +63,21 @@ def __call__(self, model, dummy_batch, mouse, btn, decode_fn = None, scale = 1): x[:, mid:] = x[:, mid:] - pred[:, mid:]*dt ts[:, mid:] = ts[:, mid:] - dt + pixels = None if decode_fn is not None: - x = x * scale - x = decode_fn(x) - return x, mouse, btn + pixels = decode_fn(x * scale) + return x, pixels, mouse, btn + + +def zlerp(x, alpha): + z = torch.randn_like(x) + return x * (1. - alpha) + z * alpha + if __name__ == "__main__": model = lambda x,t,m,b: x sampler = CFGSampler() - x = sampler(model, torch.randn(4, 128, 16, 128), + x, pixels = sampler(model, torch.randn(4, 128, 16, 128), torch.randn(4, 128, 2), torch.randn(4, 128, 11)) print(x.shape) \ No newline at end of file diff --git a/owl_wms/sampling/shortcut_sampler.py b/owl_wms/sampling/shortcut_sampler.py new file mode 100644 index 00000000..7c38fdd8 --- /dev/null +++ b/owl_wms/sampling/shortcut_sampler.py @@ -0,0 +1,600 @@ +import cv2 +import math +import pathlib +import torch +from torch import Module +from tqdm import tqdm +from typing import Optional + +from ..nn.kv_cache import KVCache +from ..utils import batch_permute_to_length +from ..models.gamerft_shortcut import ShortcutGameRFT + + +def zlerp(x, alpha): + z = torch.randn_like(x) + return x * (1. - alpha) + z * alpha + +def load_mp4_as_tensor(mp4_path: pathlib.Path) -> torch.Tensor: + """Load MP4 as tensor in format [N, C=3, H, W] with values in [-1, 1]""" + video = cv2.VideoCapture(str(mp4_path)) + + if not video.isOpened(): + raise ValueError(f"Could not open video file: {mp4_path}") + + frames = [] + while True: + ret, frame = video.read() + if not ret: + break + + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Convert to torch tensor and normalize to [-1, 1] + frame = torch.from_numpy(frame).float() / 127.5 - 1.0 + + # Rearrange from [H, W, C] to [C, H, W] + frame = frame.permute(2, 0, 1) + + frames.append(frame) + + video.release() + + if not frames: + raise ValueError(f"No frames found in video: {mp4_path}") + + # Stack to [N, C, H, W] + return torch.stack(frames) + + +class InferenceCachedShortcutSampler: + + ALPHA = 0.25 + + def __init__(self, + model: ShortcutGameRFT, + window_length = 60, + num_frames = 1, + only_return_generated = False, + vae_scale = 2.17, + decode_fn: Optional[Module] = None, + initial_history_pt_path: Optional[pathlib.Path] = None, + initial_history_mp4_path: Optional[pathlib.Path] = None, + encoder: Optional[Module] = None): + # -- + self.model: ShortcutGameRFT = model + self.window_length = window_length + self.num_frames = num_frames + + self.vae_scale = vae_scale + self.only_return_generated = only_return_generated + + # -- + self._cache_built = False + self.cache = KVCache(model.config) + self.decode_fn = decode_fn + self.initial_history_pt_path = initial_history_pt_path + self.initial_history_mp4_path = initial_history_mp4_path + self.encoder = encoder + + assert initial_history_pt_path is not None or initial_history_mp4_path is not None, \ + 'Either initial_history_pt_path or initial_history_mp4_path must be provided' + + if initial_history_mp4_path is not None: + assert encoder is not None, \ + 'Encoder must be provided if initial_history_mp4_path is provided' + + self.initial_history_bWchw = self.init_history(self.initial_history_pt_path, self.initial_history_mp4_path) + self.keyframe_b1chw = self.initial_history_bWchw[:,0] + + def init_history(self, + initial_history_pt_path: pathlib.Path | None, + initial_history_mp4_path: pathlib.Path | None) -> torch.Tensor: + + if initial_history_pt_path is not None: + history_wchw = torch.load(initial_history_pt_path) + else: + history_wrgb = load_mp4_as_tensor(initial_history_mp4_path).unsqueeze(0) # add batch dim + history_wchw = self.encoder(history_wrgb) + # NOTE This is so we avoid generating the history with a compiled model. + torch.save(history_wchw, initial_history_mp4_path.absolute().replace('.mp4', '.pt')) + + N = self.window_length + C = self.model.config.channels + H = W = int(math.sqrt(self.model.config.tokens_per_frame)) + + assert tuple(history_wchw.shape) == (1, N, C, H, W), \ + f'Initial history must have shape (B=1, {N=}, {C=}, {H=}, {W=}), ' \ + f'but got {tuple(history_wchw.shape)}' + + return history_wchw + + + def init_cache(self, + frames_bWchw, # [B, W, c, h, w] - NOTE history of frames + keyframe_b1chw, # [B, 1, c, h, w] - NOTE keyframe conditioning + mouse_bW2, # [B, W, 2] + button_bW11, # [B, W, 11] + ts_bW, # [B, W] + d_bW): # [B, W] + if self._cache_built: + print(f'WARNING: Cache already built but called `init_cache` again - ignoring.') + return + + B, N, *_ = frames_bWchw.shape + + self.cache.reset(B) ; self.cache.enable_cache_updates() + + # -- noise the history and fwd to kv cache + self.model.core.sample(x=zlerp(frames_bWchw, self.ALPHA), + y=keyframe_b1chw, + mouse=mouse_bW2, + btn=button_bW11, + cache=self.cache, + ts=ts_bW, d=d_bW) + + self.cache.disable_cache_updates() ; self._cache_built = True + print(f'Cache initialized for {B} x {N} frames - {[[i.shape for i in elt] + for elt in self.cache.cache]}') + return self.cache + + def __call__(self, + ctxt_frame_b1chw, # [B, 1, c, h, w] - NOTE Keyframe conditioning + mouse_b1_2, # [B, 1, 2] - NOTE mouse actions + button_b1_11, # [B, 1, 11] - NOTE button actions + ts_alpha_b1, # [B, 1] - NOTE overall denoising timestamp (e.g. 128) + d_alpha_b1, # [B, 1] - NOTE denoising step budget (e.g. 4) + ) -> torch.Tensor: # [B, 1, c, h, w] + # 1. ---- generate next frame ---- + self.cache.disable_cache_updates() + # 1.A) -- use the full context, including entire action history, to generate the next frame given cache. + frame = self.model.core.sample(None, ctxt_frame_b1chw, + mouse_b1_2, button_b1_11, + self.cache, ts=None, d=None) # NOTE simulating one-step sampling + # 2. ---- repopulate cache ---- + self.cache.enable_cache_updates() ; self.cache.truncate(1) + self.model.core.sample( x=zlerp(frame, self.ALPHA), # diffuse with noised frame to repopulate cache + y=ctxt_frame_b1chw, + mouse=mouse_b1_2, + btn=button_b1_11, + cache=self.cache, + ts=ts_alpha_b1, d=d_alpha_b1) + self.cache.disable_cache_updates() + return frame + + @torch.no_grad() + def generate_frames(self, + history_bWchw, # [B, W, c, h, w] - NOTE: MP4 from CoD initially, and after that it's just KV cache. + mouse_bT2, # [B, W+N, 2] - Actions taken by the user. + button_bT11, # [B, W+N, 11] - Actions taken by the user. + ) -> torch.Tensor: # [B, W+N, c, h, w] - either latent or rgb. + + if not self._cache_built: + print(f'WARNING: Cache not built, but called `generate_frames` - initializing cache.') + self.init_cache(history_bWchw, self.keyframe_b1chw, mouse_bT2, button_bT11) + + # If does not have batch-size, add it. This sampler is going to be used for single-user inference so batch-size is always 1. + # The caller might not specify the batch-size, so we have this here. + if history_bWchw.ndim == 4: + history_bWchw = history_bWchw.unsqueeze(1) + + history_bWchw = history_bWchw[:, -self.window_length:, ::] + + assert history_bWchw.shape[1] == self.window_length, \ + f'Window history must be at least {self.window_length} frames long, but got {history_bWchw.shape}' + + ts_alpha_bW = torch.ones_like(history_bWchw[:,:,0,0,0]) * self.ALPHA + d_alpha_bW = torch.ones_like(history_bWchw[:,:,0,0,0]) * round(1./self.ALPHA) + + ts_alpha_b1 = ts_alpha_bW[:,0].unsqueeze(1) + d_alpha_b1 = d_alpha_bW [:,0].unsqueeze(1) + + frames_latent = [] + for frame_idx in range(self.num_frames): + btn_atom = button_bT11[:, self.window_length+frame_idx].unsqueeze(1) + mouse_atom = mouse_bT2 [:, self.window_length+frame_idx].unsqueeze(1) + frame = self.__call__(ctxt_frame_b1chw=self.keyframe_b1chw, + mouse_b1_2=mouse_atom, button_b1_11=btn_atom, + ts_alpha_b1=ts_alpha_b1, d_alpha_b1=d_alpha_b1) + frames_latent += [frame] + + frames_latent = torch.cat(frames_latent, dim=1) + + if self.only_return_generated: frames_latent = frames_latent[:,-self.num_frames:] + + if self.decode_fn is not None: + frames_rgb = self.decode_fn(frames_latent * self.vae_scale) + return frames_rgb, mouse_bT2, button_bT11 + + return frames_latent, mouse_bT2, button_bT11 + + +class CacheShortcutSampler: + """ + Shortcut CFG sampler builds cache with 4 step diffusion. + Samples new frames in 1 step. + + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample + :param only_return_generated: Whether to only return the generated frames + """ + def __init__(self, window_length = 60, num_frames = 60, only_return_generated = False): + self.n_steps = n_steps + self.cfg_scale = cfg_scale + self.window_length = window_length + self.num_frames = num_frames + self.only_return_generated = only_return_generated + + @torch.no_grad() + def __call__(self, model, history, keyframe, mouse, btn, decode_fn = None, scale = 1): + # dummy_batch is [b,n,c,h,w] + # mouse is [b,n,2] + # btn is [b,n,n_button] + + # output will be [b,n+self.num_frames,c,h,w] + history = history[:,:self.window_length] + new_frames = [] + alpha = 0.25 # This number is special for our sampler + + # Extended fake controls to use during sampling + extended_mouse, extended_btn = batch_permute_to_length(mouse, btn, num_frames + self.window_length) + + # Generate cache over history + noisy_history = zlerp(history.clone(), alpha) + ts = torch.ones_like(noisy_history[:,:,0,0,0]) * alpha + d = torch.ones_like(noisy_history[:,:,0,0,0]) * round(1./alpha) + ts_single = ts[:,0].unsqueeze(1) + d_single = d[:,0].unsqueeze(1) + + cache = KVCache(model.config) + cache.reset(history.shape[0]) + + cache.enable_cache_updates() + _ = model.sample(noisy_history, keyframe, mouse, btn, cache, ts, d) + cache.disable_cache_updates() + + # Cache is now built! + + for frame_idx in tqdm(range(num_frames)): + cache.truncate(1) # Drop first frame + + # Generate new frame + cache.disable_cache_updates() + mouse = extended_mouse[:,self.window_length+frame_idx].unsqueeze(1) + btn = extended_btn[:,self.window_length+frame_idx].unsqueeze(1) + # N+1 + new_frame = model.sample(None, keyframe, mouse, btn, cache) # [b,1,c,h,w] + new_frames.append(new_frame) + + # Add that frame to the cache + cache.enable_cache_updates() + new_frame_noisy = zlerp(new_frame, alpha) + # N+2, noisy(N+1) gets cached + _ = model.sample(new_frame_noisy, keyframe, mouse, btn, cache, ts_single, d_single) + + new_frames = torch.cat(new_frames, dim = 1) + x = torch.cat([history,new_frames], dim = 1) + + if self.only_return_generated: + x = x[:,-num_frames:] + extended_mouse = extended_mouse[:,-num_frames:] + extended_btn = extended_btn[:,-num_frames:] + + if decode_fn is not None: + x = x * scale + x = decode_fn(x) + + return x, extended_mouse, extended_btn + +class WindowShortcutSampler: + """ + Same as above but with no cache + + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample + :param only_return_generated: Whether to only return the generated frames + """ + def __init__(self, window_length = 60, num_frames = 60, only_return_generated = False): + self.window_length = window_length + self.num_frames = num_frames + self.only_return_generated = only_return_generated + + @torch.no_grad() + def __call__(self, model, history, keyframe, mouse, btn, decode_fn = None, scale = 1): + # history is [b,n,c,h,w] + # mouse is [b,n,2] + # btn is [b,n,n_button] + + # output will be [b,n+self.num_frames,c,h,w] + history = history[:,:self.window_length] + new_frames = [] + alpha = 0.25 # This number is special for our sampler + + # Extended fake controls to use during sampling + extended_mouse, extended_btn = batch_permute_to_length(mouse, btn, self.num_frames + self.window_length) + + # Initialize window history + window_history = history.clone() + + for frame_idx in tqdm(range(self.num_frames)): + # Setup window history + x = window_history[:,-self.window_length:].clone() + + # Noise all but last frame to alpha + x[:,:-1] = zlerp(x[:,:-1], alpha) + # Last frame starts as random noise + x[:,-1] = torch.randn_like(x[:,-1]) + + # Setup timesteps - alpha for context, 1.0 for generated + ts = torch.ones_like(x[:,:,0,0,0]) + ts[:,:-1] = alpha + + # Setup diffusion steps - 4 for context, 1 for generated + d = torch.ones_like(x[:,:,0,0,0]) + d[:,:-1] = 4 + + # Get current controls + curr_mouse = extended_mouse[:,frame_idx:frame_idx+self.window_length] + curr_btn = extended_btn[:,frame_idx:frame_idx+self.window_length] + + # Generate new frame + pred = model.sample(x, keyframe, curr_mouse, curr_btn, None, ts, d) + new_frame = pred[:,-1:] # Take only the last frame + new_frames.append(new_frame) + + # Add new frame to window history + window_history = torch.cat([window_history, new_frame], dim=1) + + new_frames = torch.cat(new_frames, dim=1) + x = torch.cat([history, new_frames], dim=1) + + if self.only_return_generated: + x = x[:,-self.num_frames:] + extended_mouse = extended_mouse[:,-self.num_frames:] + extended_btn = extended_btn[:,-self.num_frames:] + + if decode_fn is not None: + x = x * scale + x = decode_fn(x) + + return x, extended_mouse, extended_btn + +<<<<<<< HEAD + +class InferenceWindowShortcutSamplerNoKeyframe: + """ + Window-based shortcut sampler without keyframe conditioning or KV cache. + Generates frames using sliding window approach with diffusion forcing. + + :param model: The shortcut diffusion model + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample per generate_frames call + :param only_return_generated: Whether to only return the generated frames + :param vae_scale: Scale factor for VAE decoding + :param decode_fn: Optional decoder function + :param initial_history_pt_path: Path to pre-encoded initial history tensor + :param initial_history_mp4_path: Path to MP4 file for initial history + :param encoder: Encoder module (required if using MP4 path) + """ + + ALPHA = 0.25 # Noise level for context frames (step 3 of 4-step diffusion) + + def __init__(self, + model, + window_length = 60, + num_frames = 1, + only_return_generated = False, + vae_scale = 2.17, # TODO Shab trained a new VAE so this needs to be updated. + decode_fn: Optional[Module] = None, + initial_history_pt_path: Optional[pathlib.Path] = None, + initial_history_mp4_path: Optional[pathlib.Path] = None, + encoder: Optional[Module] = None): + + self.model = model + self.window_length = window_length + self.num_frames = num_frames + self.vae_scale = vae_scale + self.only_return_generated = only_return_generated + self.decode_fn = decode_fn + + self.initial_history_pt_path = initial_history_pt_path + self.initial_history_mp4_path = initial_history_mp4_path + self.encoder = encoder + + assert initial_history_pt_path is not None or initial_history_mp4_path is not None, \ + 'Either initial_history_pt_path or initial_history_mp4_path must be provided' + + if initial_history_mp4_path is not None: + assert encoder is not None, \ + 'Encoder must be provided if initial_history_mp4_path is provided' + + self.initial_history_bWchw = self.init_history(self.initial_history_pt_path, self.initial_history_mp4_path) + + def init_history(self, + initial_history_pt_path: pathlib.Path | None, + initial_history_mp4_path: pathlib.Path | None) -> torch.Tensor: + """Initialize history from either .pt file or MP4 file""" + + if initial_history_pt_path is not None: + history_wchw = torch.load(initial_history_pt_path) + else: + history_wrgb = load_mp4_as_tensor(initial_history_mp4_path).unsqueeze(0) # add batch dim + history_wchw = self.encoder(history_wrgb) + # Save encoded version to avoid re-encoding + torch.save(history_wchw, initial_history_mp4_path.absolute().replace('.mp4', '.pt')) + + N = self.window_length + C = self.model.config.channels if hasattr(self.model, 'config') else history_wchw.shape[2] + H = W = int(math.sqrt(self.model.config.tokens_per_frame)) if hasattr(self.model, 'config') else history_wchw.shape[3] + + assert tuple(history_wchw.shape) == (1, N, C, H, W), \ + f'Initial history must have shape (B=1, {N=}, {C=}, {H=}, {W=}), ' \ + f'but got {tuple(history_wchw.shape)}' + + return history_wchw + + def __call__(self, + window_history_bWchw, # [B, W, c, h, w] - Current window of frames + mouse_bW2, # [B, W, 2] - Mouse actions for window + button_bW11, # [B, W, 11] - Button actions for window + ) -> torch.Tensor: # [B, 1, c, h, w] - Generated frame + """Generate a single frame given current window history""" + + # Setup window for generation + x = window_history_bWchw[:, -self.window_length:].clone() + + # Noise all but last frame to alpha level (diffusion forcing) + x[:, :-1] = zlerp(x[:, :-1], self.ALPHA) + # Last frame starts as random noise + x[:, -1] = torch.randn_like(x[:, -1]) + + # Setup timesteps - ALPHA for context frames, 1.0 for generated frame + ts = torch.ones_like(x[:, :, 0, 0, 0]) + ts[:, :-1] = self.ALPHA + + # Setup diffusion steps - 4 for context frames, 1 for generated frame + d = torch.ones_like(x[:, :, 0, 0, 0]) + d[:, :-1] = 4 # Context frames use 4-step budget + + # Generate new frame using window + pred = self.model.sample(x, mouse_bW2, button_bW11, None, ts, d) + new_frame = pred[:, -1:] # Take only the last (generated) frame + + return new_frame + + @torch.no_grad() + def generate_frames(self, + history_bWchw, # [B, W, c, h, w] - Initial history + mouse_bT2, # [B, W+N, 2] - Mouse actions for entire sequence + button_bT11, # [B, W+N, 11] - Button actions for entire sequence + ) -> torch.Tensor: # [B, W+N, c, h, w] - Generated sequence + """Generate multiple frames using sliding window approach""" + + # Handle batch dimension + if history_bWchw.ndim == 4: + history_bWchw = history_bWchw.unsqueeze(0) + + history_bWchw = history_bWchw[:, -self.window_length:] + + assert history_bWchw.shape[1] == self.window_length, \ + f'History must be exactly {self.window_length} frames long, but got {history_bWchw.shape[1]}' + + # Extended controls for generation + extended_mouse, extended_btn = batch_permute_to_length( + mouse_bT2[:, :self.window_length], + button_bT11[:, :self.window_length], + self.num_frames + self.window_length + ) + + # Initialize window history + window_history = history_bWchw.clone() + frames_latent = [] + + for frame_idx in range(self.num_frames): + # Get current window controls + curr_mouse = extended_mouse[:, frame_idx:frame_idx + self.window_length] + curr_btn = extended_btn[:, frame_idx:frame_idx + self.window_length] + + # Generate single frame + new_frame = self.__call__( + window_history_bWchw=window_history, + mouse_bW2=curr_mouse, + button_bW11=curr_btn + ) + + frames_latent.append(new_frame) + + # Add new frame to window history for next iteration + window_history = torch.cat([window_history, new_frame], dim=1) + + # Combine all generated frames + frames_latent = torch.cat(frames_latent, dim=1) + + # Combine with original history + full_sequence = torch.cat([history_bWchw, frames_latent], dim=1) + + if self.only_return_generated: + full_sequence = full_sequence[:, -self.num_frames:] + extended_mouse = extended_mouse[:, -self.num_frames:] + extended_btn = extended_btn[:, -self.num_frames:] + + if self.decode_fn is not None: + frames_rgb = self.decode_fn(full_sequence * self.vae_scale) + return frames_rgb, extended_mouse, extended_btn + + return full_sequence, extended_mouse, extended_btn +======= +class WindowShortcutSamplerNoKeyframe: + """ + Same as above but with no cache + + :param window_length: Number of frames to use for each frame generation step + :param num_frames: Number of new frames to sample + :param only_return_generated: Whether to only return the generated frames + """ + def __init__(self, window_length = 60, num_frames = 60, only_return_generated = False): + self.window_length = window_length + self.num_frames = num_frames + self.only_return_generated = only_return_generated + + @torch.no_grad() + def __call__(self, model, history, mouse, btn, decode_fn = None, scale = 1): + # history is [b,n,c,h,w] + # mouse is [b,n,2] + # btn is [b,n,n_button] + + # output will be [b,n+self.num_frames,c,h,w] + history = history[:,:self.window_length] + new_frames = [] + alpha = 0.25 # This number is special for our sampler + + # Extended fake controls to use during sampling + extended_mouse, extended_btn = batch_permute_to_length(mouse, btn, self.num_frames + self.window_length) + + # Initialize window history + window_history = history.clone() + + for frame_idx in tqdm(range(self.num_frames)): + # Setup window history + x = window_history[:,-self.window_length:].clone() + + # Noise all but last frame to alpha + x[:,:-1] = zlerp(x[:,:-1], alpha) + # Last frame starts as random noise + x[:,-1] = torch.randn_like(x[:,-1]) + + # Setup timesteps - alpha for context, 1.0 for generated + ts = torch.ones_like(x[:,:,0,0,0]) + ts[:,:-1] = alpha + + # Setup diffusion steps - 4 for context, 1 for generated + d = torch.ones_like(x[:,:,0,0,0]) + d[:,:-1] = 4 + + # Get current controls + curr_mouse = extended_mouse[:,frame_idx:frame_idx+self.window_length] + curr_btn = extended_btn[:,frame_idx:frame_idx+self.window_length] + + # Generate new frame + pred = model.sample(x, curr_mouse, curr_btn, None, ts, d) + new_frame = pred[:,-1:] # Take only the last frame + new_frames.append(new_frame) + + # Add new frame to window history + window_history = torch.cat([window_history, new_frame], dim=1) + + new_frames = torch.cat(new_frames, dim=1) + x = torch.cat([history, new_frames], dim=1) + + if self.only_return_generated: + x = x[:,-self.num_frames:] + extended_mouse = extended_mouse[:,-self.num_frames:] + extended_btn = extended_btn[:,-self.num_frames:] + + if decode_fn is not None: + x = x * scale + x = decode_fn(x) + + return x, extended_mouse, extended_btn +>>>>>>> causvid diff --git a/owl_wms/sampling/simple.py b/owl_wms/sampling/simple.py index 14d75f66..602cd7b6 100644 --- a/owl_wms/sampling/simple.py +++ b/owl_wms/sampling/simple.py @@ -19,9 +19,8 @@ def __call__(self, model, dummy_batch, mouse, btn, decode_fn = None, scale = 1): ts = ts - dt if decode_fn is not None: - x = x * scale - x = decode_fn(x) - return x, mouse, btn + pixels = decode_fn(x * scale) + return x, pixels, mouse, btn class InpaintSimpleSampler: def __init__(self, n_steps=64): @@ -47,15 +46,15 @@ def __call__(self, model, dummy_batch, mouse, btn, decode_fn = None, scale = 1): ts[:, mid:] = ts[:, mid:] - dt if decode_fn is not None: - x = x * scale - x = decode_fn(x) - return x, mouse, btn + pixels = decode_fn(x * scale) + return x, pixels, mouse, btn if __name__ == "__main__": model = lambda x,t,m,b: x - sampler = Sampler() + sampler = SimpleSampler() x = sampler(model, torch.randn(4, 3, 64, 64), torch.randn(4, 2), torch.randn(4, 8)) - print(x.shape) \ No newline at end of file + print(x.shape) + \ No newline at end of file diff --git a/owl_wms/sampling/window.py b/owl_wms/sampling/window.py index 5fdf837e..ff220054 100644 --- a/owl_wms/sampling/window.py +++ b/owl_wms/sampling/window.py @@ -95,11 +95,11 @@ def step_history(): extended_mouse = extended_mouse[:,-num_frames:] extended_btn = extended_btn[:,-num_frames:] + pixels = None if decode_fn is not None: - x = x * scale - x = decode_fn(x) - - return x, extended_mouse, extended_btn + pixels = decode_fn(x * scale) + + return x, pixels, extended_mouse, extended_btn def test_window_cfg_sampler(): diff --git a/owl_wms/trainers/__init__.py b/owl_wms/trainers/__init__.py index 86d17cba..e6133a31 100644 --- a/owl_wms/trainers/__init__.py +++ b/owl_wms/trainers/__init__.py @@ -2,4 +2,19 @@ def get_trainer_cls(trainer_id): if trainer_id == "rft": - return RFTTrainer \ No newline at end of file + return RFTTrainer + if trainer_id == "causvid": + from .causvid import CausVidTrainer + return CausVidTrainer + if trainer_id == "shortcut": + from .shortcut_trainer import ShortcutTrainer + return ShortcutTrainer + if trainer_id == "self_forcing": + from .self_forcing import SelfForcingTrainer + return SelfForcingTrainer + if trainer_id == "shortcut_2": + from .shortcut_trainer_2 import ShortcutTrainer + return ShortcutTrainer + if trainer_id == "av": + from .av_trainer import AVRFTTrainer + return AVRFTTrainer diff --git a/owl_wms/trainers/av_trainer.py b/owl_wms/trainers/av_trainer.py new file mode 100644 index 00000000..4aa78301 --- /dev/null +++ b/owl_wms/trainers/av_trainer.py @@ -0,0 +1,217 @@ +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo + +from .base import BaseTrainer + +from ..utils import freeze, Timer, find_unused_params +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb_av +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn, make_batched_audio_decode_fn + +class AVRFTTrainer(BaseTrainer): + """ + Trainer for rectified flow transformer + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices. + :param local_rank: Rank for current device on this process. + :param world_size: Overall number of devices + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + model_id = self.model_cfg.model_id + self.model = get_model_cls(model_id)(self.model_cfg) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.scheduler = None + self.scaler = None + + self.total_step_counter = 0 + self.decoder = get_decoder_only( + self.train_cfg.vae_id, + self.train_cfg.vae_cfg_path, + self.train_cfg.vae_ckpt_path + ) + + self.audio_decoder = get_decoder_only( + self.train_cfg.audio_vae_id, + self.train_cfg.audio_vae_cfg_path, + self.train_cfg.audio_vae_ckpt_path + ) + + freeze(self.decoder) + + def save(self): + save_dict = { + 'model' : self.model.state_dict(), + 'ema' : self.ema.state_dict(), + 'opt' : self.opt.state_dict(), + 'scaler' : self.scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare model and ema + self.model = self.model.cuda().train() + if self.world_size > 1: + self.model = DDP(self.model, device_ids=[self.local_rank]) + self.decoder = self.decoder.cuda().eval().bfloat16() + self.audio_decoder = self.audio_decoder.cuda().eval().bfloat16() + + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + audio_decode_fn = make_batched_audio_decode_fn(self.audio_decoder, self.train_cfg.vae_batch_size) + + self.ema = EMA( + self.model, + beta = 0.999, + update_after_step = 0, + update_every = 1 + ) + #torch.compile(self.ema.ema_model.module.core if self.world_size > 1 else self.ema.ema_model.core, dynamic=False, fullgraph=True) + + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + # Set up optimizer and scheduler + if self.train_cfg.opt.lower() == "muon": + self.opt = init_muon(self.model, rank=self.rank,world_size=self.world_size,**self.train_cfg.opt_kwargs) + else: + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + + if self.train_cfg.scheduler is not None: + self.scheduler = get_scheduler_cls(self.train_cfg.scheduler)(self.opt, **self.train_cfg.scheduler_kwargs) + + # Grad accum setup and scaler + accum_steps = self.train_cfg.target_batch_size // self.train_cfg.batch_size // self.world_size + accum_steps = max(1, accum_steps) + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda',torch.bfloat16) + + self.load() + + # Timer reset + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log = 'all') + + # Dataset setup + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + sampler = get_sampler_cls(self.train_cfg.sampler_id)(**self.train_cfg.sampler_kwargs) + + local_step = 0 + for _ in range(self.train_cfg.epochs): + for batch_vid, batch_audio, batch_mouse, batch_btn in loader: + batch_vid = batch_vid.cuda().bfloat16() / self.train_cfg.vae_scale + batch_audio = batch_audio.cuda().bfloat16() / self.train_cfg.audio_vae_scale + batch_mouse = batch_mouse.cuda().bfloat16() + batch_btn = batch_btn.cuda().bfloat16() + + with ctx: + loss = self.model(batch_vid,batch_audio,batch_mouse,batch_btn) / accum_steps + + self.scaler.scale(loss).backward() + #find_unused_params(self.model) + + metrics.log('diffusion_loss', loss) + + local_step += 1 + if local_step % accum_steps == 0: + # Updates + if self.train_cfg.opt.lower() != "muon": + self.scaler.unscale_(self.opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + self.scaler.step(self.opt) + self.opt.zero_grad(set_to_none=True) + + self.scaler.update() + + if self.scheduler is not None: + self.scheduler.step() + self.ema.update() + + # Do logging + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + timer.reset() + + # Sampling commented out for now + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + samples, audio, sample_mouse, sample_button = sampler( + get_ema_core(), + batch_vid[:n_samples], + batch_audio[:n_samples], + batch_mouse[:n_samples], + batch_btn[:n_samples], + decode_fn, + audio_decode_fn, + self.train_cfg.vae_scale, + self.train_cfg.audio_vae_scale + ) # -> [b,n,c,h,w] + if self.rank == 0: + video, audio = to_wandb_av(samples, audio, sample_mouse, sample_button) + wandb_dict['samples'] = video + wandb_dict['audio_samples'] = audio + + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() \ No newline at end of file diff --git a/owl_wms/trainers/causvid.py b/owl_wms/trainers/causvid.py new file mode 100644 index 00000000..4c586d1f --- /dev/null +++ b/owl_wms/trainers/causvid.py @@ -0,0 +1,270 @@ +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo +from copy import deepcopy + +from .base import BaseTrainer + +from ..utils import freeze, unfreeze, Timer, find_unused_params, versatile_load +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn + +class CausVidTrainer(BaseTrainer): + """ + CausVid Trainer + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices. + :param local_rank: Rank for current device on this process. + :param world_size: Overall number of devices + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + model_id = self.model_cfg.model_id + + student_cfg = deepcopy(self.model_cfg) + teacher_cfg = deepcopy(self.model_cfg) + + student_cfg.causal = True + teacher_cfg.causal = False + + self.model = get_model_cls(model_id)(student_cfg) + self.score_real = get_model_cls(model_id)(teacher_cfg) + + self.score_real.load_state_dict(versatile_load(self.train_cfg.teacher_ckpt)) + self.score_fake = deepcopy(self.score_real) + + freeze(self.score_real) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.s_fake_opt = None + self.scheduler = None + self.s_fake_scaler = None + self.scaler = None + + self.total_step_counter = 0 + self.decoder = get_decoder_only() + freeze(self.decoder) + + def save(self): + save_dict = { + 'model' : self.model.state_dict(), + 'ema' : self.ema.state_dict(), + 'opt' : self.opt.state_dict(), + 'scaler' : self.scaler.state_dict(), + 'score_fake': self.score_fake.state_dict(), + 's_fake_opt': self.s_fake_opt.state_dict(), + 's_fake_scaler': self.s_fake_scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.score_fake.load_state_dict(save_dict['score_fake']) + self.s_fake_opt.load_state_dict(save_dict['s_fake_opt']) + self.s_fake_scaler.load_state_dict(save_dict['s_fake_scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare model and ema + self.model = self.model.cuda().train() + self.decoder = self.decoder.cuda().eval().bfloat16() + self.score_real = self.score_real.cuda().eval().bfloat16() + self.score_fake = self.score_fake.cuda().train() + + if self.world_size > 1: + self.model = DDP(self.model) + self.score_fake = DDP(self.score_fake) + + freeze(self.decoder) + freeze(self.score_real) + + #torch.compile(self.score_real, dynamic = False) + + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + + self.ema = EMA( + self.model, + beta = 0.999, + update_after_step = 0, + update_every = 1 + ) + # Hard coded stuff, probably #TODO figure out where to put this? + self.update_ratio = 5 + self.cfg_scale = 1.3 + + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + # Don't use MUON pls + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + self.s_fake_opt = getattr(torch.optim, self.train_cfg.opt)(self.score_fake.parameters(), **self.train_cfg.opt_kwargs) + + if self.train_cfg.scheduler is not None: + self.scheduler = get_scheduler_cls(self.train_cfg.scheduler)(self.opt, **self.train_cfg.scheduler_kwargs) + + # Scaler + self.s_fake_scaler = torch.amp.GradScaler() + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda',torch.bfloat16) + + self.load() + + # Timer reset + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log = 'all') + + # Dataset setup + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + sampler = get_sampler_cls(self.train_cfg.sampler_id)() + + # Simplifiying assumptions: data will never stop iter, no grad accum + + def sample_from_gen(vid, mouse, btn): + model_out = self.model(vid, mouse, btn, return_dict = True) + ts = model_out['ts'][:,None,None,None] # [b,n,c,h,w] + lerpd = model_out['lerpd'] # [b,n,c,h,w] + pred = model_out['pred'] # [b,n,c,h,w] + + samples = lerpd - pred*ts + return samples + + def get_dmd_loss(vid, mouse, btn): + s_real_fn = self.score_real.core + s_fake_fn = self.score_fake.module.core + + with torch.no_grad(): + b,n,c,h,w = vid.shape + ts = torch.randn(b,n,device=vid.device,dtype=vid.dtype).sigmoid() + z = torch.randn_like(vid) + ts_exp = ts[:,:,None,None,None] + lerpd = vid * (1. - ts_exp) + z * ts_exp + + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + s_real_uncond = s_real_fn(lerpd, ts, null_mouse, null_btn) + s_real_cond = s_real_fn(lerpd, ts, mouse, btn) + s_real = s_real_uncond + self.cfg_scale * (s_real_cond - s_real_uncond) + + s_fake = s_fake_fn(lerpd, ts, mouse, btn) + + grad = (s_fake - s_real) + + # Normalizer? + p_real = (vid - s_real) + normalizer = torch.abs(p_real).mean(dim=[1,2,3,4],keepdim=True) + grad = grad / (normalizer + 1.0e-6) + + grad = torch.nan_to_num(grad) + dmd_loss = 0.5 * F.mse_loss(vid.double(), vid.double() - grad.double()) + # ^ simplify to 0.5 * 2 * (vid - vid + grad) = grad, neat! + return dmd_loss + + def optimizer_step(loss, model, scaler, optimizer): + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + optimizer.zero_grad(set_to_none=True) + scaler.update() + + loader = iter(loader) + while True: + freeze(self.model) + unfreeze(self.score_fake) + for _ in range(self.update_ratio): + batch_vid, batch_mouse, batch_btn = next(loader) + with ctx: + with torch.no_grad(): + samples = sample_from_gen(batch_vid, batch_mouse, batch_btn) + s_fake_loss = self.score_fake(samples, batch_mouse, batch_btn) + + optimizer_step(s_fake_loss, self.score_fake, self.s_fake_scaler, self.s_fake_opt) + + metrics.log('s_fake_loss', s_fake_loss) + unfreeze(self.model) + freeze(self.score_fake) + + batch_vid, batch_mouse, batch_btn = next(loader) + with ctx: + samples = sample_from_gen(batch_vid, batch_mouse, batch_btn) + dmd_loss = get_dmd_loss(samples, batch_mouse, batch_btn) + metrics.log('dmd_loss', dmd_loss) + + optimizer_step(dmd_loss, self.model, self.scaler, self.opt) + self.ema.update() + + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + timer.reset() + + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + samples, sample_mouse, sample_button = sampler( + get_ema_core(), + batch_vid[:n_samples], + batch_mouse[:n_samples], + batch_btn[:n_samples], + decode_fn = decode_fn, + scale=self.train_cfg.vae_scale + ) # -> [b,n,c,h,w] + if self.rank == 0: wandb_dict['samples'] = to_wandb(samples, sample_mouse, sample_button) + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() \ No newline at end of file diff --git a/owl_wms/trainers/gamerft_trainer.py b/owl_wms/trainers/gamerft_trainer.py index 0b6cfe93..998cad82 100644 --- a/owl_wms/trainers/gamerft_trainer.py +++ b/owl_wms/trainers/gamerft_trainer.py @@ -1,7 +1,3 @@ -""" -Trainer for reconstruction only -""" - import torch from ema_pytorch import EMA import wandb @@ -184,7 +180,8 @@ def get_ema_core(): if self.total_step_counter % self.train_cfg.sample_interval == 0: with ctx, torch.no_grad(): n_samples = self.train_cfg.n_samples - samples, sample_mouse, sample_button = sampler( + + latents, samples, sample_mouse, sample_button = sampler( get_ema_core(), batch_vid[:n_samples], batch_mouse[:n_samples], diff --git a/owl_wms/trainers/self_forcing.py b/owl_wms/trainers/self_forcing.py new file mode 100644 index 00000000..0d9a17da --- /dev/null +++ b/owl_wms/trainers/self_forcing.py @@ -0,0 +1,509 @@ +""" +Self-Forcing Trainer for Game World Model +Implements autoregressive self-rollout training with proper gradient truncation +""" + +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo +from copy import deepcopy + +from .base import BaseTrainer +from ..utils import freeze, unfreeze, Timer, find_unused_params, versatile_load +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn +from copy import deepcopy + +class SelfForcingTrainer(BaseTrainer): + """ + Self-Forcing Trainer implementing autoregressive self-rollout training + + Key differences from CausVid: + 1. Uses self-generated context during training (not ground truth) + 2. Implements gradient truncation with stochastic steps + 3. Supports DMD, SiD, and GAN losses on correct distribution + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model_id = self.model_cfg.model_id + + # Create student (causal) and teacher (non-causal) configs + student_cfg = deepcopy(self.model_cfg) + teacher_cfg = deepcopy(self.model_cfg) + + student_cfg.causal = True + teacher_cfg.causal = False + + # Initialize models + self.model = get_model_cls(model_id)(student_cfg) + self.score_real = get_model_cls(model_id)(teacher_cfg) + + # Load pretrained teacher + if self.train_cfg.teacher_ckpt: + self.score_real.load_state_dict(versatile_load(self.train_cfg.teacher_ckpt)) + freeze(self.score_real) + + # Initialize fake score for DMD/SiD losses + self.score_fake = deepcopy(self.score_real) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.s_fake_opt = None + self.scheduler = None + self.s_fake_scaler = None + self.scaler = None + + self.total_step_counter = 0 + + # Initialize VAE decoder + self.decoder = get_decoder_only( + self.train_cfg.vae_id, + self.train_cfg.vae_cfg_path, + self.train_cfg.vae_ckpt_path + ) + freeze(self.decoder) + + # Self-forcing specific parameters + self.loss_type = self.train_cfg.get('loss_type', 'dmd') # dmd, sid, or gan + self.gradient_steps = self.train_cfg.get('gradient_steps', 1) # Number of steps to backprop + self.rollout_steps = self.train_cfg.get('rollout_steps', 5) # Total rollout length + self.stochastic_steps = self.train_cfg.get('stochastic_steps', True) # Random gradient truncation + self.update_ratio = self.train_cfg.get('update_ratio', 5) # Critic updates per generator update + self.cfg_scale = self.train_cfg.get('cfg_scale', 1.3) + + def save(self): + save_dict = { + 'model': self.model.state_dict(), + 'ema': self.ema.state_dict(), + 'opt': self.opt.state_dict(), + 'scaler': self.scaler.state_dict(), + 'score_fake': self.score_fake.state_dict(), + 's_fake_opt': self.s_fake_opt.state_dict(), + 's_fake_scaler': self.s_fake_scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.score_fake.load_state_dict(save_dict['score_fake']) + self.s_fake_opt.load_state_dict(save_dict['s_fake_opt']) + self.s_fake_scaler.load_state_dict(save_dict['s_fake_scaler']) + self.total_step_counter = save_dict['steps'] + + def autoregressive_rollout(self, initial_latent, mouse, btn, decode_fn=None): + """ + Perform autoregressive rollout with gradient truncation + + Args: + initial_latent: Initial frame(s) [b, init_frames, c, h, w] + mouse: Mouse inputs for entire sequence [b, n, 2] + btn: Button inputs for entire sequence [b, n, n_buttons] + decode_fn: Optional decode function for visualization + + Returns: + generated_latents: Full generated sequence + gradient_mask: Boolean mask indicating which frames get gradients + """ + b = initial_latent.shape[0] + device = initial_latent.device + + # Initialize output with initial frames + generated_latents = [initial_latent] + + # Determine gradient truncation point + if self.stochastic_steps: + # Randomly select which steps to backprop through + grad_start = torch.randint( + max(0, self.rollout_steps - self.gradient_steps), + self.rollout_steps, + (1,) + ).item() + else: + # Always backprop through last gradient_steps + grad_start = self.rollout_steps - self.gradient_steps + + # Generate frames autoregressively + for step in range(self.rollout_steps): + # Get context from previously generated frames + context = torch.cat(generated_latents, dim=1) + context_frames = context.shape[1] + + # Get corresponding actions + step_mouse = mouse[:, :context_frames] + step_btn = btn[:, :context_frames] + + # Determine if this step needs gradients + needs_grad = step >= grad_start + + with torch.set_grad_enabled(needs_grad): + # Add noise to last frame for next prediction + noisy_next = torch.randn(b, 1, *initial_latent.shape[2:], device=device) + + # Prepare input + model_input = torch.cat([context, noisy_next], dim=1) + model_mouse = mouse[:, :context_frames + 1] + model_btn = btn[:, :context_frames + 1] + + # Generate next frame + with torch.amp.autocast('cuda', torch.bfloat16): + # Run diffusion model to denoise + ts = torch.ones(b, context_frames + 1, device=device) + ts[:, -1] = 0.99 # High noise for last frame + ts[:, :-1] = 0.0 # Clean context + + pred = self.model.core(model_input, ts, model_mouse, model_btn) + next_frame = model_input - pred * ts[:, :, None, None, None] + next_frame = next_frame[:, -1:] # Take only the newly generated frame + + generated_latents.append(next_frame) + + # Concatenate all generated frames + full_sequence = torch.cat(generated_latents, dim=1) + + # Create gradient mask + gradient_mask = torch.zeros(b, full_sequence.shape[1], dtype=torch.bool, device=device) + gradient_mask[:, initial_latent.shape[1] + grad_start:] = True + + return full_sequence, gradient_mask + + def compute_dmd_loss(self, generated, mouse, btn, gradient_mask): + """Compute DMD loss on generated sequence""" + s_real_fn = self.score_real.core + s_fake_fn = self.score_fake.module.core if self.world_size > 1 else self.score_fake.core + + with torch.no_grad(): + b, n, c, h, w = generated.shape + ts = torch.rand(b, n, device=generated.device).sigmoid() + z = torch.randn_like(generated) + ts_exp = ts[:, :, None, None, None] + lerpd = generated * (1. - ts_exp) + z * ts_exp + + # Compute real score with CFG + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + s_real_uncond = s_real_fn(lerpd, ts, null_mouse, null_btn) + s_real_cond = s_real_fn(lerpd, ts, mouse, btn) + s_real = s_real_uncond + self.cfg_scale * (s_real_cond - s_real_uncond) + + # Compute fake score + s_fake = s_fake_fn(lerpd, ts, mouse, btn) + + # DMD gradient + grad = (s_fake - s_real) + + # Normalize + p_real = (generated - s_real) + normalizer = torch.abs(p_real).mean(dim=[2, 3, 4], keepdim=True) + grad = grad / (normalizer + 1e-6) + grad = torch.nan_to_num(grad) + + # Apply gradient mask + if gradient_mask is not None: + grad = grad * gradient_mask[:, :, None, None, None] + + dmd_loss = 0.5 * F.mse_loss(generated.double(), (generated - grad).double()) + + return dmd_loss + + def compute_sid_loss(self, generated, mouse, btn, gradient_mask): + """ + Compute Score identity Distillation (SiD) loss on generated sequence + Based on "Score identity Distillation" paper + """ + s_real_fn = self.score_real.core + s_fake_fn = self.score_fake.module.core if self.world_size > 1 else self.score_fake.core + + with torch.no_grad(): + b, n, c, h, w = generated.shape + ts = torch.rand(b, n, device=generated.device).sigmoid() + z = torch.randn_like(generated) + ts_exp = ts[:, :, None, None, None] + lerpd = generated * (1. - ts_exp) + z * ts_exp + + # Compute real score with CFG + null_mouse = torch.zeros_like(mouse) + null_btn = torch.zeros_like(btn) + + s_real_uncond = s_real_fn(lerpd, ts, null_mouse, null_btn) + s_real_cond = s_real_fn(lerpd, ts, mouse, btn) + s_real = s_real_uncond + self.cfg_scale * (s_real_cond - s_real_uncond) + + # Compute fake score + s_fake = s_fake_fn(lerpd, ts, mouse, btn) + + # SiD loss formulation + # L = (s_real - s_fake) * ((s_real - x) - α(s_real - s_fake)) + alpha = self.train_cfg.get('sid_alpha', 1.0) + + diff_score = s_real - s_fake + diff_real = s_real - generated + + sid_loss = diff_score * (diff_real - alpha * diff_score) + + # Normalize + with torch.no_grad(): + normalizer = torch.abs(diff_real).mean(dim=[2, 3, 4], keepdim=True) + sid_loss = sid_loss / (normalizer + 1e-6) + + # Apply gradient mask + if gradient_mask is not None: + sid_loss = sid_loss * gradient_mask[:, :, None, None, None] + + sid_loss = torch.nan_to_num(sid_loss).mean() + + return sid_loss + + def compute_gan_loss(self, generated, mouse, btn, gradient_mask, train_generator=True): + """ + Simplified GAN loss using the score networks as discriminators + This avoids needing a separate discriminator implementation + """ + with torch.no_grad(): + b, n, c, h, w = generated.shape + # Use a fixed small noise level for discrimination + ts = torch.ones(b, n, device=generated.device) * 0.01 + noise = torch.randn_like(generated) * 0.01 + noisy_generated = generated + noise + + if train_generator: + # Generator loss: make fake score match real score + s_fake = self.score_fake.module.core(noisy_generated, ts, mouse, btn) if self.world_size > 1 else self.score_fake.core(noisy_generated, ts, mouse, btn) + s_real = self.score_real.core(noisy_generated, ts, mouse, btn) + + # L2 loss between scores + score_diff = (s_fake - s_real) ** 2 + + # Apply gradient mask + if gradient_mask is not None: + score_diff = score_diff * gradient_mask[:, :, None, None, None] + + gan_loss = score_diff.mean() + return gan_loss + else: + # Train fake score to distinguish real from fake + # This is similar to standard DMD critic training + with torch.no_grad(): + # For real data, we'd need ground truth + # For now, return a placeholder + return torch.tensor(0.0, device=generated.device) + + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare models + self.model = self.model.cuda().train() + self.decoder = self.decoder.cuda().eval().bfloat16() + self.score_real = self.score_real.cuda().eval().bfloat16() + self.score_fake = self.score_fake.cuda().train() + + if self.world_size > 1: + self.model = DDP(self.model) + self.score_fake = DDP(self.score_fake) + + freeze(self.decoder) + freeze(self.score_real) + + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + + # Initialize EMA + self.ema = EMA( + self.model, + beta=0.999, + update_after_step=0, + update_every=1 + ) + + # Initialize optimizers + self.opt = getattr(torch.optim, self.train_cfg.opt)( + self.model.parameters(), + **self.train_cfg.opt_kwargs + ) + self.s_fake_opt = getattr(torch.optim, self.train_cfg.opt)( + self.score_fake.parameters(), + **self.train_cfg.opt_kwargs + ) + + if self.train_cfg.scheduler is not None: + self.scheduler = get_scheduler_cls(self.train_cfg.scheduler)( + self.opt, + **self.train_cfg.scheduler_kwargs + ) + + # Scalers for mixed precision + self.s_fake_scaler = torch.amp.GradScaler() + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda', torch.bfloat16) + + self.load() + + # Setup logging + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log='all') + + # Dataset setup + loader = get_loader( + self.train_cfg.data_id, + self.train_cfg.batch_size, + **self.train_cfg.data_kwargs + ) + sampler = get_sampler_cls(self.train_cfg.sampler_id)(**self.train_cfg.sampler_kwargs) + + def optimizer_step(loss, model, scaler, optimizer): + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + optimizer.zero_grad(set_to_none=True) + scaler.update() + + # Training loop + loader = iter(loader) + while True: + # Update critic/fake score + if self.loss_type in ['dmd', 'sid']: + freeze(self.model) + unfreeze(self.score_fake) + + for _ in range(self.update_ratio): + batch_vid, batch_mouse, batch_btn = next(loader) + + # Get initial frames + initial_frames = batch_vid[:, :1] # Use first frame as initial + + with torch.no_grad(): + # Generate sequence autoregressively + generated, _ = self.autoregressive_rollout( + initial_frames, + batch_mouse, + batch_btn + ) + + # Train fake score on generated data + with ctx: + s_fake_loss = self.score_fake( + generated.detach(), + batch_mouse[:, :generated.shape[1]], + batch_btn[:, :generated.shape[1]] + ) + + optimizer_step(s_fake_loss, self.score_fake, self.s_fake_scaler, self.s_fake_opt) + metrics.log('s_fake_loss', s_fake_loss) + + # Update generator + unfreeze(self.model) + freeze(self.score_fake) + + batch_vid, batch_mouse, batch_btn = next(loader) + initial_frames = batch_vid[:, :1] + + # Generate with gradients + with ctx: + generated, gradient_mask = self.autoregressive_rollout( + initial_frames, + batch_mouse, + batch_btn + ) + + # Compute loss based on selected type + if self.loss_type == 'dmd': + loss = self.compute_dmd_loss( + generated, + batch_mouse[:, :generated.shape[1]], + batch_btn[:, :generated.shape[1]], + gradient_mask + ) + elif self.loss_type == 'sid': + loss = self.compute_sid_loss( + generated, + batch_mouse[:, :generated.shape[1]], + batch_btn[:, :generated.shape[1]], + gradient_mask + ) + elif self.loss_type == 'gan': + loss = self.compute_gan_loss( + generated, + batch_mouse[:, :generated.shape[1]], + batch_btn[:, :generated.shape[1]], + gradient_mask, + train_generator=True + ) + + metrics.log(f'{self.loss_type}_loss', loss) + + optimizer_step(loss, self.model, self.scaler, self.opt) + self.ema.update() + + # Logging and visualization + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + timer.reset() + + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + + # Get EMA model for sampling + ema_core = self.ema.ema_model.module.core if self.world_size > 1 else self.ema.ema_model.core + + # Sample using the trained model + samples, sample_mouse, sample_button = sampler( + ema_core, + batch_vid[:n_samples, :1], # Initial frame + batch_mouse[:n_samples], + batch_btn[:n_samples], + decode_fn=decode_fn, + scale=self.train_cfg.vae_scale + ) + + if self.rank == 0: + wandb_dict['samples'] = to_wandb(samples, sample_mouse, sample_button) + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() diff --git a/owl_wms/trainers/shortcut_trainer.py b/owl_wms/trainers/shortcut_trainer.py new file mode 100644 index 00000000..1c96a9a4 --- /dev/null +++ b/owl_wms/trainers/shortcut_trainer.py @@ -0,0 +1,201 @@ +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo + +from .base import BaseTrainer + +from ..utils import freeze, Timer, find_unused_params +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn +from ..models.gamerft_shortcut import get_sc_targets + +class ShortcutTrainer(BaseTrainer): + """ + Trainer for rectified flow transformer with shortcut + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices. + :param local_rank: Rank for current device on this process. + :param world_size: Overall number of devices + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + model_id = self.model_cfg.model_id + self.model = get_model_cls(model_id)(self.model_cfg) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.scheduler = None + self.scaler = None + + self.total_step_counter = 0 + self.decoder = get_decoder_only( + self.train_cfg.vae_id, + self.train_cfg.vae_cfg_path, + self.train_cfg.vae_ckpt_path + ) + + freeze(self.decoder) + + def save(self): + save_dict = { + 'model' : self.model.state_dict(), + 'ema' : self.ema.state_dict(), + 'opt' : self.opt.state_dict(), + 'scaler' : self.scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare model and ema + self.model = self.model.cuda().train() + if self.world_size > 1: + self.model = DDP(self.model, find_unused_parameters=True) + + self.decoder = self.decoder.cuda().eval().bfloat16() + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + + self.ema = EMA( + self.model, + beta = 0.999, + update_after_step = 0, + update_every = 1 + ) + #torch.compile(self.ema.ema_model.module.core if self.world_size > 1 else self.ema.ema_model.core, dynamic=False, fullgraph=True) + + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + # No muon pls + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + + # Grad accum setup and scaler + accum_steps = self.train_cfg.target_batch_size // self.train_cfg.batch_size // self.world_size + accum_steps = max(1, accum_steps) + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda',torch.bfloat16) + + self.load() + + # Timer reset + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log = 'all') + + # Dataset setup + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + sampler = get_sampler_cls(self.train_cfg.sampler_id)(**self.train_cfg.sampler_kwargs) + + local_step = 0 + for _ in range(self.train_cfg.epochs): + for batch_vid, batch_keyframe, batch_mouse, batch_btn in loader: + batch_vid = batch_vid.cuda().bfloat16() / self.train_cfg.vae_scale + batch_keyframe = batch_keyframe.cuda().bfloat16() + batch_mouse = batch_mouse.cuda().bfloat16() + batch_btn = batch_btn.cuda().bfloat16() + + with ctx: + diff_loss, sc_loss = self.model(batch_vid,batch_keyframe,batch_mouse,batch_btn, get_ema_core()) + loss = diff_loss + sc_loss + loss = loss / accum_steps + + self.scaler.scale(loss).backward() + + metrics.log('diffusion_loss', diff_loss) + metrics.log('shortcut_loss', sc_loss) + + local_step += 1 + if local_step % accum_steps == 0: + # Updates + self.scaler.unscale_(self.opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + self.scaler.step(self.opt) + self.opt.zero_grad(set_to_none=True) + + self.scaler.update() + + if self.scheduler is not None: + self.scheduler.step() + self.ema.update() + + # Do logging + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + wandb_dict['lr'] = self.opt.param_groups[0]['lr'] + timer.reset() + + # Sampling commented out for now + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + samples, sample_mouse, sample_button = sampler( + get_ema_core(), + batch_vid[:n_samples], + batch_keyframe[:n_samples], + batch_mouse[:n_samples], + batch_btn[:n_samples], + decode_fn = decode_fn, + scale=self.train_cfg.vae_scale + ) # -> [b,n,c,h,w] + if self.rank == 0: wandb_dict['samples'] = to_wandb(samples, sample_mouse, sample_button) + + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() diff --git a/owl_wms/trainers/shortcut_trainer_2.py b/owl_wms/trainers/shortcut_trainer_2.py new file mode 100644 index 00000000..d0e9309a --- /dev/null +++ b/owl_wms/trainers/shortcut_trainer_2.py @@ -0,0 +1,204 @@ +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo + +from .base import BaseTrainer + +from ..utils import freeze, Timer, find_unused_params +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn +from ..models.gamerft_shortcut import get_sc_targets + +class ShortcutTrainer(BaseTrainer): + """ + Trainer for rectified flow transformer with shortcut + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices. + :param local_rank: Rank for current device on this process. + :param world_size: Overall number of devices + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + model_id = self.model_cfg.model_id + self.model = get_model_cls(model_id)(self.model_cfg) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.scheduler = None + self.scaler = None + + self.total_step_counter = 0 + self.decoder = get_decoder_only( + self.train_cfg.vae_id, + self.train_cfg.vae_cfg_path, + self.train_cfg.vae_ckpt_path + ) + + freeze(self.decoder) + + def save(self): + save_dict = { + 'model' : self.model.state_dict(), + 'ema' : self.ema.state_dict(), + 'opt' : self.opt.state_dict(), + 'scaler' : self.scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare model and ema + self.model = self.model.cuda().train() + if self.world_size > 1: + self.model = DDP(self.model, device_ids=[self.local_rank]) + + self.decoder = self.decoder.cuda().eval().bfloat16() + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + + self.ema = EMA( + self.model, + beta = 0.999, + update_after_step = 0, + update_every = 1 + ) + #torch.compile(self.ema.ema_model.module.core if self.world_size > 1 else self.ema.ema_model.core, dynamic=False, fullgraph=True) + + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + # Set up optimizer and scheduler + if self.train_cfg.opt.lower() == "muon": + self.opt = init_muon(self.model, rank=self.rank,world_size=self.world_size,**self.train_cfg.opt_kwargs) + else: + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + + # Grad accum setup and scaler + accum_steps = self.train_cfg.target_batch_size // self.train_cfg.batch_size // self.world_size + accum_steps = max(1, accum_steps) + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda',torch.bfloat16) + + self.load() + + # Timer reset + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log = 'all') + + # Dataset setup + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + sampler = get_sampler_cls(self.train_cfg.sampler_id)(**self.train_cfg.sampler_kwargs) + + local_step = 0 + for _ in range(self.train_cfg.epochs): + for batch_vid, batch_mouse, batch_btn in loader: + batch_vid = batch_vid.cuda().bfloat16() / self.train_cfg.vae_scale + batch_mouse = batch_mouse.cuda().bfloat16() + batch_btn = batch_btn.cuda().bfloat16() + + with ctx: + diff_loss, sc_loss = self.model(batch_vid,batch_mouse,batch_btn, get_ema_core()) + loss = diff_loss + sc_loss + loss = loss / accum_steps + + self.scaler.scale(loss).backward() + + metrics.log('diffusion_loss', diff_loss) + metrics.log('shortcut_loss', sc_loss) + + local_step += 1 + if local_step % accum_steps == 0: + # Updates + if self.train_cfg.opt.lower() != "muon": + self.scaler.unscale_(self.opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + + self.scaler.step(self.opt) + self.opt.zero_grad(set_to_none=True) + + self.scaler.update() + + if self.scheduler is not None: + self.scheduler.step() + self.ema.update() + + # Do logging + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + wandb_dict['lr'] = self.opt.param_groups[0]['lr'] + timer.reset() + + # Sampling commented out for now + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + samples, sample_mouse, sample_button = sampler( + get_ema_core(), + batch_vid[:n_samples], + batch_mouse[:n_samples], + batch_btn[:n_samples], + decode_fn = decode_fn, + scale=self.train_cfg.vae_scale + ) # -> [b,n,c,h,w] + if self.rank == 0: wandb_dict['samples'] = to_wandb(samples, sample_mouse, sample_button) + + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() diff --git a/owl_wms/trainers/shortcut_trainer_audio.py b/owl_wms/trainers/shortcut_trainer_audio.py new file mode 100644 index 00000000..5f1d7263 --- /dev/null +++ b/owl_wms/trainers/shortcut_trainer_audio.py @@ -0,0 +1,217 @@ +import torch +from ema_pytorch import EMA +import wandb +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import einops as eo + +from .base import BaseTrainer + +from ..utils import freeze, Timer, find_unused_params +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..muon import init_muon +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn, make_batched_audio_decode_fn +from ..models.gamerft_shortcut import get_sc_targets + +class ShortcutTrainer(BaseTrainer): + """ + Trainer for rectified flow transformer with shortcut + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices. + :param local_rank: Rank for current device on this process. + :param world_size: Overall number of devices + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + model_id = self.model_cfg.model_id + self.model = get_model_cls(model_id)(self.model_cfg) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + self.ema = None + self.opt = None + self.scheduler = None + self.scaler = None + + self.total_step_counter = 0 + self.decoder = get_decoder_only( + self.train_cfg.image_vae_id, + self.train_cfg.image_vae_cfg_path, + self.train_cfg.image_vae_ckpt_path + ) + + self.audio_decoder = get_decoder_only( + self.train_cfg.audio_vae_id, + self.train_cfg.audio_vae_cfg_path, + self.train_cfg.audio_vae_ckpt_path + ) + + freeze(self.decoder) + freeze(self.audio_decoder) + + def save(self): + save_dict = { + 'model' : self.model.state_dict(), + 'ema' : self.ema.state_dict(), + 'opt' : self.opt.state_dict(), + 'scaler' : self.scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + torch.cuda.set_device(self.local_rank) + + # Prepare model and ema + self.model = self.model.cuda().train() + if self.world_size > 1: + self.model = DDP(self.model, device_ids=[self.local_rank]) + + self.decoder = self.decoder.cuda().eval().bfloat16() + self.audio_decoder = self.audio_decoder.cuda().eval().bfloat16() + decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + audio_decode_fn = make_batched_audio_decode_fn(self.audio_decoder, self.train_cfg.vae_batch_size) + + self.ema = EMA( + self.model, + beta = 0.999, + update_after_step = 0, + update_every = 1 + ) + #torch.compile(self.ema.ema_model.module.core if self.world_size > 1 else self.ema.ema_model.core, dynamic=False, fullgraph=True) + + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + # Set up optimizer and scheduler + if self.train_cfg.opt.lower() == "muon": + self.opt = init_muon(self.model, rank=self.rank,world_size=self.world_size,**self.train_cfg.opt_kwargs) + else: + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + + # Grad accum setup and scaler + accum_steps = self.train_cfg.target_batch_size // self.train_cfg.batch_size // self.world_size + accum_steps = max(1, accum_steps) + self.scaler = torch.amp.GradScaler() + ctx = torch.amp.autocast('cuda',torch.bfloat16) + + self.load() + + # Timer reset + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log = 'all') + + # Dataset setup + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + sampler = get_sampler_cls(self.train_cfg.sampler_id)(**self.train_cfg.sampler_kwargs) + + local_step = 0 + for _ in range(self.train_cfg.epochs): + for batch_vid, batch_audio, batch_mouse, batch_btn in loader: + batch_vid = batch_vid.cuda().bfloat16() / self.train_cfg.vae_scale + batch_audio = batch_audio.cuda().bfloat16() + batch_mouse = batch_mouse.cuda().bfloat16() + batch_btn = batch_btn.cuda().bfloat16() + + with ctx: + diff_loss, sc_loss = self.model(batch_vid, batch_audio, batch_mouse, batch_btn, get_ema_core()) + loss = diff_loss + sc_loss + loss = loss / accum_steps + + self.scaler.scale(loss).backward() + + metrics.log('diffusion_loss', diff_loss) + metrics.log('shortcut_loss', sc_loss) + + local_step += 1 + if local_step % accum_steps == 0: + # Updates + if self.train_cfg.opt.lower() != "muon": + self.scaler.unscale_(self.opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + + self.scaler.step(self.opt) + self.opt.zero_grad(set_to_none=True) + + self.scaler.update() + + if self.scheduler is not None: + self.scheduler.step() + self.ema.update() + + # Do logging + with torch.no_grad(): + wandb_dict = metrics.pop() + wandb_dict['time'] = timer.hit() + wandb_dict['lr'] = self.opt.param_groups[0]['lr'] + timer.reset() + + # Sampling commented out for now + if self.total_step_counter % self.train_cfg.sample_interval == 0: + with ctx, torch.no_grad(): + n_samples = self.train_cfg.n_samples + samples, sample_audio, sample_mouse, sample_button = sampler( + get_ema_core(), + batch_vid[:n_samples], + batch_audio[:n_samples], + batch_mouse[:n_samples], + batch_btn[:n_samples], + image_decode_fn = decode_fn, + audio_decode_fn = audio_decode_fn, + image_vae_scale=self.train_cfg.image_vae_scale, + audio_vae_scale=self.train_cfg.audio_vae_scale + ) # -> [b,n,c,h,w] + if self.rank == 0: wandb_dict['samples'] = to_wandb_av(samples, sample_audio, sample_mouse, sample_button) + + + if self.rank == 0: + wandb.log(wandb_dict) + + self.total_step_counter += 1 + if self.total_step_counter % self.train_cfg.save_interval == 0: + if self.rank == 0: + self.save() + + self.barrier() diff --git a/owl_wms/utils/__init__.py b/owl_wms/utils/__init__.py index 312b4a21..a88c23ed 100644 --- a/owl_wms/utils/__init__.py +++ b/owl_wms/utils/__init__.py @@ -96,7 +96,7 @@ def batch_permute_to_length(mouse, button, length): # Calculate how many times we need to double n to exceed length n = mouse.shape[1] factor = 0 - doubled_length = n + doubled_length = mouse.shape[1] while doubled_length < length: factor += 1 doubled_length *= 2 diff --git a/owl_wms/utils/ddp.py b/owl_wms/utils/ddp.py index 8bc23e26..7296a3a2 100644 --- a/owl_wms/utils/ddp.py +++ b/owl_wms/utils/ddp.py @@ -23,4 +23,4 @@ def setup(force=False): def cleanup(): if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() \ No newline at end of file + dist.destroy_process_group() diff --git a/owl_wms/utils/logging.py b/owl_wms/utils/logging.py index e2ef70c1..fc6912c1 100644 --- a/owl_wms/utils/logging.py +++ b/owl_wms/utils/logging.py @@ -74,4 +74,35 @@ def to_wandb(x, batch_mouse, batch_btn, gather = False, max_samples = 8): x = eo.rearrange(x, '(r c) n d h w -> n d (r h) (c w)', r = 2, c = 4) return wandb.Video(x, format='gif',fps=60) - \ No newline at end of file + +@torch.no_grad() +def to_wandb_av(x, audio, batch_mouse, batch_btn, gather = False, max_samples = 8): + # x is [b,n,c,h,w] + # audio is [b,n,2] + x = x.clamp(-1, 1) + x = x[:max_samples] + audio = audio[:max_samples] + + if dist.is_initialized() and gather: + gathered_x = [None for _ in range(dist.get_world_size())] + gathered_audio = [None for _ in range(dist.get_world_size())] + dist.all_gather(gathered_x, x) + dist.all_gather(gathered_audio, audio) + x = torch.cat(gathered_x, dim=0) + audio = torch.cat(gathered_audio, dim=0) + + # Get labels on frames + x = draw_frames(x, batch_mouse, batch_btn) # -> [b,n,c,h,w] [0,255] uint8 np + + # Convert audio to numpy float32 [-1,1] + audio = audio.cpu().float().numpy() + + # Create grid of videos like in to_wandb + if max_samples == 8: + x = eo.rearrange(x, '(r c) n d h w -> n d (r h) (c w)', r = 2, c = 4) + + # Create video and audio objects + video = wandb.Video(x, format='gif', fps=60) + audio_samples = [wandb.Audio(audio[i], sample_rate=44100) for i in range(len(audio))] + + return video, audio_samples \ No newline at end of file diff --git a/owl_wms/utils/owl_vae_bridge.py b/owl_wms/utils/owl_vae_bridge.py index 331a8582..4c341a6c 100644 --- a/owl_wms/utils/owl_vae_bridge.py +++ b/owl_wms/utils/owl_vae_bridge.py @@ -25,7 +25,7 @@ def get_decoder_only(vae_id, cfg_path, ckpt_path): model = AutoencoderDC.from_pretrained(model_id).bfloat16().cuda().eval() del model.encoder return model.decoder - elif vae_id == "720pr3dc": + else: cfg = Config.from_yaml(cfg_path).model model = get_model_cls(cfg.model_id)(cfg) model.load_state_dict(torch.load(ckpt_path, map_location='cpu',weights_only=False)) @@ -34,19 +34,34 @@ def get_decoder_only(vae_id, cfg_path, ckpt_path): model = model.bfloat16().cuda().eval() return model +def get_encoder_only(vae_id, cfg_path, ckpt_path): + if vae_id == "dcae": + model_id = "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers" + model = AutoencoderDC.from_pretrained(model_id).bfloat16().cuda().eval() + del model.decoder # Keep encoder only + return model.encoder + elif vae_id == "720pr3dc": + cfg = Config.from_yaml(cfg_path).model + model = get_model_cls(cfg.model_id)(cfg) + model.load_state_dict(torch.load(ckpt_path, map_location='cpu',weights_only=False)) + del model.decoder # Keep encoder only + model = model.encoder + model = model.bfloat16().cuda().eval() + return model + @torch.no_grad() -def _make_batched_decode_fn(decoder, batch_size = 8): +def make_batched_decode_fn(decoder, batch_size = 8): def decode(x): - # x is [b,n,m,d] - b,n,m,d = x.shape - x = x.view(b*n,m,d).contiguous() + # x is [b,n,c,h,w] + b,n,c,h,w = x.shape + x = x.view(b*n,c,h,w).contiguous() batches = x.split(batch_size) batch_out = [] for batch in batches: batch_out.append(decoder(batch).bfloat16()) - x = torch.cat(batch_out) # [b*n,3,256,256] + x = torch.cat(batch_out) # [b*n,c,h,w] _,c,h,w = x.shape x = x.view(b,n,c,h,w).contiguous() @@ -54,20 +69,19 @@ def decode(x): return decode @torch.no_grad() -def make_batched_decode_fn(decoder, batch_size = 8): +def make_batched_audio_decode_fn(decoder, batch_size = 8): def decode(x): - # x is [b,n,c,h,w] - b,n,c,h,w = x.shape - x = x.view(b*n,c,h,w).contiguous() + # x is [b,n,c] audio samples + x = x.transpose(1,2) + b,c,n = x.shape - batches = x.split(batch_size) + batches = x.contiguous().split(batch_size) batch_out = [] for batch in batches: batch_out.append(decoder(batch).bfloat16()) - x = torch.cat(batch_out) # [b*n,c,h,w] - _,c,h,w = x.shape - x = x.view(b,n,c,h,w).contiguous() + x = torch.cat(batch_out) # [b,c,n] + x = x.transpose(-1,-2).contiguous() # [b,n,2] return x - return decode \ No newline at end of file + return decode diff --git a/requirements.txt b/requirements.txt index 909273a1..541bc77d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,26 @@ +--extra-index-url https://download.pytorch.org/whl/cu128 opencv-python wandb einops rotary-embedding-torch ema-pytorch omegaconf +torch +torchvision +multimethod diffusers +vector-quantize-pytorch +torchtyping +imageio +einops +numpy +imageio[ffmpeg] +accelerate +fastapi +uvicorn[standard] accelerate boto3 python-dotenv +wandb[media] +alias-free-torch +taskgroup # only here to backport to 3.10 cause hotel wifi is too slow to remake venv diff --git a/train.py b/train.py index 64576921..9126bbd5 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,9 @@ from owl_wms.utils.ddp import cleanup, setup if __name__ == "__main__": + import sys + sys.argv[1:] = ["--config_path", "configs/self_forcing.yaml"] + parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config YAML file") diff --git a/webapp/action_converter.py b/webapp/action_converter.py new file mode 100644 index 00000000..5c78440d --- /dev/null +++ b/webapp/action_converter.py @@ -0,0 +1,200 @@ +import time +import torch +import asyncio + +from webapp.streaming import StreamingConfig +from torch.nn import functional as F + +BUTTON_NAMES = ["W", "A", "S", "D", "LSHIFT", "SPACE", "R", "F", "E", "LMB", "RMB"] +BUTTON_INDICES = {name: idx for idx, name in enumerate(BUTTON_NAMES)} + + +def _interpolate(actions: torch.Tensor, + empty_action: torch.Tensor, + target_length: int) -> torch.Tensor: + + """ + Interpolate actions to target_length. + If tensor_batch is longer than target_length, subsample. + If tensor_batch is shorter than target_length, repeat with empty actions. + + Must provide empty_action, which is the action to repeat when the batch is shorter than target_length. + Must also provide target_length, which is the length to interpolate to. + Returns: + actions: [target_length, features] + """ + num_actions = actions.shape[0] + + if num_actions >= target_length: + # subsample actions if somehow longer than frames_per_batch + downsampled = torch.arange(0, num_actions, step=(num_actions // target_length))[:target_length] + return actions[downsampled, :] + + # Repeat with empty actions to fill remaining frames + num_missing_actions = target_length - num_actions + if num_missing_actions == target_length: + return empty_action.repeat(target_length, 1) + + # NOTE: Repeat last action for the remaining frames + last_action = actions[-1, :] + repeated = last_action.repeat(num_missing_actions, 1) # [missing_frames, features] + + return torch.cat([actions, repeated], dim=0) # [target_length, features] + + +class ActionConverter: + """Converts WebSocket messages to model tensor format.""" + + def __init__(self, streaming_config: StreamingConfig): + self.streaming_config = streaming_config + self.device = streaming_config.device + + def websocket_to_action(self, ws_message: dict) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert WebSocket message to action tensors. + + Expected ws_message format: + { + "mouse_x": 0.1, # Mouse movement [-1, 1] + "mouse_y": -0.05, + "W": true, # Button states + "LMB": false, + # ... other buttons + } + + Returns: + mouse: [2] tensor + buttons: [n_buttons] tensor + """ + # Extract mouse movement + mouse_x = ws_message.get("mouse_x", 0.0) + mouse_y = ws_message.get("mouse_y", 0.0) + + # Clamp to valid range + mouse_x = max(min(mouse_x, self.streaming_config.mouse_range[1]), self.streaming_config.mouse_range[0]) + mouse_y = max(min(mouse_y, self.streaming_config.mouse_range[1]), self.streaming_config.mouse_range[0]) + + mouse = torch.tensor([mouse_x, mouse_y], device=self.device, dtype=torch.float32) + + # Extract button states + button_states = torch.zeros(self.streaming_config.n_buttons, device=self.device, dtype=torch.float32) + for button_name, idx in BUTTON_INDICES.items(): + if button_name in ws_message: + button_states[idx] = 1.0 if ws_message[button_name] else 0.0 + + return mouse, button_states + + def actions_to_sequence(self, actions: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert list of individual actions to batch tensors. + + Args: + actions: List of (mouse, buttons) tuples + + Returns: + mouse: [sequence_length, 2] + button: [sequence_length, n_buttons] + """ + if not actions: + # Return empty batch - [0, 2] and [0, n_buttons] + return ( + torch.zeros(0, 2, device=self.device), + torch.zeros(0, self.streaming_config.n_buttons, device=self.device) + ) + + mouse = torch.stack([action[0] for action in actions], dim=0) # [seq_len, 2] + button = torch.stack([action[1] for action in actions], dim=0) # [seq_len, n_buttons] + + return mouse, button + + def buttons_to_dict(self, buttons: torch.Tensor) -> dict: + """Convert buttons tensor to dictionary.""" + return {BUTTON_NAMES[i]: bool(buttons[i]) for i in range(buttons.shape[0])} + + def mouse_to_dict(self, mouse: torch.Tensor) -> dict: + """Convert mouse tensor to dictionary.""" + return {"mouse_x": mouse[0].item(), "mouse_y": mouse[1].item()} + + +class ActionCollector: + """Collects real-time actions into 8-frame batches.""" + + def __init__(self, streaming_config: StreamingConfig): + self.streaming_config = streaming_config + self.converter = ActionConverter(streaming_config) + self.action_queue = asyncio.Queue(maxsize=100) # Buffer incoming actions + + # an empty action is for one frame only + self.empty_mouse = torch.zeros((1, streaming_config.n_mouse_axes), + device=streaming_config.device, dtype=torch.float32) + self.empty_buttons = torch.zeros((1, streaming_config.n_buttons), + device=streaming_config.device, dtype=torch.bool) + + async def add_websocket_action(self, ws_message: dict): + """Add action from WebSocket message.""" + action = self.converter.websocket_to_action(ws_message) + # Add timestamp to track when action was received + timestamped_action = (action, time.time()) + + await self.action_queue.put(timestamped_action) + + async def collect_actions(self) -> tuple[torch.Tensor, torch.Tensor]: + """ + Collect actions from the UI at whatever rate is sent over by the client. + This collects all the frames that have been supplied between the last frame generation and now. + + It takes in as many actions as the model generates frames at once. For example, in CausVid, if + the model generates 4 frames at a time, this function will return [4, 2] and [4, 11] + for mouse and button actions. + + If, for one reason or another, we have <4 actions, we will fill the batch with idle actions. + + Returns: + mouse: [X, 2] + button: [X, n_buttons] + """ + real_actions = [] + start_time = time.time() + + # First, clear any stale actions from the queue (older than 1 second) + stale_threshold = start_time - 1.0 + temp_actions = [] + stale_count = 0 + + # Drain the queue and filter out stale actions + while not self.action_queue.empty(): + try: + timestamped_action = self.action_queue.get_nowait() + action, timestamp = timestamped_action + if timestamp >= stale_threshold: + temp_actions.append(action) + else: + stale_count += 1 + except asyncio.QueueEmpty: + break + + # Re-add fresh actions to the queue + for action in temp_actions: + try: + await self.action_queue.put((action, time.time())) + except asyncio.QueueFull: + break # Skip if queue is full + + # Now collect actions for the current batch + while start_time + self.streaming_config.batch_duration > time.time(): + try: + timeout = max(0.01, self.streaming_config.batch_duration - (time.time() - start_time)) + timestamped_action = await asyncio.wait_for(self.action_queue.get(), timeout=timeout) + action, timestamp = timestamped_action + real_actions.append(action) + except asyncio.TimeoutError: + pass + + # Convert real actions to batch tensors + mouse, button = self.converter.actions_to_sequence(real_actions) + mouse = _interpolate(mouse, empty_action=self.empty_mouse, + target_length=self.streaming_config.frames_per_batch) + button = _interpolate(button,empty_action=self.empty_buttons, + target_length=self.streaming_config.frames_per_batch) + + return mouse, button diff --git a/webapp/checkpoints/configs/ae.yml b/webapp/checkpoints/configs/ae.yml new file mode 100644 index 00000000..570f334b --- /dev/null +++ b/webapp/checkpoints/configs/ae.yml @@ -0,0 +1,53 @@ +# Config for a simple 256 -> 16 autoencoder +model: + model_id: dcae + sample_size: [360,640] + channels: 3 + latent_size: 4 + latent_channels: 128 + + noise_decoder_inputs: 0.0 + ch_0: 128 + ch_max: 1024 + + encoder_blocks_per_stage: [3, 3, 3, 3, 3, 3, 3, 3] + decoder_blocks_per_stage: [3, 3, 3, 3, 3, 3, 3, 3] + + checkpoint_grads: true + +train: + trainer_id: rec + data_id: s3_cod + target_batch_size: 128 + batch_size: 16 + + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + weight_decay: 1.0e-4 + betas: [0.9, 0.95] + eps: 1.0e-15 + + lpips_type: convnext + loss_weights: + latent_reg: 1.0e-6 + lpips: 10.0 + se_reg: 0.0 + + scheduler: LinearWarmup + scheduler_kwargs: + warmup_steps: 3000 + min_lr: 5.0e-6 + + checkpoint_dir: webapp/checkpoints/models/cod_128x_30k_ema + resume_ckpt: null + + sample_interval: 1000 + save_interval: 5000 + +wandb: + name: ${env:WANDB_USER_NAME} + project: new_vaes + run_name: 128x_cod \ No newline at end of file diff --git a/webapp/checkpoints/configs/audio.yml b/webapp/checkpoints/configs/audio.yml new file mode 100644 index 00000000..1fde568d --- /dev/null +++ b/webapp/checkpoints/configs/audio.yml @@ -0,0 +1,57 @@ +model: + model_id: audio_ae + + channels: 2 + latent_channels: 64 + ch_0: 128 + ch_max: 512 + + strides: [3, 5, 7, 7, 1] + + eq: true + checkpoint_grads: true + +train: + trainer_id: audio_rec + data_id: local_cod_audio + data_kwargs: + window_length: 88200 + root: "../cod_download/raw" + + target_batch_size: 128 + batch_size: 16 + epochs: 100 + + opt: AdamW + opt_kwargs: + lr: 1.0e-4 + eps: 1.0e-15 + betas: [0.9, 0.95] + weight_decay: 1.0e-4 + + loss_weights: + recon: 2.5 + stft: 1.5 + kl: 1.0e-5 + lr_ms_ratio: 0.5 + hubert: 0.0 + crt: 4.0 + + sample_rate: 44100 + n_fft_list: [1024, 2048, 512] + + scheduler: LinearWarmup + scheduler_kwargs: + warmup_steps: 1500 + min_lr: 1.0e-6 + + checkpoint_dir: webapp/checkpoints/models/audio_20k_ema.pt + sample_interval: 500 + save_interval: 5000 + + resume_ckpt: null + +wandb: + name: ${env:WANDB_USER_NAME} + project: owl_audio_vaes + run_name: audio_ae_baseline \ No newline at end of file diff --git a/webapp/checkpoints/configs/av.yml b/webapp/checkpoints/configs/av.yml new file mode 100644 index 00000000..65085dcc --- /dev/null +++ b/webapp/checkpoints/configs/av.yml @@ -0,0 +1,75 @@ +model: + model_id: game_rft_audio + sample_size: 4 + channels: 128 + audio_channels: 64 + + n_layers: 13 + n_heads: 16 + d_model: 1024 + + tokens_per_frame: 17 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 30 + + causal: false + +train: + trainer_id: av + data_id: cod_s3_audio + data_kwargs: + window_length: 30 + bucket_name: cod-data-latent-360x640to4x4 + + target_batch_size: 256 + batch_size: 32 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 1.0e-4 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [core.proj_in, core.proj_out.proj] + + scheduler: null + + checkpoint_dir: checkpoints/360p + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: av_window + sampler_kwargs: + n_steps: 10 + cfg_scale: 1.3 + window_length: 30 + num_frames: 60 + noise_prev: 0.2 + only_return_generated: false + + n_samples: 8 + + vae_id: null + vae_batch_size: 4 + vae_scale: 0.13 + audio_vae_scale: 0.17 + + vae_cfg_path: webapp/checkpoints/configs/ae.yml + vae_ckpt_path: webapp/checkpoints/models/cod_128x_30k_ema.pt + + audio_vae_id: null + audio_vae_cfg_path: webapp/checkpoints/configs/audio.yml + audio_vae_ckpt_path: webapp/checkpoints/models/cod_audio_20k_ema.pt + +wandb: + name: shahbuland + project: video_models + run_name: av \ No newline at end of file diff --git a/webapp/server.py b/webapp/server.py new file mode 100644 index 00000000..d66ed616 --- /dev/null +++ b/webapp/server.py @@ -0,0 +1,93 @@ +import os +from dotenv import load_dotenv ; load_dotenv() + +from contextlib import asynccontextmanager +from fastapi import FastAPI, WebSocket +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse + +from webapp.streaming import StreamingFrameGenerator +from webapp.user_session import UserGameSession +from webapp.utils.configs import WebappConfig + + +DEBUG = True + +# -- lifespan +config: WebappConfig = None +webapp_config_path = "./webapp/webapp_config.yaml" ; assert os.path.exists(webapp_config_path) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global config, DEBUG + config = WebappConfig.from_yaml(webapp_config_path) + yield + config = None + + +def run(): + """Create and configure the FastAPI app with routes.""" + app = FastAPI(lifespan=lifespan) + + @app.get("/") + async def read_root(): + """Serve the main game page.""" + return FileResponse("webapp/static/index.html") + + @app.websocket("/ws/game") + async def websocket_endpoint(websocket: WebSocket): + global DEBUG + await websocket.accept() + + # Create streaming session for this user + frame_generator = StreamingFrameGenerator(streaming_config=config.stream_config, + run_config=config.run_config, + debug=DEBUG) + session = UserGameSession(frame_generator) + await session.run_session(websocket) + + app.mount("/assets", StaticFiles(directory="webapp/static"), name="assets") + return app + + +def main(): + global DEBUG + import argparse + import uvicorn + + + parser = argparse.ArgumentParser() + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + parser.add_argument("--no-debug", action="store_true", default=True, help="Disable debug mode") + parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + args = parser.parse_args() + + + if args.debug: + DEBUG = True + elif args.no_debug: + DEBUG = False + # Otherwise keep the default value (True) + + # Create app AFTER setting DEBUG + app = run() + + print("🚀 Starting OWL-WMS FastAPI Server...") + print("📡 WebSocket endpoint: ws://localhost:8000/ws/game") + print("🌐 Access via: http://localhost:8000") + print("🔄 Auto-reload enabled for development") + print("🔄 DEBUG is set to:", DEBUG) + print("🔄 PORT is set to:", args.port) + + uvicorn.run( + app, # Pass the app object directly instead of module string + host="0.0.0.0", # Allow external connections + port=args.port, + reload=False, # Can't use reload with app object + log_level="info" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/webapp/static/index.html b/webapp/static/index.html new file mode 100644 index 00000000..4142d931 --- /dev/null +++ b/webapp/static/index.html @@ -0,0 +1,526 @@ + + +
+ +Move your mouse over the canvas to interact with the game world!
+ + + + diff --git a/webapp/streaming.py b/webapp/streaming.py new file mode 100644 index 00000000..004a897f --- /dev/null +++ b/webapp/streaming.py @@ -0,0 +1,132 @@ +import time +import torch +import asyncio +from torch import nn + +from webapp.utils.configs import StreamingConfig +from webapp.utils.av_window_inference_pipeline import AV_WindowInferencePipeline +from owl_wms.configs import Config as RunConfig + + +class FrameBuffer: + """ + Manages frame streaming at precise timing, to adhere to a max FPS. + """ + + def __init__(self, streaming_config: StreamingConfig): + self.streaming_config = streaming_config + self.video_frame_queue = asyncio.Queue(maxsize=streaming_config.frames_per_batch * 2) + self.audio_frame_queue = asyncio.Queue(maxsize=streaming_config.frames_per_batch * 2) # Add audio queue + self.buttons_queue = asyncio.Queue(maxsize=streaming_config.frames_per_batch * 2) + self.mouse_queue = asyncio.Queue(maxsize=streaming_config.frames_per_batch * 2) + self.last_frame_time = 0.0 + + async def queue_frames(self, + video_frames: torch.Tensor, # [t,c,h,w] + audio_frames: torch.Tensor, # [t,?,2] + mouse: torch.Tensor, + button: torch.Tensor): + # video_frames shape: [frames_per_batch, channels, height, width] + # audio_frames shape: [frames_per_batch, 2] + num_frames = video_frames.shape[0] + + for i in range(num_frames): + await self.video_frame_queue.put(video_frames[i]) + await self.audio_frame_queue.put(audio_frames[i]) # Queue audio frames + await self.buttons_queue.put(button[i]) + await self.mouse_queue.put(mouse[i]) + + async def get_next_frames(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get next video, audio, and input frames for streaming at capped FPS.""" + now = time.time() + time_since_last = now - self.last_frame_time + time_to_wait = max(0, self.streaming_config.frame_interval - time_since_last) + + if time_to_wait > 0: + await asyncio.sleep(time_to_wait) + + video_frame = await self.video_frame_queue.get() + audio_frame = await self.audio_frame_queue.get() + button = await self.buttons_queue.get() + mouse = await self.mouse_queue.get() + self.last_frame_time = time.time() + return video_frame, audio_frame, button, mouse + + +class StreamingFrameGenerator: + """Wraps WindowCFGSampler to generate frames.""" + + def __init__(self, + streaming_config: StreamingConfig, + run_config: RunConfig, + debug: bool = False): + + self.run_config = run_config + self.streaming_config = streaming_config + self.debug = debug + + self.av_window_inference_pipeline = AV_WindowInferencePipeline( + config = self.run_config, + ckpt_path = self.streaming_config.model_checkpoint_path, + video_latent_history = self.streaming_config.video_latent_history, + audio_latent_history = self.streaming_config.audio_latent_history, + mouse_history = self.streaming_config.mouse_history, + button_history = self.streaming_config.button_history, + return_only_generated = True, + compile = True + ) + + + async def generate_frames(self, mouse: torch.Tensor, button: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate window_length frames, return separate video and overlay frames for streaming. + + Args: + mouse: [window_length, 2] , user input mouse + button: [window_length, n_buttons] , user input button + + Returns: + tuple: (video_frames, audio_frames) + video_frames: [frames_per_batch, 3, 256, 256] - pure video frames + audio_frames: [frames_per_batch, 2] - audio frames + """ + mouse = mouse.to(self.streaming_config.device) + button = button.to(self.streaming_config.device) + + if self.debug: + num_frames = mouse.shape[0] + # Create gradient from white to black to white across columns + col_indices = torch.arange(256, device=self.streaming_config.device) + # Create gradient that goes from 1 to 0 to 1 + gradient = torch.where( + col_indices < 128, + 1.0 - (col_indices / 127.0), # First half: 1 to 0 + (col_indices - 128) / 127.0 # Second half: 0 to 1 + ).view(1, 1, 1, -1) # [1, 1, 1, 256] + full_frames = gradient.expand(num_frames, 3, 256, 256).to(torch.bfloat16) + audio_frames = torch.randn(num_frames, 2, device=self.streaming_config.device, dtype=torch.bfloat16) + # to between 0 and 1 + full_frames = (full_frames - full_frames.min()) / (full_frames.max() - full_frames.min()) + # between 0 and 255 + full_frames = (full_frames * 255).to(torch.uint8) + else: + device_type = 'cuda' if self.streaming_config.device == 'cuda' else 'cpu' + with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16): + full_frames, audio_frames = self.av_window_inference_pipeline( + user_input_mouse=mouse.float().unsqueeze(0), # NOTE Need batch dimension + user_input_button=button.float().unsqueeze(0) # NOTE Need batch dimension + ) # [3, 256, 256], [f, 2] + + # convert the frames to a pixel-range of [0-255] from [-1,1] + full_frames = (full_frames + 1) / 2 # NOTE for some reason this is slightly off of [-1, 1] + full_frames = (full_frames * 255).to(torch.uint8) # [3, 256, 256] + full_frames = torch.clip(full_frames, 0, 255) # bandaid for the [-1,1] + + return full_frames.unsqueeze(0), audio_frames.unsqueeze(0) # [t, 3, 256, 256], [t, f, 2] where t = 1 for frame-by-frame rollouts + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.cuda.empty_cache() diff --git a/webapp/user_session.py b/webapp/user_session.py new file mode 100644 index 00000000..ef4a1f22 --- /dev/null +++ b/webapp/user_session.py @@ -0,0 +1,158 @@ +import io +import wave +import cv2 +import time +import json +import torch +import base64 +import asyncio +import termcolor +import numpy as np + +from fastapi import WebSocket +from webapp.action_converter import ActionCollector +from webapp.streaming import StreamingFrameGenerator, FrameBuffer +from taskgroup import TaskGroup + +class UserGameSession: + """ + Orchestrates receiving actions from the UI, generating frames, and displaying them. + """ + def __init__(self, frame_generator: StreamingFrameGenerator): + self.frame_generator = frame_generator + self.action_collector = ActionCollector(frame_generator.streaming_config) + self.frame_buffer = FrameBuffer(frame_generator.streaming_config) + + async def run_session(self, websocket: WebSocket): + with self.frame_generator: + print(termcolor.colored(f"Starting streaming session at {self.frame_generator.streaming_config.fps} FPS", "green")) + print(termcolor.colored(f"Generating {self.frame_generator.streaming_config.frames_per_batch} frames per batch", "green")) + print(termcolor.colored(f"Batch duration: {self.frame_generator.streaming_config.batch_duration:.3f}s", "green")) + + async with TaskGroup() as tg: + tg.create_task(self._action_input_loop (websocket)) + tg.create_task(self._frame_generation_loop ()) + tg.create_task(self._frame_display_loop (websocket)) + + async def _action_input_loop(self, websocket: WebSocket): + while True: + try: + message = await websocket.receive_text() + action_data = json.loads(message) + await self.action_collector.add_websocket_action(action_data) + except Exception as e: + # Check if this is a WebSocket disconnect + if "websocket.close" in str(e) or "response already completed" in str(e) or "WebSocket" in str(e): + print(termcolor.colored("🔌 WebSocket disconnected - stopping action input", "yellow")) + break + else: + print(f"Error processing action: {e}") + break # Exit loop on any other error + + async def _frame_generation_loop(self): + """Generate frame batches continuously.""" + print(termcolor.colored("Frame generation loop started", "green")) + while True: + try: + # Collect multiple frames worth of actions, e.g. X frames, that have happened between the last frame generation and now. + mouse, button = await self.action_collector.collect_actions() + # Generate Y frames from X actions by taking the X[-1]'th action. Typically, X >> Y, because they are sampled at uncapped FPS from the UI, + # whereas Y frames are sampled from the model one at a time. + video_frames, audio_frames = await self.frame_generator.generate_frames(mouse, button) # TODO What are dims of mouse, button here? + # Queue frames for streaming at a capped FPS. If model predictions speed up or slow down, it won't cause any dilation of frames being displayed. + # However, if the model predictions are too slow, the frames will be displayed at a lower FPS than the capped FPS. + await self.frame_buffer.queue_frames(video_frames, audio_frames, mouse, button) # TODO What should we pass into here? mouse should have 1 frame, button should have 1 frame? + except Exception as e: + import traceback + print(termcolor.colored(f"Error in frame generation: {e} :\n {traceback.format_exc()}", "red")) + await asyncio.sleep(0.05) # Brief pause before retry + + async def _frame_display_loop(self, websocket: WebSocket): + while True: + try: + # Check if WebSocket is still connected before processing + if websocket.client_state.name != 'CONNECTED': + print(termcolor.colored("🔌 WebSocket no longer connected - stopping frame stream", "yellow")) + break + + video_frame, audio_frames, button, mouse = await self.frame_buffer.get_next_frames() + await self._send_frames_to_client(websocket, video_frame, audio_frames, button, mouse) + except Exception as e: + # Check if this is a WebSocket disconnect + if ("websocket.close" in str(e) or + "response already completed" in str(e) or + "Cannot call \"send\" once a close message has been sent" in str(e) or + "RuntimeError" in str(e)): + print(termcolor.colored("🔌 WebSocket disconnected - stopping frame stream", "yellow")) + break + else: + import traceback + print(termcolor.colored(f"Error in frame streaming: {e} :\n {traceback.format_exc()}", "red")) + await asyncio.sleep(self.frame_generator.streaming_config.frame_interval) + + async def _send_frames_to_client(self, + websocket: WebSocket, + video_frame: torch.Tensor, + audio_frames: torch.Tensor, # Add audio_frame parameter + button: torch.Tensor, + mouse: torch.Tensor): + try: + # Check WebSocket state before sending + if websocket.client_state.name != 'CONNECTED': + raise RuntimeError("WebSocket is not connected") + + # Convert video frame to base64 JPEG (existing code) + video_frame_np = video_frame.float().cpu().numpy().transpose(1, 2, 0) + if video_frame_np.max() <= 1.0: + video_frame_np = (video_frame_np * 255).clip(0, 255).astype(np.uint8) + else: + video_frame_np = video_frame_np.clip(0, 255).astype(np.uint8) + + _, video_buffer = cv2.imencode('.jpg', video_frame_np) + video_base64 = base64.b64encode(video_buffer).decode('utf-8') + + # Convert audio frame to base64 WAV + audio_base64 = self._encode_audio_to_wav(audio_frames) + + await websocket.send_json({ + "type": "frame", + "video_data": video_base64, + "audio_data": audio_base64, # Add audio data + "button_data": self.action_collector.converter.buttons_to_dict(button), + "mouse_data": self.action_collector.converter.mouse_to_dict(mouse), + "timestamp": time.time() + }) + except Exception as e: + raise e + + def _encode_audio_to_wav(self, audio_frames: torch.Tensor, sample_rate: int = 44100) -> str: + """ + Convert audio tensor to base64 encoded WAV data. + + Args: + audio_frames: [window_length, 2] tensor representing stereo audio for one frame + sample_rate: Audio sample rate (default 44100 Hz) + """ + # Convert to numpy and ensure it's in the right format + audio_np = audio_frames.float().cpu().numpy() + + # Normalize audio to [-1, 1] range if needed + if audio_np.dtype == torch.bfloat16 or audio_np.max() > 1.0: + audio_np = np.clip(audio_np, -1.0, 1.0) + + # Convert to 16-bit PCM (standard for WAV) + audio_int16 = (audio_np * 32767).astype(np.int16) + + # Create WAV data in memory + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wav_file: + wav_file.setnchannels(2) # Stereo + wav_file.setsampwidth(2) # 2 bytes per sample (16-bit) + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_int16.tobytes()) + + # Get WAV data and encode as base64 + wav_data = wav_buffer.getvalue() + return base64.b64encode(wav_data).decode('utf-8') + + # TODO Audio diff --git a/webapp/utils/__init__.py b/webapp/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/webapp/utils/action_builder.py b/webapp/utils/action_builder.py new file mode 100644 index 00000000..47be882e --- /dev/null +++ b/webapp/utils/action_builder.py @@ -0,0 +1,321 @@ +import math +import torch +from enum import Enum +from dataclasses import dataclass +from typing import List, Dict, Optional, Union, Tuple, Callable + + +# Button mapping from the codebase +BUTTON_NAMES = ["W", "A", "S", "D", "LSHIFT", "SPACE", "R", "F", "E", "LMB", "RMB"] +BUTTON_INDICES = {name: idx for idx, name in enumerate(BUTTON_NAMES)} + + +class ActionPattern(Enum): + IDLE = "idle" + WALK_FORWARD = "walk_forward" + STRAFE_LEFT = "strafe_left" + STRAFE_RIGHT = "strafe_right" + WALK_BACKWARD = "walk_backward" + CIRCLE_STRAFE = "circle_strafe" + LOOK_AROUND = "look_around" + SHOOT = "shoot" + SPRINT_FORWARD = "sprint_forward" + RELOAD = "reload" + + +@dataclass(frozen=True) +class ActionConfig: + sequence_length: int + device: Union[str, torch.device] = 'cpu' + dtype: torch.dtype = torch.float32 + n_buttons: int = 11 + mouse_range: Tuple[float, float] = (-1.0, 1.0) + smooth_transitions: bool = True + random_seed: Optional[int] = None + + +class MouseGenerator: + @staticmethod + def _apply_smoothing(window_size: int, values: torch.Tensor) -> torch.Tensor: + """Apply smoothing to mouse movements using convolution.""" + if window_size <= 1: + return values + + # values is [2, sequence_length] for mouse x,y coordinates + # Use groups=2 to smooth each channel independently + kernel = torch.ones(2, 1, window_size, device=values.device) / window_size + padding = window_size // 2 + padded = torch.nn.functional.pad(values, (padding, padding), mode='reflect') + smoothed = torch.nn.functional.conv1d( + padded.unsqueeze(0), # Add batch dim: [1, 2, sequence_length] + kernel, padding=0, + groups=2 # Each output channel only depends on corresponding input channel + ).squeeze(0) # Remove batch dim: [2, sequence_length] + return smoothed + + @staticmethod + def idle(config: ActionConfig) -> torch.Tensor: + """Generate idle mouse movement (minimal random drift).""" + return torch.randn(config.sequence_length, 2, + device=config.device, dtype=config.dtype) * 0.05 + + @staticmethod + def look_around(config: ActionConfig, + speed: float = 0.3, + amplitude: float = 0.7) -> torch.Tensor: + """Generate smooth looking around movement.""" + t = torch.linspace(0, 4 * math.pi, config.sequence_length, + device=config.device, dtype=config.dtype) + + # Create smooth sinusoidal movement + mouse_x = amplitude * torch.sin(t * speed) * torch.cos(t * speed * 0.3) + mouse_y = amplitude * torch.cos(t * speed * 0.7) * torch.sin(t * speed * 0.2) + + movement = torch.stack([mouse_x, mouse_y], dim=1) + + if config.smooth_transitions: + return MouseGenerator._apply_smoothing(5, movement.T).T + + return movement + + @staticmethod + def aim_tracking(config: ActionConfig, + target_speed: float = 0.1, + noise_level: float = 0.02) -> torch.Tensor: + """Generate aiming/tracking movement with micro-adjustments.""" + # Create base tracking movement + t = torch.linspace(0, 2 * math.pi, config.sequence_length, + device=config.device, dtype=config.dtype) + + # Smooth circular tracking + base_x = 0.3 * torch.sin(t * target_speed) + base_y = 0.2 * torch.cos(t * target_speed * 1.2) + + # Add realistic micro-movements + noise_x = torch.randn(config.sequence_length, device=config.device) * noise_level + noise_y = torch.randn(config.sequence_length, device=config.device) * noise_level + + return torch.stack([base_x + noise_x, base_y + noise_y], dim=1) + + @staticmethod + def custom_path(config: ActionConfig, + path_points: List[Tuple[float, float]], + interpolation: str = 'linear') -> torch.Tensor: + """Generate mouse movement following a custom path.""" + if len(path_points) < 2: + return MouseGenerator.idle(config) + + # Convert to tensors + points = torch.tensor(path_points, device=config.device, dtype=config.dtype) + + # Create interpolation indices + t = torch.linspace(0, len(points) - 1, config.sequence_length, + device=config.device, dtype=config.dtype) + + # Linear interpolation between points + indices = t.long() + weights = t - indices.float() + + # Handle edge case + indices = torch.clamp(indices, 0, len(points) - 2) + weights = weights.unsqueeze(1) + + interpolated = (1 - weights) * points[indices] + weights * points[indices + 1] + + return interpolated + + +class ButtonGenerator: + @staticmethod + def idle(config: ActionConfig) -> torch.Tensor: + """Generate idle button state (all buttons released).""" + return torch.zeros(config.sequence_length, config.n_buttons, + device=config.device, dtype=config.dtype) + + @staticmethod + def hold_buttons(config: ActionConfig, + button_names: List[str], + start_frame: int = 0, + duration: Optional[int] = None) -> torch.Tensor: + """Hold specific buttons for a duration.""" + buttons = torch.zeros(config.sequence_length, config.n_buttons, + device=config.device, dtype=config.dtype) + + end_frame = start_frame + (duration or config.sequence_length) + end_frame = min(end_frame, config.sequence_length) + + for button_name in button_names: + if button_name in BUTTON_INDICES: + idx = BUTTON_INDICES[button_name] + buttons[start_frame:end_frame, idx] = 1.0 + + return buttons + + @staticmethod + def tap_sequence(config: ActionConfig, + button_sequences: List[Tuple[str, int, int]]) -> torch.Tensor: + """Create button taps at specific times. + + Args: + button_sequences: List of (button_name, start_frame, duration) tuples + """ + buttons = torch.zeros(config.sequence_length, config.n_buttons, + device=config.device, dtype=config.dtype) + + for button_name, start_frame, duration in button_sequences: + if button_name in BUTTON_INDICES and start_frame < config.sequence_length: + idx = BUTTON_INDICES[button_name] + end_frame = min(start_frame + duration, config.sequence_length) + buttons[start_frame:end_frame, idx] = 1.0 + + return buttons + + @staticmethod + def pattern_from_name(config: ActionConfig, pattern: ActionPattern) -> torch.Tensor: + """Generate button pattern from predefined patterns.""" + if pattern == ActionPattern.WALK_FORWARD: + return ButtonGenerator.hold_buttons(config, ["W"]) + elif pattern == ActionPattern.STRAFE_LEFT: + return ButtonGenerator.hold_buttons(config, ["A"]) + elif pattern == ActionPattern.STRAFE_RIGHT: + return ButtonGenerator.hold_buttons(config, ["D"]) + elif pattern == ActionPattern.WALK_BACKWARD: + return ButtonGenerator.hold_buttons(config, ["S"]) + elif pattern == ActionPattern.SPRINT_FORWARD: + return ButtonGenerator.hold_buttons(config, ["W", "LSHIFT"]) + elif pattern == ActionPattern.RELOAD: + return ButtonGenerator.tap_sequence(config, [("R", 10, 20)]) + else: + return ButtonGenerator.idle(config) + + +class ActionSequenceBuilder: + def __init__(self, config: ActionConfig): + self.config = config + self.mouse_sequence = torch.zeros(config.sequence_length, 2, + device=config.device, dtype=config.dtype) + self.button_sequence = torch.zeros(config.sequence_length, config.n_buttons, + device=config.device, dtype=config.dtype) + + def add_mouse_segment(self, + start_frame: int, + end_frame: int, + generator_func: Callable[[ActionConfig], torch.Tensor], + **kwargs) -> 'ActionSequenceBuilder': + """Add a mouse movement segment.""" + segment_config = ActionConfig( + sequence_length=end_frame - start_frame, + device=self.config.device, + dtype=self.config.dtype, + n_buttons=self.config.n_buttons + ) + + segment = generator_func(segment_config, **kwargs) + self.mouse_sequence[start_frame:end_frame] = segment + return self + + def add_button_segment(self, + start_frame: int, + end_frame: int, + generator_func: Callable[[ActionConfig], torch.Tensor], + **kwargs) -> 'ActionSequenceBuilder': + """Add a button press segment.""" + segment_config = ActionConfig( + sequence_length=end_frame - start_frame, + device=self.config.device, + dtype=self.config.dtype, + n_buttons=self.config.n_buttons + ) + + segment = generator_func(segment_config, **kwargs) + self.button_sequence[start_frame:end_frame] = segment + return self + + def build(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Build and return the final action sequence.""" + return self.mouse_sequence, self.button_sequence + + +class ActionSequenceGenerator: + def __init__(self, config: ActionConfig): + self.config = config + if config.random_seed is not None: + torch.manual_seed(config.random_seed) + + def generate_pattern(self, + pattern: ActionPattern, + mouse_kwargs: Optional[Dict] = None, + button_kwargs: Optional[Dict] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate actions for a predefined pattern.""" + mouse_kwargs = mouse_kwargs or {} + button_kwargs = button_kwargs or {} + + if pattern == ActionPattern.IDLE: + mouse = MouseGenerator.idle(self.config) + buttons = ButtonGenerator.idle(self.config) + + elif pattern == ActionPattern.LOOK_AROUND: + mouse = MouseGenerator.look_around(self.config, **mouse_kwargs) + buttons = ButtonGenerator.idle(self.config) + + elif pattern == ActionPattern.SHOOT: + mouse = MouseGenerator.aim_tracking(self.config, **mouse_kwargs) + # Add some shooting + shoot_times = [(i * 30, 5) for i in range(self.config.sequence_length // 30)] + buttons = ButtonGenerator.tap_sequence( + self.config, + [("LMB", start, dur) for start, dur in shoot_times] + ) + + elif pattern == ActionPattern.CIRCLE_STRAFE: + # Combine circular mouse movement with strafing + mouse = MouseGenerator.look_around(self.config, speed=0.2, amplitude=0.5) + buttons = ButtonGenerator.hold_buttons(self.config, ["A", "W"]) + + else: + mouse = MouseGenerator.idle(self.config) + buttons = ButtonGenerator.pattern_from_name(self.config, pattern) + + return mouse, buttons + + def generate_custom_sequence(self) -> ActionSequenceBuilder: + """Get a builder for creating custom action sequences.""" + return ActionSequenceBuilder(self.config) + + def generate_batch(self, + batch_size: int, + patterns: Optional[List[ActionPattern]] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate a batch of action sequences.""" + if patterns is None: + patterns = [ActionPattern.IDLE] * batch_size + + if len(patterns) != batch_size: + # Repeat or truncate patterns to match batch size + patterns = (patterns * (batch_size // len(patterns) + 1))[:batch_size] + + mouse_batch = [] + button_batch = [] + + for pattern in patterns: + mouse, buttons = self.generate_pattern(pattern) + mouse_batch.append(mouse) + button_batch.append(buttons) + + return torch.stack(mouse_batch), torch.stack(button_batch) + + +if __name__ == "__main__": + print("=== Action Generation Examples ===") + config = ActionConfig(sequence_length=100) + builder = ActionSequenceGenerator(config).generate_custom_sequence() + + mouse_custom, button_custom = (builder + .add_mouse_segment(0, 30, MouseGenerator.idle) + .add_mouse_segment(30, 70, MouseGenerator.look_around, speed=0.5) + .add_mouse_segment(70, 100, MouseGenerator.aim_tracking) + .add_button_segment(0, 50, ButtonGenerator.hold_buttons, button_names=["W"]) + .add_button_segment(50, 100, ButtonGenerator.hold_buttons, button_names=["W", "LSHIFT"]) + .build()) + + print(f"Custom sequence mouse shape: {mouse_custom.shape}, button shape: {button_custom.shape}") + print("Button names mapping:", BUTTON_INDICES) \ No newline at end of file diff --git a/webapp/utils/av_window_inference_pipeline.py b/webapp/utils/av_window_inference_pipeline.py new file mode 100644 index 00000000..cb3fd3fe --- /dev/null +++ b/webapp/utils/av_window_inference_pipeline.py @@ -0,0 +1,145 @@ +import os +import time +import torch +from torch.nn import Module +from owl_wms.models import get_model_cls +from owl_wms.utils.owl_vae_bridge import get_decoder_only +from owl_wms.configs import Config as RunConfig +from owl_wms.models.gamerft_audio import GameRFTCore + +def zlerp(x, alpha): + return x * (1. - alpha) + alpha * torch.randn_like(x) + + +def print_duration(func): + """Decorator that logs the input and output of a function.""" + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + print(f"{func.__name__} took {execution_time:.3f} seconds to execute, which would yield FPS of {1/execution_time:.3f}") + return result + return wrapper + + +class AV_WindowInferencePipeline: + def __init__(self, + config: RunConfig, + video_latent_history: torch.Tensor, + audio_latent_history: torch.Tensor, + mouse_history: torch.Tensor, + button_history: torch.Tensor, + ckpt_path: str = "av_dfot_35k_ema_200m.pt", + alpha: float = 0.2, + cfg_scale: float = 1.3, + sampling_steps: int = 10, + audio_f: int = 735, + return_only_generated: bool = True, + compile: bool = True, + device: str = 'cuda'): + + self.return_only_generated = return_only_generated + self.config = config + self.device = device + + self.model: GameRFTCore = get_model_cls(self.config.model.model_id)(self.config.model).core + state_dict = torch .load(ckpt_path, map_location="cpu") + self.model .load_state_dict(state_dict) + self.model .eval() + self.model .to(self.device) + + self.frame_decoder: Module = get_decoder_only( + None, + self.config.train.vae_cfg_path, + self.config.train.vae_ckpt_path + ) + self.frame_decoder.eval() + self.frame_decoder.to(self.device) + + self.audio_decoder: Module = get_decoder_only( + None, + self.config.train.audio_vae_cfg_path, + self.config.train.audio_vae_ckpt_path + ) + self.audio_decoder.eval() + self.audio_decoder.to(self.device) + + self.frame_scale = self.config.train.vae_scale + self.audio_scale = self.config.train.audio_vae_scale + + self.history_buffer = (video_latent_history / self.frame_scale).to(self.device) + self.audio_buffer = (audio_latent_history / self.audio_scale).to(self.device) + self.mouse_buffer = mouse_history.to(self.device) + self.button_buffer = button_history.to(self.device) + + self.alpha = alpha + self.cfg_scale = cfg_scale + self.sampling_steps = sampling_steps + self.audio_f = audio_f + + if compile: + print(f'Compiling models...') + torch.compile(self.model) + torch.compile(self.frame_decoder) + torch.compile(self.audio_decoder) + + @print_duration + @torch.no_grad() + def __call__(self, + user_input_mouse: torch.Tensor, # b,1,2 + user_input_button: torch.Tensor # b,1,11 + ) -> tuple[torch.Tensor, torch.Tensor]: + + noised_history = zlerp(self.history_buffer[:,1:], self.alpha) + noised_audio = zlerp(self.audio_buffer[:,1:], self.alpha) + + noised_history = torch.cat([noised_history, torch.randn_like(noised_history[:,0:1])], dim = 1) + noised_audio = torch.cat([noised_audio, torch.randn_like(noised_audio[:,0:1])], dim = 1) + + self.mouse_buffer = torch.cat([self.mouse_buffer[:,1:],user_input_mouse],dim=1) + self.button_buffer = torch.cat([self.button_buffer[:,1:],user_input_button],dim=1) + + dt = 1. / self.sampling_steps + + x = noised_history + a = noised_audio + ts = torch.ones_like(noised_history[:,:,0,0,0]) + ts[:,:-1] = self.alpha + + # mouse_batch = torch.cat([self.mouse_buffer, torch.zeros_like(user_input_mouse)], dim=0) + # btn_batch = torch.cat([self.button_buffer, torch.zeros_like(user_input_button)], dim=0) + # TODO Who knows bruh idk. I think this is to get cfg - uncond is just no actions (zeros) + mouse_batch = torch.cat([self.mouse_buffer, torch.zeros_like(self.mouse_buffer)], dim=0) + btn_batch = torch.cat([self.button_buffer, torch.zeros_like(self.button_buffer)], dim=0) + + for _ in range(self.sampling_steps): + x_batch = torch.cat([x, x], dim=0) + a_batch = torch.cat([a, a], dim=0) + ts_batch = torch.cat([ts, ts], dim=0) + + video_rollout, audio_rollout = self.model(x_batch,a_batch,ts_batch,mouse_batch,btn_batch) + + cond_pred_video, uncond_pred_video = video_rollout.chunk(2) + cond_pred_audio, uncond_pred_audio = audio_rollout.chunk(2) + + pred_video = uncond_pred_video + self.cfg_scale * (cond_pred_video - uncond_pred_video) + pred_audio = uncond_pred_audio + self.cfg_scale * (cond_pred_audio - uncond_pred_audio) + + x[:,-1] = x[:,-1] - dt * pred_video[:,-1] + a[:,-1] = a[:,-1] - dt * pred_audio[:,-1] + ts[:,-1] = ts[:,-1] - dt + + new_frame = x[:,-1:] # [1,1,c,h,w] + new_audio = a[:,-1:] # [1,1,c] + + self.history_buffer = torch.cat([self.history_buffer[:,1:], new_frame], dim=1) + self.audio_buffer = torch.cat([self.audio_buffer[:,1:], new_audio], dim=1) + + frame = self.frame_decoder(new_frame[0] * self.frame_scale).squeeze() # [c,h,w] + audio = self.audio_decoder( + self.audio_buffer.permute(0,2,1) # need this as [b,c,t] for some reason + * self.audio_scale + ).squeeze()[-self.audio_f:].T # [735,2] + + return frame, audio diff --git a/webapp/utils/configs.py b/webapp/utils/configs.py new file mode 100644 index 00000000..9a3a222e --- /dev/null +++ b/webapp/utils/configs.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import os +import yaml +import torch +from dataclasses import dataclass + +from owl_wms.configs import Config as RunConfig + +@dataclass +class WebappConfig: + run_config : RunConfig + stream_config : StreamingConfig + sampling_config : SamplingConfig + run_config_path : os.PathLike + device : str = 'cuda' + + @classmethod + def from_yaml(cls, path: os.PathLike) -> WebappConfig: + # + with open(path, 'r') as wcp: + config = yaml.safe_load(wcp) + config['run_config'] = RunConfig.from_yaml(config['run_config_path']) + config['sampling_config'] = SamplingConfig(**config['sampling_config']) + config['stream_config'] = StreamingConfig(**config['stream_config']) + + return cls(**config) + + +@dataclass +class SamplingConfig: + sampling_steps : int = 20 + vae_scale: float = 1.0 + cfg_scale : float = 1.3 + window_length : int = 60 + num_frames : int = 60 + noise_prev : float = 0.2 + + +@dataclass +class StreamingConfig: + model_checkpoint_path : os.PathLike + fps: int = 20 + frames_per_batch: int = 1 + device: str = 'cuda' + n_buttons: int = 11 + n_mouse_axes: int = 2 + mouse_range: tuple[float, float] = (-1.0, 1.0) + video_latent_history_path: os.PathLike = None + audio_latent_history_path: os.PathLike = None + mouse_history_path: os.PathLike = None + button_history_path: os.PathLike = None + + @property + def frame_interval(self) -> float: return 1.0 / self.fps + + @property + def batch_duration(self) -> float: return self.frames_per_batch / self.fps + + @property + def video_latent_history(self) -> torch.Tensor: + if self.video_latent_history_path is None: raise ValueError("video_latent_history_path is not set") + return torch.load(self.video_latent_history_path) + + @property + def audio_latent_history(self) -> torch.Tensor: + if self.audio_latent_history_path is None: raise ValueError("audio_latent_history_path is not set") + return torch.load(self.audio_latent_history_path) + + @property + def mouse_history(self) -> torch.Tensor: + if self.mouse_history_path is None: raise ValueError("mouse_history_path is not set") + return torch.load(self.mouse_history_path) + + @property + def button_history(self) -> torch.Tensor: + if self.button_history_path is None: raise ValueError("button_history_path is not set") + return torch.load(self.button_history_path) \ No newline at end of file diff --git a/webapp/utils/create_samplers.py b/webapp/utils/create_samplers.py new file mode 100644 index 00000000..87fbbd8a --- /dev/null +++ b/webapp/utils/create_samplers.py @@ -0,0 +1,178 @@ +from torch import Tensor +from typing import Literal, Callable +from functools import partial, cache +from multimethod import multimethod +from owl_wms.sampling.cfg import CFGSampler +from owl_wms.sampling.simple import SimpleSampler, InpaintSimpleSampler +from owl_wms.sampling.window import WindowCFGSampler +from owl_wms.sampling.av_window import Inference_AV_WindowSampler +from owl_wms.utils.owl_vae_bridge import make_batched_decode_fn, make_batched_audio_decode_fn + +SAMPLING_STEPS = 60 +SCALE = 2.17 +CFG_SCALE = 1.3 + +MouseData = Tensor +ButtonData = Tensor +VideoData = Tensor +LatentData = Tensor + +@multimethod +def create_sampler(sampler_id: Literal['cfg'], encoder, decoder, + batch_size: int = 8, + n_steps: int = 20, + cfg_scale: float = 1.3, + vae_scale: float = 1.0, + **kwargs) -> Callable: + """Create CFG sampler with its specific parameters.""" + + @cache + def _sampler(): + return CFGSampler(n_steps=n_steps, cfg_scale=cfg_scale) + + return partial( + _sampler().__call__, + decode_fn=make_batched_decode_fn(decoder, batch_size=batch_size), + scale=vae_scale, + model=encoder + ) + + +@multimethod +def create_sampler(sampler_id: Literal['simple'], encoder, decoder, + batch_size: int = 8, + n_steps: int = 64, + vae_scale: float = 1.0, + **kwargs) -> Callable: + """Create Simple sampler with its specific parameters.""" + + @cache + def _sampler(): + return SimpleSampler(n_steps=n_steps) + + return partial( + _sampler().__call__, + decode_fn=make_batched_decode_fn(decoder, batch_size=batch_size), + scale=vae_scale, + model=encoder + ) + + +@multimethod +def create_sampler(sampler_id: Literal['inpaint_simple'], encoder, decoder, + batch_size: int = 8, + n_steps: int = 64, + vae_scale: float = 1.0, + **kwargs) -> Callable: + """Create Inpaint Simple sampler with its specific parameters.""" + + @cache + def _sampler(): + return InpaintSimpleSampler(n_steps=n_steps) + + return partial( + _sampler().__call__, + decode_fn=make_batched_decode_fn(decoder, batch_size=batch_size), + scale=vae_scale, + model=encoder + ) + + +@multimethod +def create_sampler(sampler_id: Literal['window'], encoder, decoder, + batch_size: int = 8, + n_steps: int = 20, + cfg_scale: float = 1.3, + window_length: int = 60, + num_frames: int = 60, + noise_prev: float = 0.2, + only_return_generated: bool = False, + vae_scale: float = 1.0, + **kwargs) -> Callable: + """Create Window CFG sampler with its specific parameters.""" + + @cache + def _sampler(): + return WindowCFGSampler( + n_steps=n_steps, + cfg_scale=cfg_scale, + window_length=window_length, + num_frames=num_frames, + noise_prev=noise_prev, + only_return_generated=only_return_generated + ) + + return partial( + _sampler().__call__, + decode_fn=make_batched_decode_fn(decoder, batch_size=batch_size), + scale=vae_scale, + model=encoder + ) + + +@multimethod +def create_sampler(sampler_id: Literal['av_window'], + encoder, decoder, audio_decoder, + batch_size: int = 8, + n_steps: int = 20, + cfg_scale: float = 1.3, + window_length: int = 60, + num_frames: int = 60, + noise_prev: float = 0.2, + only_return_generated: bool = True, + vae_scale: float = 1.0, + **kwargs) -> Callable: + + @cache + def _sampler(): + return Inference_AV_WindowSampler( + n_steps=n_steps, + cfg_scale=cfg_scale, + window_length=window_length, + num_frames=num_frames, + noise_prev=noise_prev, + only_return_generated=only_return_generated + ) + + # TODO `dummy_batch` is history, `audio` is audio history, mouse is mouse history, btn is button history + return partial( + _sampler().__call__, + decode_fn=make_batched_decode_fn(decoder, batch_size=batch_size), + audio_decode_fn=make_batched_audio_decode_fn(audio_decoder, batch_size=batch_size), + scale=vae_scale, + model=encoder + ) + + +if __name__ == "__main__": + # Each sampler type can be created with its specific parameters + import webapp.utils.models + encoder, decoder, model_config = webapp.utils.models.load_models() + + # CFG sampler + cfg_sampler = create_sampler( + 'cfg', encoder, decoder, + n_steps=25, cfg_scale=1.5 + ) + + # Simple sampler + simple_sampler = create_sampler( + 'simple', encoder, decoder, + n_steps=50 + ) + + # Window sampler with all its specific params + window_sampler = create_sampler( + 'window', encoder, decoder, + n_steps=30, cfg_scale=1.4, + window_length=80, num_frames=120, + noise_prev=0.3, only_return_generated=True + ) + + # AV Window sampler with all its specific params + av_window_sampler = create_sampler( + 'av_window', encoder, decoder, + n_steps=10, cfg_scale=1.3, + window_length=30, num_frames=60, + noise_prev=0.2, only_return_generated=False + ) \ No newline at end of file diff --git a/webapp/utils/demo_streaming.py b/webapp/utils/demo_streaming.py new file mode 100644 index 00000000..6ba54914 --- /dev/null +++ b/webapp/utils/demo_streaming.py @@ -0,0 +1,73 @@ +import asyncio +import torch +import termcolor + +from webapp.utils.render import generate_dummy_actions, save_video +from webapp.streaming import StreamingConfig, StreamingFrameGenerator +from webapp.utils.models import load_models +from webapp.utils.action_builder import ActionPattern +from webapp.utils.configs import SamplingConfig + + +DEBUG = True + +async def demo_streaming_generation(pattern=ActionPattern.LOOK_AROUND): + """Generate one batch using StreamingFrameGenerator instead of regular sampler.""" + + print(termcolor.colored("🎮 OWL-WMS Streaming Demo", "green")) + print(termcolor.colored("=" * 50, "green")) + + # Configuration + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + streaming_config = StreamingConfig( + fps=20, + frames_per_batch=8, + window_length=60, + device=device + ) + sampling_config = SamplingConfig() + + # Load models (reuse render.py's load_models) + print("📦 Loading models...") + encoder, decoder, train_config = load_models(device=device, verbose=True) + model_config = train_config.model + training_config = train_config.train + + # Create streaming frame generator + print("🎬 Creating streaming frame generator...") + frame_generator = StreamingFrameGenerator( + encoder, decoder, + streaming_config, model_config, training_config, sampling_config, + debug=DEBUG + ) + + # Generate actions (reuse render.py's generate_dummy_actions) + print(f"🎯 Generating {pattern.value} actions...") + mouse_batch, button_batch = generate_dummy_actions(pattern, streaming_config.window_length) + + # Generate frames using streaming generator + print("🎨 Generating frames with streaming generator...") + with frame_generator: + frame_batch = await frame_generator.generate_frame_batch(mouse_batch, button_batch) + + print(f"Generated {frame_batch.shape[0]} frames with shape: {frame_batch.shape}") + + # Save video (reuse render.py's save_video) + print("💾 Saving video...") + output_path = save_video(frame_batch, f"streaming_demo_{pattern.value.lower()}", fps=streaming_config.fps) + + print(termcolor.colored(f"🎉 Demo complete! Video: {output_path}", "green")) + return output_path + + +if __name__ == "__main__": + # Try different patterns by changing this: + pattern = ActionPattern.CIRCLE_STRAFE # or AIM_AND_SHOOT, CIRCLE_STRAFE, etc. + + print("Available patterns:") + for p in ActionPattern: + print(f" - {p.value}") + print() + + asyncio.run(demo_streaming_generation(pattern)) \ No newline at end of file diff --git a/webapp/utils/models.py b/webapp/utils/models.py new file mode 100644 index 00000000..e43d1d07 --- /dev/null +++ b/webapp/utils/models.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +from pathlib import Path +from typing import Dict, Optional, Union +from dataclasses import dataclass + +from termcolor import colored + +from owl_wms.models import get_model_cls +from owl_wms.configs import Config as RunConfig +from owl_wms.utils.owl_vae_bridge import get_decoder_only +from owl_wms.utils import freeze + +@dataclass(frozen=True) +class ModelPaths: + config: Path + checkpoint: Path + + @classmethod + def from_strings(cls, config_path: str, checkpoint_path: str) -> 'ModelPaths': + return cls( + config=Path(config_path), + checkpoint=Path(checkpoint_path) + ) + + def validate(self) -> None: + if not self.config.exists(): + raise FileNotFoundError(f"Config file not found: {self.config}") + if not self.checkpoint.exists(): + raise FileNotFoundError(f"Checkpoint file not found: {self.checkpoint}") + + +class ModelLoader: + DEFAULT_PATHS = ModelPaths.from_strings( + config_path='webapp/checkpoints/shortcut.yaml', + checkpoint_path='webapp/checkpoints/shortcut/step_165000.pt' + ) + + def __init__(self, paths: Optional[ModelPaths] = None): + self.paths = paths or self.DEFAULT_PATHS + self.paths.validate() + + @staticmethod + def _append_state_dict_prefix(state_dict: Dict, prefix: str = 'core.') -> Dict: + return {prefix+key: value for key, value in state_dict.items()} + + @staticmethod + def _count_parameters(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + def _load_config(self) -> RunConfig: + return RunConfig.from_yaml(str(self.paths.config)) + + def _load_checkpoint(self) -> Dict: + return torch.load(str(self.paths.checkpoint), map_location='cpu') + + def load_model(self, + device: Optional[Union[str, torch.device]] = None, + eval_mode: bool = True, + verbose: bool = True) -> nn.Module: + # Load configuration and create model + config = self._load_config() + model_cls = get_model_cls(config.model.model_id) + model = model_cls(config.model) + + # Load and filter state dict + checkpoint = self._load_checkpoint() + + if config.model.model_id == "game_rft": + checkpoint = self._append_state_dict_prefix(checkpoint) + + model.load_state_dict(checkpoint) + + # Configure model + if eval_mode: + model.eval() + + if device is not None: + model = model.to(device) + + # Print model information if requested + if verbose: + param_count = self._count_parameters(model) + print(f'{colored("Model loaded", "blue")}\t\t {colored("successfully", "green")}') + print(f'{colored("Parameters", "blue")} \t\t {colored(f"{param_count:,}", "green")}') + print(f'{colored("Config", "blue")} \t\t {colored(str(self.paths.config), "green", attrs=["bold"])}') + print(f'{colored("Checkpoint", "blue")} \t\t {colored(str(self.paths.checkpoint), "green", attrs=["bold"])}') + + return model + + def load_decoder(self, + device: Optional[Union[str, torch.device]] = None, + eval_mode: bool = True, + verbose: bool = True) -> nn.Module: + decoder = get_decoder_only(vae_id='dcae', cfg_path=str(self.paths.config), ckpt_path=str(self.paths.checkpoint)) + freeze(decoder) + + if verbose: + print(f'{colored("Decoder loaded", "blue")}\t\t {colored("successfully", "green")}') + print(f'{colored("Parameters", "blue")} \t\t {colored(f"{self._count_parameters(decoder):,}", "green")}') + print(f'{colored("Config", "blue")} \t\t {colored(str(self.paths.config), "green", attrs=["bold"])}') + print(f'{colored("Checkpoint", "blue")} \t\t {colored(str(self.paths.checkpoint), "green", attrs=["bold"])}') + + if device is not None: + decoder = decoder.to(device) + + if eval_mode: + decoder.eval() + + return decoder + + +def load_models(config_path: Optional[str] = None, + checkpoint_path: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + eval_mode: bool = True, + verbose: bool = True) -> tuple[nn.Module, nn.Module, RunConfig]: + """ + Convenience function for loading models with custom paths. + + Args: + config_path: Path to model configuration file + checkpoint_path: Path to model checkpoint file + device: Target device for the model + eval_mode: Whether to set model to evaluation mode + verbose: Whether to print model information + + Returns: + Loaded PyTorch model + """ + if config_path or checkpoint_path: + # Use custom paths if provided + paths = ModelPaths.from_strings( + config_path or str(ModelLoader.DEFAULT_PATHS.config), + checkpoint_path or str(ModelLoader.DEFAULT_PATHS.checkpoint) + ) + loader = ModelLoader(paths) + else: + # Use default paths + loader = ModelLoader() + + encoder = loader.load_model(device=device, eval_mode=eval_mode, verbose=verbose) + decoder = loader.load_decoder(device=device, eval_mode=eval_mode, verbose=verbose) + return encoder, decoder, loader._load_config() + + +if __name__ == "__main__": + # Example usage with different approaches + + # Method 1: Using the class directly + print("=== Loading model using ModelLoader class ===") + loader = ModelLoader() + model = loader.load_model() + print(f"Model type: {type(model).__name__}") + + # Method 2: Using the convenience function + print("=== Loading model using convenience function ===") + encoder, decoder, config = load_models() + print(f"Config: {config}") diff --git a/webapp/utils/render.py b/webapp/utils/render.py new file mode 100644 index 00000000..069901cc --- /dev/null +++ b/webapp/utils/render.py @@ -0,0 +1,178 @@ +import torch +import math +import time +from datetime import datetime +from pathlib import Path +import imageio +import numpy as np + +import einops as eo + +from webapp.utils.models import load_models +from webapp.utils.create_samplers import create_sampler, CFG_SCALE +from webapp.utils.action_builder import ActionSequenceGenerator, ActionConfig, ActionPattern + +HEIGHT = 256 +WIDTH = 256 +D_MODEL = 1024 +CHANNELS = 128 +SEQUENCE_LENGTH = 60 +TOKENS_PER_FRAME = 16 +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +OUTPUT_DIR = "generated_videos" +SAMPLER_TYPE = 'window' +DEFAULT_PATTERN = ActionPattern.LOOK_AROUND +VAE_SCALE = 2.17 + +def setup_output_dir(): Path(OUTPUT_DIR).mkdir(exist_ok=True) + + +def generate_dummy_actions(pattern=DEFAULT_PATTERN, length=SEQUENCE_LENGTH): + """Generate dummy actions for video conditioning.""" + config = ActionConfig( + sequence_length=length, + device=DEVICE, + dtype=torch.float32 + ) + + generator = ActionSequenceGenerator(config) + mouse, buttons = generator.generate_pattern(pattern) + + # Add batch dimension + return mouse.unsqueeze(0), buttons.unsqueeze(0) + + +def synthesize_video(mouse_actions, button_actions, encoder, decoder, sampler): + """Generate video using model and sampler.""" + batch_size, sequence_length = mouse_actions.shape[:2] + + dummy_batch = torch.randn( + batch_size, sequence_length, CHANNELS, + int(math.sqrt(TOKENS_PER_FRAME)), # H + int(math.sqrt(TOKENS_PER_FRAME)), # W + device=DEVICE, dtype=torch.float32 + ) + + # Ensure actions are on correct device + mouse_actions = mouse_actions.to(DEVICE) + button_actions = button_actions.to(DEVICE) + + # Generate video + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + latents, video, mouse, button = sampler( + dummy_batch=dummy_batch, + mouse=mouse_actions, + btn=button_actions + ) + + return video + + +def save_video(video_tensor: torch.Tensor, filename="generated_video", fps=30): + """ + Save video tensor as MP4 file. + + Args: + video_tensor: Tensor with shape [batch_size, sequence_length, channels, height, width] + Expected range: [-1, 1] (VAE decoder output) + filename: Base filename (without extension) + fps: Frames per second for output video + + Returns: + str: Path to saved video file + """ + setup_output_dir() + + output_path = Path(OUTPUT_DIR) / f"{filename}.mp4" + video_np: np.ndarray = video_tensor.float().cpu().detach().numpy() + + # Take first batch item if batch_size > 1 + if video_np.ndim == 5: # [batch, seq, channels, height, width] + video_np = video_np[0] # Take first batch item: [seq, channels, height, width] + + # Convert from [seq, channels, height, width] to [seq, height, width, channels] + video_np = eo.rearrange(video_np, 'seq c h w -> seq h w c') + + # Denormalize from actual range to [0,255] + print(f'Video range before denorm: [{video_np.min():.3f}, {video_np.max():.3f}]') + + # Normalize to [0, 1] using actual min/max + video_min, video_max = video_np.min(), video_np.max() + if video_max > video_min: # Avoid division by zero + video_np = (video_np - video_min) / (video_max - video_min) # -> [0, 1] + else: + video_np = np.zeros_like(video_np) # Handle edge case of constant values + + video_np = (video_np * 255.0).clip(0, 255).astype(np.uint8) # [0,1] -> [0,255] + + # Handle grayscale (single channel) by converting to RGB + if video_np.shape[-1] == 1: + video_np = np.repeat(video_np, 3, axis=-1) + elif video_np.shape[-1] > 3: + print(f'Warning: video has {video_np.shape[-1]} channels, taking only first 3') + video_np = video_np[:, :, :, :3] + + try: + imageio.mimsave(output_path, video_np, fps=fps, codec='libx264') + return str(output_path) + except Exception as e: + print(f"Warning: Could not save as MP4 ({e}), falling back to .pt format") + + # Fallback: save as PyTorch tensor + fallback_path = Path(OUTPUT_DIR) / f"{filename}.pt" + torch.save(video_tensor.cpu(), fallback_path) + return str(fallback_path) + + +def render_video( + pattern=DEFAULT_PATTERN, + length=SEQUENCE_LENGTH, + verbose=True, + sampler_type=SAMPLER_TYPE, + vae_scale=VAE_SCALE, + cfg_scale=CFG_SCALE, + encoder=None, decoder=None): + """Simple video rendering - just load, generate, save.""" + if verbose: + print(f"🎬 Rendering video with pattern: {pattern.value}") + + # Load model and create sampler + if encoder is None or decoder is None: + if verbose: + print("Loading model...") + encoder, decoder, model_config = load_models(device=DEVICE, verbose=verbose) + + if verbose: + print("Creating sampler...") + sampler = create_sampler(sampler_type, encoder, decoder, vae_scale=vae_scale, cfg_scale=cfg_scale) + + # Generate actions + if verbose: + print("Generating actions...") + mouse, buttons = generate_dummy_actions(pattern, length) + + # Synthesize video + if verbose: + print("Synthesizing video...") + video = synthesize_video(mouse, buttons, encoder, decoder, sampler) + + # Save video + if verbose: + print("Saving video...") + path = save_video(video, f"{sampler_type}_render_{pattern.value}_{vae_scale=:.2f}_{cfg_scale=:.2f}") + + if verbose: + print(f"✅ Done! Video saved to: {path}") + print(f" Video shape: {video.shape}") + + return path + + +if __name__ == "__main__": + # Simple usage examples + print("🎮 Simple OWL-WMS Video Renderer") + + # Render with default settings + encoder, decoder, model_config = load_models(device=DEVICE, verbose=True) + render_video(verbose=True, sampler_type='window', encoder=encoder, decoder=decoder) + render_video(verbose=True, sampler_type='cfg', encoder=encoder, decoder=decoder) diff --git a/webapp/utils/visualize_overlay_actions.py b/webapp/utils/visualize_overlay_actions.py new file mode 100644 index 00000000..03903b3f --- /dev/null +++ b/webapp/utils/visualize_overlay_actions.py @@ -0,0 +1,401 @@ +import imageio as imio +import cv2 +import numpy as np +import math +import torch +from typing import Optional +from contextlib import contextmanager + + +# Global configuration +KEYBINDS = ["W", "A", "S", "D", "LSHIFT", "SPACE", "R", "F", "E", "LMB", "RMB"] +MINIMUM_FRAME_SIZE = 150 +# Colors (BGR format for OpenCV) +COLOR_PRESSED = (50, 200, 50) # Green +COLOR_UNPRESSED = (100, 100, 100) # Gray +COLOR_TEXT = (255, 255, 255) # White +COLOR_BACKGROUND = (30, 30, 30) # Dark gray +COLOR_MOUSE_ARROW = COLOR_PRESSED # Green +COLOR_UNCERTAINTY = (100, 255, 255) # Yellow-ish +COLOR_LMB_SECTOR = COLOR_PRESSED # Green +COLOR_RMB_SECTOR = COLOR_PRESSED # Green + +# Key dimensions +KEY_SIZE = 20 +KEY_MARGIN = 5 +SHIFT_WIDTH = int(KEY_SIZE * 2 + KEY_MARGIN) # Two keys worth of width +SPACE_WIDTH = int(KEY_SIZE * 5) + +# Mouse compass dimensions +COMPASS_RADIUS = 24 # 80% of 80 +COMPASS_START_X_PERCENT = 0.80 + +# Mouse button arc dimensions +MOUSE_BUTTON_OFFSET = 6 # Pixels outside the main circle +MOUSE_BUTTON_THICKNESS = 6 # Increased thickness + +# Arrow scaling parameters +ARROW_SCALE_FACTOR = 1 # Base scaling factor for arrow length +ARROW_MIN_LENGTH = 55 # Minimum arrow length in pixels +ARROW_MAX_SCALE = 0.75 # Maximum scale relative to compass radius + +START_X_PERCENT = 0.18 +START_Y_PERCENT = 0.85 + + +@contextmanager +def _rescale_icons(ratio: float): + """ + Rescale all icon sizes to the ratio of the original video to the 512x512 video + """ + global KEY_SIZE, KEY_MARGIN, SHIFT_WIDTH, SPACE_WIDTH, COMPASS_RADIUS, MOUSE_BUTTON_OFFSET, MOUSE_BUTTON_THICKNESS + global ARROW_SCALE_FACTOR, ARROW_MIN_LENGTH, ARROW_MAX_SCALE + global START_X_PERCENT, START_Y_PERCENT + + try: + old_values = { + "KEY_SIZE": KEY_SIZE, + "KEY_MARGIN": KEY_MARGIN, + "SHIFT_WIDTH": SHIFT_WIDTH, + "SPACE_WIDTH": SPACE_WIDTH, + "COMPASS_RADIUS": COMPASS_RADIUS, + "MOUSE_BUTTON_OFFSET": MOUSE_BUTTON_OFFSET, + "MOUSE_BUTTON_THICKNESS": MOUSE_BUTTON_THICKNESS, + "ARROW_SCALE_FACTOR": ARROW_SCALE_FACTOR, + } + + KEY_SIZE *= ratio ; KEY_SIZE = int(KEY_SIZE) + KEY_MARGIN *= ratio ; KEY_MARGIN = int(KEY_MARGIN) + SHIFT_WIDTH *= ratio ; SHIFT_WIDTH = int(SHIFT_WIDTH) + SPACE_WIDTH *= ratio ; SPACE_WIDTH = int(SPACE_WIDTH) + COMPASS_RADIUS *= ratio ; COMPASS_RADIUS = int(COMPASS_RADIUS) + MOUSE_BUTTON_OFFSET *= ratio ; MOUSE_BUTTON_OFFSET = int(MOUSE_BUTTON_OFFSET) + MOUSE_BUTTON_THICKNESS *= ratio ; MOUSE_BUTTON_THICKNESS = int(MOUSE_BUTTON_THICKNESS) + ARROW_SCALE_FACTOR *= ratio ; ARROW_SCALE_FACTOR = float(ARROW_SCALE_FACTOR) + # Note: ARROW_MIN_LENGTH and ARROW_MAX_SCALE are now calculated relative to COMPASS_RADIUS + # so they don't need separate scaling + yield + finally: + for key, value in old_values.items(): + globals()[key] = value + + +def _get_adaptive_positioning(frame_width: int, frame_height: int) -> tuple[float, float, float]: + """ + Calculate adaptive positioning percentages based on frame dimensions. + Returns (keyboard_start_y_percent, mouse_start_y_percent, start_x_percent) + """ + aspect_ratio = frame_width / frame_height + + # Calculate how much vertical space the keyboard needs + keyboard_height_needed = KEY_SIZE * 3 + KEY_MARGIN * 2 + 20 # 3 rows + margins + buffer + + # Adaptive Y positioning - ensure keyboard fits + if frame_height <= keyboard_height_needed + 40: # Very short frame + keyboard_start_y_percent = 0.1 # Start near top + mouse_start_y_percent = 0.6 # Place mouse in middle-bottom + elif frame_height < 300: # Short frame + keyboard_start_y_percent = 0.4 + mouse_start_y_percent = 0.75 + else: # Normal/tall frame + keyboard_start_y_percent = 0.75 + mouse_start_y_percent = 0.85 + + # Adaptive X positioning based on aspect ratio + if aspect_ratio > 1.5: # Wide frame + start_x_percent = 0.08 + elif aspect_ratio < 0.75: # Tall frame + start_x_percent = 0.15 + else: # Near square frame + start_x_percent = 0.12 + + return keyboard_start_y_percent, mouse_start_y_percent, start_x_percent + + + +def _draw_buttons( + frame: np.ndarray, + button_sequence: list[bool], +) -> None: + """ + Draw keyboard buttons on the frame with adaptive positioning. + """ + frame_height, frame_width = frame.shape[:2] + + # Get adaptive positioning + keyboard_start_y_percent, _, start_x_percent = _get_adaptive_positioning(frame_width, frame_height) + + # Starting position for keyboard layout + start_x = int(frame_width * start_x_percent) + start_y = int(frame_height * keyboard_start_y_percent) + + # Rest of the function remains the same... + key_positions = { + # Top row: W E R (W above S) + "W": (start_x + (KEY_SIZE + KEY_MARGIN) * 1, start_y), + "E": (start_x + (KEY_SIZE + KEY_MARGIN) * 2, start_y), + "R": (start_x + (KEY_SIZE + KEY_MARGIN) * 3, start_y), + + # Middle row: A S D F + "A": (start_x + (KEY_SIZE + KEY_MARGIN) * 0, start_y + KEY_SIZE + KEY_MARGIN), + "S": (start_x + (KEY_SIZE + KEY_MARGIN) * 1, start_y + KEY_SIZE + KEY_MARGIN), + "D": (start_x + (KEY_SIZE + KEY_MARGIN) * 2, start_y + KEY_SIZE + KEY_MARGIN), + "F": (start_x + (KEY_SIZE + KEY_MARGIN) * 3, start_y + KEY_SIZE + KEY_MARGIN), + + # Bottom row: LSHIFT SPACE + "LSHIFT": (start_x - (KEY_SIZE + KEY_MARGIN), start_y + (KEY_SIZE + KEY_MARGIN) * 2), + "SPACE": (start_x + SHIFT_WIDTH + KEY_MARGIN, start_y + (KEY_SIZE + KEY_MARGIN) * 2), + } + + # Draw each key (rest remains the same) + for i, key in enumerate(KEYBINDS[:-2]): # Exclude LMB and RMB + if key in key_positions: + x, y = key_positions[key] + + # Ensure keys stay within frame bounds + if x < 0 or y < 0 or x + KEY_SIZE > frame_width or y + KEY_SIZE > frame_height: + continue + + # Determine key dimensions + if key == "LSHIFT": + width = SHIFT_WIDTH + height = KEY_SIZE + elif key == "SPACE": + width = SPACE_WIDTH + height = KEY_SIZE + else: + width = KEY_SIZE + height = KEY_SIZE + + # Final bounds check with actual key dimensions + if x + width > frame_width or y + height > frame_height: + continue + + # Determine color based on pressed state + color = COLOR_PRESSED if button_sequence[i] else COLOR_UNPRESSED + + # Draw key background + cv2.rectangle(frame, (x, y), (x + width, y + height), color, -1) + + # Draw key border + cv2.rectangle(frame, (x, y), (x + width, y + height), COLOR_TEXT, 1) + + # Draw key label + text_size = cv2.getTextSize(key, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + text_x = x + (width - text_size[0]) // 2 + text_y = y + (height + text_size[1]) // 2 + cv2.putText(frame, key, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, + 0.5, COLOR_TEXT, 1, cv2.LINE_AA) + +def _draw_mouse( + frame: np.ndarray, + LMB_on: bool, + RMB_on: bool, + mouse_delta: tuple[float, float], + center: tuple[int, int], +) -> None: + """ + Draw mouse compass with direction arrow and properly positioned labels. + Arrow length is proportional to mouse movement magnitude. + + Args: + frame: numpy array representing the image frame + LMB_on: bool indicating if left mouse button is pressed + RMB_on: bool indicating if right mouse button is pressed + mouse_delta: tuple of floats (x, y) representing mouse direction + center: tuple (x, y) for compass center position + """ + # Draw compass circle + cv2.circle(frame, center, COMPASS_RADIUS, COLOR_TEXT, 2) + + # Calculate outer radius for mouse buttons (slightly outside the main compass) + button_radius = COMPASS_RADIUS + MOUSE_BUTTON_OFFSET + + # Draw LMB arc (top-left 45 degrees) - outside the main circle and thicker + color_lmb = COLOR_LMB_SECTOR if LMB_on else (60, 60, 60) # Gray when off + cv2.ellipse(frame, center, (button_radius, button_radius), + 0, 225, 270, color_lmb, MOUSE_BUTTON_THICKNESS) + + # Draw RMB arc (top-right 45 degrees) - outside the main circle and thicker + color_rmb = COLOR_RMB_SECTOR if RMB_on else (60, 60, 60) # Gray when off + cv2.ellipse(frame, center, (button_radius, button_radius), + 0, 270, 315, color_rmb, MOUSE_BUTTON_THICKNESS) + + # Fixed label positioning - spread them out more and position them better + # LMB label - position it to the left of the compass + text_loc_lmb = (center[0] - COMPASS_RADIUS - 15, center[1] - COMPASS_RADIUS + 5) + cv2.putText(frame, 'LMB', text_loc_lmb, cv2.FONT_HERSHEY_SIMPLEX, 0.4, COLOR_TEXT, 1, cv2.LINE_AA) + + # RMB label - position it to the right of the compass + text_loc_rmb = (center[0] + COMPASS_RADIUS - 10, center[1] - COMPASS_RADIUS + 5) + cv2.putText(frame, 'RMB', text_loc_rmb, cv2.FONT_HERSHEY_SIMPLEX, 0.4, COLOR_TEXT, 1, cv2.LINE_AA) + + # Calculate mouse direction + mouse_x, mouse_y = mouse_delta + magnitude = math.sqrt(mouse_x**2 + mouse_y**2) + + if magnitude > 0: + # Normalize to unit vector + unit_x = mouse_x / magnitude + unit_y = mouse_y / magnitude + + # Fixed arrow scaling - ensure arrow stays within compass bounds + # Calculate minimum and maximum arrow lengths relative to compass size + min_arrow_length = COMPASS_RADIUS * 0.3 # 30% of compass radius + max_arrow_length = COMPASS_RADIUS * 0.8 # 80% of compass radius (stays within circle) + + # Scale by magnitude but clamp to reasonable bounds + arrow_scale = max( + min_arrow_length, # Minimum visible arrow + min( + COMPASS_RADIUS * magnitude * ARROW_SCALE_FACTOR, # Proportional to magnitude + max_arrow_length # Cap at 80% of compass radius + ) + ) + + end_x = int(center[0] + unit_x * arrow_scale) + end_y = int(center[1] - unit_y * arrow_scale) # Negative because y-axis is inverted + + # Draw direction arrow with proportional length + cv2.arrowedLine(frame, center, (end_x, end_y), COLOR_MOUSE_ARROW, 2, tipLength=0.4) + + # Optional: Display magnitude as text for debugging + # magnitude_text = f"Mag: {magnitude:.2f}" + # cv2.putText(frame, magnitude_text, (center[0] - 40, center[1] + button_radius + 40), + # cv2.FONT_HERSHEY_SIMPLEX, 0.4, COLOR_TEXT, 1, cv2.LINE_AA) + + # Draw center dot + cv2.circle(frame, center, 2, COLOR_TEXT, -1) + + # Add "Mouse" label below the compass + cv2.putText(frame, "Mouse", (center[0] - 20, center[1] + button_radius + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLOR_TEXT, 1, cv2.LINE_AA) + + +def _draw_frame( + frame: np.ndarray, + buttons: list[bool] | torch.Tensor, + mouse_delta: tuple[float, float], +) -> np.ndarray: + """ + Overlay keyboard & mouse on a single frame with adaptive positioning. + """ + buttons, lmb, rmb = buttons[:-2], buttons[-2], buttons[-1] + _draw_buttons(frame, buttons) + + frame_height, frame_width = frame.shape[:2] + + # Get adaptive positioning + _, mouse_start_y_percent, _ = _get_adaptive_positioning(frame_width, frame_height) + + # Calculate compass position + compass_x_percent = COMPASS_START_X_PERCENT + compass_x = int(frame_width * compass_x_percent) - COMPASS_RADIUS - MOUSE_BUTTON_OFFSET + compass_y = int(frame_height * mouse_start_y_percent) + + # Ensure compass fits within frame boundaries + min_x = COMPASS_RADIUS + MOUSE_BUTTON_OFFSET + 20 + max_x = frame_width - COMPASS_RADIUS - MOUSE_BUTTON_OFFSET - 20 + compass_x = max(min_x, min(compass_x, max_x)) + + min_y = COMPASS_RADIUS + MOUSE_BUTTON_OFFSET + 30 + max_y = frame_height - COMPASS_RADIUS - MOUSE_BUTTON_OFFSET - 50 + compass_y = max(min_y, min(compass_y, max_y)) + + _draw_mouse(frame, + LMB_on=lmb, RMB_on=rmb, + mouse_delta=mouse_delta, + center=(compass_x, compass_y)) + return frame + +def _draw_video( + video: torch.Tensor, + buttons: torch.Tensor, + mouse_delta: torch.Tensor, + save_path: Optional[str] = None, + fps: int = 30, + arrow_scale_factor: Optional[float] = None, + arrow_max_scale: Optional[float] = None, +) -> list[np.ndarray]: + """ + Draw video with input device monitoring overlays. + + Args: + video: torch tensor of video frames, [num_frames, 256, 256, 3] + buttons: torch tensor of button states + mouse_delta: torch tensor of mouse deltas between frames + save_path: optional path to save video + fps: frames per second for output video + arrow_scale_factor: optional override for arrow scaling + arrow_max_scale: optional override for maximum arrow scale + """ + # Update global arrow parameters if provided + global ARROW_SCALE_FACTOR, ARROW_MAX_SCALE + if arrow_scale_factor is not None: + ARROW_SCALE_FACTOR = arrow_scale_factor + if arrow_max_scale is not None: + ARROW_MAX_SCALE = arrow_max_scale + + video = video.float().cpu().numpy() + + # Get original dimensions + original_height, original_width = video.shape[1], video.shape[2] + + # Calculate scaling to meet minimum size while preserving aspect ratio + if original_height < MINIMUM_FRAME_SIZE or original_width < MINIMUM_FRAME_SIZE: + # Calculate scale factor to make smallest dimension equal to MINIMUM_FRAME_SIZE + scale_factor = MINIMUM_FRAME_SIZE / min(original_height, original_width) + new_height = int(original_height * scale_factor) + new_width = int(original_width * scale_factor) + + frames = [cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_CUBIC) for frame in video] + # UI scaling ratio - keep icons at base size since we're scaling up small frames + ratio = 1.0 + else: + frames = video + # UI scaling ratio - scale icons proportional to how much larger the frame is than minimum + # This ensures icons don't become too large on big frames + ratio = min(original_height, original_width) / MINIMUM_FRAME_SIZE + + with _rescale_icons(ratio): + frames = [ + _draw_frame(frame, buttons[i], mouse_delta[i]) + for i, frame in enumerate(frames) + ] + if save_path is not None: + imio.mimsave(save_path, [f.astype(np.uint8) for f in frames], fps=fps) + + return frames + + +# Example usage +if __name__ == "__main__": + # Example button states (9 buttons, excluding LMB and RMB) + button_states = [True, False, True, False, True, False, False, True, False] + + # Example mouse state + LMB_pressed = True + RMB_pressed = False + mouse_direction = (0.7, 0.7) # Diagonal up-right + mouse_uncertainty = (0.2, 0.1) + + vidpath, mousepath, buttonpath = '.pt' + + video = torch.load(vidpath, map_location='cpu', mmap=True) + mouse = torch.load(mousepath, map_location='cpu', mmap=True) + buttons = torch.load(buttonpath, map_location='cpu', mmap=True) # [1, window_length, n_buttons] + + min_len = min(len(video), len(mouse), len(buttons)) + video = video[:min_len] + mouse = mouse[:min_len] + buttons = buttons[:min_len] + + video = video.float() + buttons = buttons + mouse = mouse + + # You can adjust arrow scaling parameters here + _draw_video(video, buttons, mouse, + save_path="groundtruth.mp4", + arrow_scale_factor=0.3, # Adjust this to control arrow sensitivity + arrow_max_scale=0.9) # Maximum arrow length relative to compass diff --git a/webapp/webapp_config.yaml b/webapp/webapp_config.yaml new file mode 100644 index 00000000..bc4b4b73 --- /dev/null +++ b/webapp/webapp_config.yaml @@ -0,0 +1,27 @@ +run_config_path: "webapp/checkpoints/configs/av.yml" +device: "cuda" + +stream_config: + fps: 10 + frames_per_batch: 1 + device: "cuda" + n_buttons: 11 + n_mouse_axes: 2 + mouse_range: [-1.0, 1.0] + video_latent_history_path: "webapp/static/histories/base1/video_latent.pt" + audio_latent_history_path: "webapp/static/histories/base1/audio_latent.pt" + mouse_history_path: "webapp/static/histories/base1/mouse.pt" + button_history_path: "webapp/static/histories/base1/buttons.pt" + model_checkpoint_path: "webapp/checkpoints/models/av_dfot_35k_ema_200m.pt" + +sampling_config: + sampling_steps: 20 + vae_scale: 1.0 + cfg_scale: 1.3 + window_length: 60 + num_frames: 1 + noise_prev: 0.25 + +run_config: null # loaded at runtime from model_config_path, and used to access model and train config + +