Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

# Sphinx documentation
docs/_build/

# Other
wandb/
checkpoints/
94 changes: 94 additions & 0 deletions configs/av_ddpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Config for DDPO trainer - based on basic.yml
model:
model_id: game_rft_audio
sample_size: 4
channels: 128
audio_channels: 64

n_layers: 12 # 25
n_heads: 12 # 24
d_model: 768 # 1536

tokens_per_frame: 17
n_buttons: 11
n_mouse_axes: 2

cfg_prob: 0.1
n_frames: 60

causal: false

train:
trainer_id: ddpo
data_id: cod_s3_audio
data_kwargs:
window_length: 60
bucket_name: cod-data-latent-360x640to4x4

# DDPO-specific parameters
sampling_steps: 64 # might not be right?
timestep_fraction: 0.5
clip_range: 0.2
adv_clip_max: 5.0
sample_batch_size: 16 # TODO: Remove this maybe? Is it used?
num_batches_per_epoch: 8
num_inner_epochs: 1

# Reward function configuration
# Option 1: Load from file
# reward_fn:
# module: "rewards/example_reward.py"
# function: "reward_function"
# Option 2: Load from importable module
# reward_fn: "my_rewards.simple_reward"
reward_fn:
module: "/home/pcurtin/owl-wms/rewards.py"
function: "darkness_reward"

# Standard training parameters
target_batch_size: 256
batch_size: 4
epochs: 200

opt: AdamW
opt_kwargs:
lr: 1.0e-3
weight_decay: 1.0e-4
eps: 1.0e-8
betas: [0.9, 0.999]

scheduler: null

checkpoint_dir: checkpoints/ddpo
resume_ckpt: null

sample_interval: 1000
save_interval: 5000

# VAE configuration (required for DDPO)
vae_id: null
vae_cfg_path: ../models/cod_128x.yml
vae_ckpt_path: ../models/cod_128x_30k_ema.pt
vae_scale: 0.13
audio_vae_scale: 0.17
vae_batch_size: 4

audio_vae_id: null
audio_vae_cfg_path: ../models/cod_audio.yml
audio_vae_ckpt_path: ../models/cod_audio_20k_ema.pt

sampler_id: av_window
sampler_kwargs:
n_steps: 20
cfg_scale: 1.3
window_length: 60
num_frames: 120
noise_prev: 0.2
only_return_generated: false

n_samples: 4

wandb:
name: peter_curtin
project: owl
run_name: av_ddpo
76 changes: 76 additions & 0 deletions configs/av_peter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
model:
model_id: game_rft_audio
sample_size: 4
channels: 128
audio_channels: 64

n_layers: 25
n_heads: 24
d_model: 1536

tokens_per_frame: 17
n_buttons: 11
n_mouse_axes: 2

cfg_prob: 0.1
n_frames: 60

causal: false

train:
trainer_id: av
data_id: cod_s3_audio
data_kwargs:
window_length: 60
bucket_name: cod-data-latent-360x640to4x4

target_batch_size: 256
batch_size: 2

epochs: 200

opt: Muon
opt_kwargs:
lr: 1.0e-3
momentum: 0.95
adamw_lr: 1.0e-4
adamw_wd: 1.0e-4
adamw_eps: 1.0e-15
adamw_betas: [0.9, 0.95]
adamw_keys: [core.proj_in, core.proj_out.proj]

scheduler: null

checkpoint_dir: checkpoints/av_huge
resume_ckpt: null # checkpoints/av_huge/step_50000.pt

sample_interval: 1000
save_interval: 5000

sampler_id: av_window
sampler_kwargs:
n_steps: 20
cfg_scale: 1.3
window_length: 60
num_frames: 120
noise_prev: 0.2
only_return_generated: false

n_samples: 4

vae_id: null
vae_batch_size: 4
vae_scale: 0.13
audio_vae_scale: 0.17

vae_cfg_path: /home/pcurtin/models/cod_128x.yml # configs/owl_vaes/cod_128x.yml
vae_ckpt_path: /home/pcurtin/models/cod_128x_30k_ema.pt # checkpoints/owl_vaes/cod_128x_30k_ema.pt

audio_vae_id: null
audio_vae_cfg_path: /home/pcurtin/models/cod_audio.yml # configs/owl_vaes/cod_audio.yml
audio_vae_ckpt_path: /home/pcurtin/models/cod_audio_20k_ema.pt # checkpoints/owl_vaes/cod_audio_20k_ema.pt

wandb:
name: shahbuland
project: video_models
run_name: av
2 changes: 1 addition & 1 deletion owl-vaes
4 changes: 3 additions & 1 deletion owl_wms/data/s3_cod_latent_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def background_download_tars(self):
tar_data = response['Body'].read()
self.tar_queue.add(tar_data)
except Exception as e:
print(f"Error downloading tar {tar_path}: {e}")
# TODO: Uncomment this before merge - can't stand the error messages.
# print(f"Error downloading tar {tar_path}: {e}")
pass
else:
time.sleep(1)

Expand Down
5 changes: 4 additions & 1 deletion owl_wms/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ def get_trainer_cls(trainer_id):
return CausVidTrainer
if trainer_id == "av":
from .av_trainer import AVRFTTrainer
return AVRFTTrainer
return AVRFTTrainer
if trainer_id == "ddpo":
from .ddpo_trainer import DDPOTrainer
return DDPOTrainer
Loading