diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..e294ff2b --- /dev/null +++ b/.gitignore @@ -0,0 +1,54 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +# Sphinx documentation +docs/_build/ + +# Other +wandb/ +checkpoints/ \ No newline at end of file diff --git a/configs/av_ddpo.yml b/configs/av_ddpo.yml new file mode 100644 index 00000000..871a2dd6 --- /dev/null +++ b/configs/av_ddpo.yml @@ -0,0 +1,94 @@ +# Config for DDPO trainer - based on basic.yml +model: + model_id: game_rft_audio + sample_size: 4 + channels: 128 + audio_channels: 64 + + n_layers: 12 # 25 + n_heads: 12 # 24 + d_model: 768 # 1536 + + tokens_per_frame: 17 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 60 + + causal: false + +train: + trainer_id: ddpo + data_id: cod_s3_audio + data_kwargs: + window_length: 60 + bucket_name: cod-data-latent-360x640to4x4 + + # DDPO-specific parameters + sampling_steps: 64 # might not be right? + timestep_fraction: 0.5 + clip_range: 0.2 + adv_clip_max: 5.0 + sample_batch_size: 16 # TODO: Remove this maybe? Is it used? + num_batches_per_epoch: 8 + num_inner_epochs: 1 + + # Reward function configuration + # Option 1: Load from file + # reward_fn: + # module: "rewards/example_reward.py" + # function: "reward_function" + # Option 2: Load from importable module + # reward_fn: "my_rewards.simple_reward" + reward_fn: + module: "/home/pcurtin/owl-wms/rewards.py" + function: "darkness_reward" + + # Standard training parameters + target_batch_size: 256 + batch_size: 4 + epochs: 200 + + opt: AdamW + opt_kwargs: + lr: 1.0e-3 + weight_decay: 1.0e-4 + eps: 1.0e-8 + betas: [0.9, 0.999] + + scheduler: null + + checkpoint_dir: checkpoints/ddpo + resume_ckpt: null + + sample_interval: 1000 + save_interval: 5000 + + # VAE configuration (required for DDPO) + vae_id: null + vae_cfg_path: ../models/cod_128x.yml + vae_ckpt_path: ../models/cod_128x_30k_ema.pt + vae_scale: 0.13 + audio_vae_scale: 0.17 + vae_batch_size: 4 + + audio_vae_id: null + audio_vae_cfg_path: ../models/cod_audio.yml + audio_vae_ckpt_path: ../models/cod_audio_20k_ema.pt + + sampler_id: av_window + sampler_kwargs: + n_steps: 20 + cfg_scale: 1.3 + window_length: 60 + num_frames: 120 + noise_prev: 0.2 + only_return_generated: false + + n_samples: 4 + +wandb: + name: peter_curtin + project: owl + run_name: av_ddpo \ No newline at end of file diff --git a/configs/av_peter.yml b/configs/av_peter.yml new file mode 100644 index 00000000..44d507dc --- /dev/null +++ b/configs/av_peter.yml @@ -0,0 +1,76 @@ +model: + model_id: game_rft_audio + sample_size: 4 + channels: 128 + audio_channels: 64 + + n_layers: 25 + n_heads: 24 + d_model: 1536 + + tokens_per_frame: 17 + n_buttons: 11 + n_mouse_axes: 2 + + cfg_prob: 0.1 + n_frames: 60 + + causal: false + +train: + trainer_id: av + data_id: cod_s3_audio + data_kwargs: + window_length: 60 + bucket_name: cod-data-latent-360x640to4x4 + + target_batch_size: 256 + batch_size: 2 + + epochs: 200 + + opt: Muon + opt_kwargs: + lr: 1.0e-3 + momentum: 0.95 + adamw_lr: 1.0e-4 + adamw_wd: 1.0e-4 + adamw_eps: 1.0e-15 + adamw_betas: [0.9, 0.95] + adamw_keys: [core.proj_in, core.proj_out.proj] + + scheduler: null + + checkpoint_dir: checkpoints/av_huge + resume_ckpt: null # checkpoints/av_huge/step_50000.pt + + sample_interval: 1000 + save_interval: 5000 + + sampler_id: av_window + sampler_kwargs: + n_steps: 20 + cfg_scale: 1.3 + window_length: 60 + num_frames: 120 + noise_prev: 0.2 + only_return_generated: false + + n_samples: 4 + + vae_id: null + vae_batch_size: 4 + vae_scale: 0.13 + audio_vae_scale: 0.17 + + vae_cfg_path: /home/pcurtin/models/cod_128x.yml # configs/owl_vaes/cod_128x.yml + vae_ckpt_path: /home/pcurtin/models/cod_128x_30k_ema.pt # checkpoints/owl_vaes/cod_128x_30k_ema.pt + + audio_vae_id: null + audio_vae_cfg_path: /home/pcurtin/models/cod_audio.yml # configs/owl_vaes/cod_audio.yml + audio_vae_ckpt_path: /home/pcurtin/models/cod_audio_20k_ema.pt # checkpoints/owl_vaes/cod_audio_20k_ema.pt + +wandb: + name: shahbuland + project: video_models + run_name: av diff --git a/owl-vaes b/owl-vaes index cdde9f3e..303c2eeb 160000 --- a/owl-vaes +++ b/owl-vaes @@ -1 +1 @@ -Subproject commit cdde9f3e93cbae5ed77c99e6d7f29926ba89af42 +Subproject commit 303c2eeb03a6a361e603307d28c85776e8bfded7 diff --git a/owl_wms/data/s3_cod_latent_audio.py b/owl_wms/data/s3_cod_latent_audio.py index 1a8b7873..79341b18 100644 --- a/owl_wms/data/s3_cod_latent_audio.py +++ b/owl_wms/data/s3_cod_latent_audio.py @@ -84,7 +84,9 @@ def background_download_tars(self): tar_data = response['Body'].read() self.tar_queue.add(tar_data) except Exception as e: - print(f"Error downloading tar {tar_path}: {e}") + # TODO: Uncomment this before merge - can't stand the error messages. + # print(f"Error downloading tar {tar_path}: {e}") + pass else: time.sleep(1) diff --git a/owl_wms/trainers/__init__.py b/owl_wms/trainers/__init__.py index d6fef3f9..80e29750 100644 --- a/owl_wms/trainers/__init__.py +++ b/owl_wms/trainers/__init__.py @@ -7,4 +7,7 @@ def get_trainer_cls(trainer_id): return CausVidTrainer if trainer_id == "av": from .av_trainer import AVRFTTrainer - return AVRFTTrainer \ No newline at end of file + return AVRFTTrainer + if trainer_id == "ddpo": + from .ddpo_trainer import DDPOTrainer + return DDPOTrainer \ No newline at end of file diff --git a/owl_wms/trainers/ddpo_trainer.py b/owl_wms/trainers/ddpo_trainer.py new file mode 100644 index 00000000..6258880b --- /dev/null +++ b/owl_wms/trainers/ddpo_trainer.py @@ -0,0 +1,586 @@ +import torch +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +from collections import defaultdict +import time +from concurrent import futures +import numpy as np +import wandb +from functools import partial +import tqdm +from ema_pytorch import EMA +import importlib.util +import os +import math + +from .base import BaseTrainer +from ..utils import freeze, Timer, find_unused_params +from ..schedulers import get_scheduler_cls +from ..models import get_model_cls +from ..sampling import get_sampler_cls +from ..data import get_loader +from ..utils.logging import LogHelper, to_wandb +from ..utils.owl_vae_bridge import get_decoder_only, make_batched_decode_fn +from ..muon import init_muon +from omegaconf.dictconfig import DictConfig + +tqdm = partial(tqdm.tqdm, dynamic_ncols=True) + + +class DDPOTrainer(BaseTrainer): + """ + Trainer for DDPO (Denoising Diffusion Policy Optimization) for video world models. + + Implements reinforcement learning with human feedback (RLHF) for world models + using PPO-style policy gradient optimization. World models take video frames, + mouse movements, and button presses as inputs. + + :param train_cfg: Configuration for training + :param logging_cfg: Configuration for logging + :param model_cfg: Configuration for model + :param global_rank: Rank across all devices + :param local_rank: Rank for current device on this process + :param world_size: Overall number of devices + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Initialize model + model_id = self.model_cfg.model_id + self.model = get_model_cls(model_id)(self.model_cfg) + + # Print model size + if self.rank == 0: + n_params = sum(p.numel() for p in self.model.parameters()) + print(f"Model has {n_params:,} parameters") + + # Initialize training components + self.ema = None + self.opt = None + self.scheduler = None + self.scaler = None + self.total_step_counter = 0 + + # Initialize video decoder + self.decoder = get_decoder_only( + self.train_cfg.vae_id, + self.train_cfg.vae_cfg_path, + self.train_cfg.vae_ckpt_path + ) + freeze(self.decoder) + + # DDPO specific components + self.reward_fn = None # Will be set by user or loaded from config + self._load_reward_function() + + # Initialize executor for async reward computation + self.executor = futures.ThreadPoolExecutor(max_workers=2) + + # DDPO hyperparameters from train_cfg + self.sampling_steps = self.train_cfg.get('sampling_steps', 64) + self.num_train_timesteps = int( + self.sampling_steps * self.train_cfg.get('timestep_fraction', 0.5) + ) + self.clip_range = self.train_cfg.get('clip_range', 0.2) + self.adv_clip_max = self.train_cfg.get('adv_clip_max', 5.0) + + # Sampling configuration + self.sample_batch_size = self.train_cfg.get('sample_batch_size', 4) + self.num_batches_per_epoch = self.train_cfg.get('num_batches_per_epoch', 4) + self.num_inner_epochs = self.train_cfg.get('num_inner_epochs', 1) + + def _load_reward_function(self): + """Load reward function from config if specified.""" + reward_config = getattr(self.train_cfg, 'reward_fn', None) + if reward_config is None: + return + + if isinstance(reward_config, DictConfig): + # Config format: {"module": "path/to/file.py", "function": "reward_function_name"} + module_path = reward_config.get('module') + function_name = reward_config.get('function') + + if module_path and function_name: + try: + # Load module from file path + if not os.path.isabs(module_path): + # Make relative paths relative to config directory + config_dir = os.path.dirname(os.path.abspath("")) # Assumes we're in project root + module_path = os.path.join(config_dir, module_path) + + spec = importlib.util.spec_from_file_location("reward_module", module_path) + reward_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(reward_module) + + + + # Get the function from the module + if hasattr(reward_module, function_name): + self.reward_fn = getattr(reward_module, function_name) + if self.rank == 0: + print(f"Loaded reward function '{function_name}' from {module_path}") + else: + raise AttributeError(f"Function '{function_name}' not found in {module_path}") + + except Exception as e: + if self.rank == 0: + print(f"Error loading reward function: {e}") + raise + elif isinstance(reward_config, str): + # Simple format: "module.function" (assumes module is importable) + try: + module_name, function_name = reward_config.rsplit('.', 1) + module = importlib.import_module(module_name) + self.reward_fn = getattr(module, function_name) + if self.rank == 0: + print(f"Loaded reward function '{function_name}' from module '{module_name}'") + except Exception as e: + if self.rank == 0: + print(f"Error loading reward function: {e}") + raise + + def set_reward_function(self, reward_fn): + """Set the reward function for DDPO training.""" + self.reward_fn = reward_fn + + def setup_training(self): + """Setup optimizer, scheduler, and other training components.""" + torch.cuda.set_device(self.local_rank) + + # Prepare model + self.model = self.model.cuda().train() + if self.world_size > 1: + self.model = DDP(self.model, device_ids=[self.local_rank]) + + # Setup decoder + self.decoder = self.decoder.cuda().eval().bfloat16() + self.decode_fn = make_batched_decode_fn(self.decoder, self.train_cfg.vae_batch_size) + + # Setup EMA + self.ema = EMA( + self.model, + beta=0.999, + update_after_step=0, + update_every=1 + ) + + # Setup optimizer + if self.train_cfg.opt.lower() == "muon": + self.opt = init_muon(self.model, rank=self.rank, world_size=self.world_size, **self.train_cfg.opt_kwargs) + else: + self.opt = getattr(torch.optim, self.train_cfg.opt)(self.model.parameters(), **self.train_cfg.opt_kwargs) + + # Setup scheduler + if self.train_cfg.scheduler is not None: + self.scheduler = get_scheduler_cls(self.train_cfg.scheduler)(self.opt, **self.train_cfg.scheduler_kwargs) + + # Setup gradient accumulation and scaler + self.accum_steps = self.train_cfg.target_batch_size // self.train_cfg.batch_size // self.world_size + self.accum_steps = max(1, self.accum_steps) + self.scaler = torch.amp.GradScaler() + self.ctx = torch.amp.autocast('cuda', torch.bfloat16) + + def sample_batch_with_logprob(self, batch_vid, batch_mouse, batch_btn, batch_audio): + """ + Sample a batch of videos using the current model with log probability tracking. + + Treats the model as a standard denoising diffusion model with fresh noise sampling. + + :param batch_vid: Input video frames + :param batch_mouse: Mouse movement data + :param batch_btn: Button press data + :param batch_audio: Audio data + :return: Dictionary containing samples with latents, log_probs, etc. + """ + self.model.eval() + + with torch.no_grad(): + # Get EMA model for sampling + def get_ema_core(): + if self.world_size > 1: + return self.ema.ema_model.module.core + else: + return self.ema.ema_model.core + + model = get_ema_core() + batch_size = batch_vid.shape[0] + + # Initialize with noise + x = torch.randn_like(batch_vid) + n_frames = batch_mouse.shape[1] # Get sequence length from mouse data + ts = torch.ones(batch_size, n_frames, device=x.device, dtype=x.dtype) + dt = 1.0 / self.sampling_steps + + # Store trajectory data for DDPO + all_latents = [x.clone()] + all_log_probs = [] + all_timesteps = [ts.clone()] + all_preds = [] + all_noise = [] + + # Sampling loop with log probability tracking + for step in range(self.sampling_steps): + with self.ctx: + # Get model prediction (this is our v-network prediction) + pred_video, pred_audio = model(x, batch_audio, ts, batch_mouse, batch_btn) + all_preds.append((pred_video.clone(), pred_audio.clone())) + + if step < self.sampling_steps - 1: + # Sample fresh noise for the transition + noise = torch.randn_like(x) + all_noise.append(noise.clone()) + + # Standard DDPM update: x = x - pred_video * dt + noise * sqrt(dt) + # where pred_video is the velocity field (v-network output) + next_x = x - pred_video * dt + noise * torch.sqrt(torch.tensor(dt, device=x.device)) + next_ts = ts - dt + + # Empirical variance for velocity difference + actual_velocity = (x - next_x) / dt + velocity_var = torch.var(actual_velocity).item() + 1e-6 + # Clamp to a reasonable minimum to avoid division by zero + velocity_var = max(velocity_var, 1e-6) + + # Log probability based on velocity prediction + log_prob = -0.5 * torch.sum((pred_video - actual_velocity) ** 2, dim=(1, 2, 3, 4)) / velocity_var + + all_latents.append(next_x.clone()) + all_log_probs.append(log_prob) + all_timesteps.append(next_ts.clone()) + + x = next_x + ts = next_ts + + # Stack trajectory data + all_latents = torch.stack(all_latents, dim=1) # (batch, timesteps+1, ...) + all_log_probs = torch.stack(all_log_probs, dim=1) # (batch, timesteps) + all_timesteps = torch.stack(all_timesteps[:-1], dim=1) # (batch, timesteps) + all_noise = torch.stack(all_noise, dim=1) # (batch, timesteps, ...) + + # Separate video and audio predictions + all_video_preds = torch.stack([pred[0] for pred in all_preds], dim=1) # (batch, timesteps, ...) + all_audio_preds = torch.stack([pred[1] for pred in all_preds], dim=1) # (batch, timesteps, ...) + + return { + 'latents': all_latents[:, :-1], # Remove last timestep + 'next_latents': all_latents[:, 1:], # Remove first timestep + 'log_probs': all_log_probs, + 'timesteps': all_timesteps, + 'video_predictions': all_video_preds, + 'audio_predictions': all_audio_preds, + 'noise': all_noise, + 'final_latents': x, + 'mouse': batch_mouse, + 'buttons': batch_btn, + 'audio': batch_audio, + } + + def compute_rewards(self, final_latents, mouse_data, button_data, audio_data=None): + """ + Compute rewards for generated videos. + + :param final_latents: Final latent representations + :param mouse_data: Mouse movement data + :param button_data: Button press data + :param audio_data: Audio data (optional) + :return: Tensor of rewards + """ + if self.reward_fn is None: + raise ValueError("Reward function not set. Use set_reward_function() to set it.") + + # Decode latents to videos + # with torch.no_grad(): + # videos = self.decode_fn(final_latents * self.train_cfg.vae_scale) + # TODO: undo this + videos = final_latents + + # Compute rewards asynchronously + if audio_data is not None: + rewards_future = self.executor.submit(self.reward_fn, videos, mouse_data, button_data, audio_data) + else: + rewards_future = self.executor.submit(self.reward_fn, videos, mouse_data, button_data) + rewards = rewards_future.result() + + return torch.tensor(rewards, device=final_latents.device) + + def compute_advantages(self, rewards): + """Compute advantages for PPO.""" + # Simple advantage computation - can be enhanced with value function + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + return advantages + + def ppo_loss(self, log_probs, old_log_probs, advantages): + """ + Compute PPO clipped loss. + + :param log_probs: Current policy log probabilities + :param old_log_probs: Old policy log probabilities + :param advantages: Computed advantages + :return: PPO loss + """ + # https://github.com/pmcurtin/owl-wms/blob/5f5a118f9a02be4dea1e75460494e9354a80f194/resources/train.py#L539 + # Importance sampling + # j is a timestep + # ratio = torch.exp(log_prob - sample["log_probs"][:, j]) + ratio = torch.exp(log_probs - old_log_probs) + + # Clipped advantages + advantages = torch.clamp(advantages, -self.adv_clip_max, self.adv_clip_max) + + # PPO loss + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, 1.0 - self.clip_range, 1.0 + self.clip_range + ) + loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + return loss + + def train_step(self, samples, local_step): + """ + Perform one training step using PPO. + + :param samples: Dictionary containing trajectory samples + :param local_step: Current local step for gradient accumulation + :return: Dictionary of training metrics + """ + self.model.train() + + batch_size, num_timesteps = samples['timesteps'].shape[:2] + info = defaultdict(list) + + # Train on subset of timesteps (randomly sample for efficiency) + train_timesteps = min(num_timesteps, self.num_train_timesteps) + timestep_indices = torch.randperm(num_timesteps)[:train_timesteps] + + for t_idx in timestep_indices: + # Get current timestep data + latents = samples['latents'][:, t_idx] + next_latents = samples['next_latents'][:, t_idx] + timesteps = samples['timesteps'][:, t_idx] + old_log_probs = samples['log_probs'][:, t_idx] + old_video_predictions = samples['video_predictions'][:, t_idx] + advantages = samples['advantages'] + + with self.ctx: + # Forward pass - get current model prediction + model = self.get_module() + # Note: mouse and buttons need to be full sequences for this model + current_pred_video, current_pred_audio = model.core(latents, samples['audio'], timesteps, samples['mouse'], samples['buttons']) + + # Empirical variance for velocity difference (use detached tensors to avoid gradients through var) + actual_velocity = (latents - next_latents).detach() / (1.0 / self.sampling_steps) + velocity_var = torch.var(actual_velocity).item() + 1e-6 + velocity_var = max(velocity_var, 1e-6) + + # Log probability based on velocity prediction + log_prob = -0.5 * torch.sum((current_pred_video - actual_velocity) ** 2, dim=(1, 2, 3, 4)) / velocity_var + + # Compute PPO loss + loss = self.ppo_loss(log_prob, old_log_probs, advantages) / self.accum_steps + + # Backward pass + self.scaler.scale(loss).backward() + + # Compute metrics + with torch.no_grad(): + ratio = torch.exp(log_prob - old_log_probs) + info['approx_kl'].append(0.5 * torch.mean((log_prob - old_log_probs) ** 2)) + info['clipfrac'].append(torch.mean((torch.abs(ratio - 1.0) > self.clip_range).float())) + info['ppo_loss'].append(loss * self.accum_steps) + info['log_prob'].append(log_prob.mean()) + info['old_log_prob'].append(old_log_probs.mean()) + + # Optimization step (if gradient accumulation is complete) + if local_step % self.accum_steps == 0: + # Gradient clipping + if self.train_cfg.opt.lower() != "muon": + self.scaler.unscale_(self.opt) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + self.scaler.step(self.opt) + self.opt.zero_grad(set_to_none=True) + self.scaler.update() + + if self.scheduler is not None: + self.scheduler.step() + self.ema.update() + + # Average metrics + if info: + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + + return info + + def train_epoch(self, epoch, loader): + """Train for one epoch using DDPO.""" + if self.reward_fn is None: + raise ValueError("Reward function not set. Use set_reward_function() to set it.") + + # Sampling phase + self.model.eval() + all_samples = [] + all_rewards = [] + + for batch_idx, (batch_vid, batch_audio, batch_mouse, batch_btn) in enumerate(tqdm( + loader, + desc=f"Epoch {epoch}: Sampling", + disable=self.rank != 0, + total=self.num_batches_per_epoch + )): + # Prepare batch data + batch_vid = batch_vid.cuda().bfloat16() / self.train_cfg.vae_scale + batch_mouse = batch_mouse.cuda().bfloat16() + batch_btn = batch_btn.cuda().bfloat16() + batch_audio = batch_audio.cuda().bfloat16() + + # print(batch_vid.shape) + # print(batch_mouse.shape) + # print(batch_btn.shape) + # print(batch_audio.shape) + + # Sample trajectories with log probabilities + samples = self.sample_batch_with_logprob(batch_vid, batch_mouse, batch_btn, batch_audio) + + # Compute rewards + rewards = self.compute_rewards(samples['final_latents'], batch_mouse, batch_btn, batch_audio) + samples['rewards'] = rewards + + all_samples.append(samples) + all_rewards.extend(rewards.cpu().float().numpy()) + + # Limit number of batches for sampling + if batch_idx >= self.num_batches_per_epoch - 1: + + break + + # Compute advantages + all_rewards = np.array(all_rewards) + advantages = self.compute_advantages(torch.tensor(all_rewards)) + + # Add advantages to samples + start_idx = 0 + for samples in all_samples: + end_idx = start_idx + len(samples['rewards']) + samples['advantages'] = advantages[start_idx:end_idx].to(f'cuda:{self.local_rank}') + start_idx = end_idx + + # Training phase - multiple inner epochs over collected data + epoch_info = defaultdict(list) + local_step = 0 + + for inner_epoch in range(self.num_inner_epochs): + # Shuffle samples for training + shuffled_samples = all_samples.copy() + np.random.shuffle(shuffled_samples) + + for samples in tqdm( + shuffled_samples, + desc=f"Epoch {epoch}.{inner_epoch}: Training", + disable=self.rank != 0 + ): + info = self.train_step(samples, local_step) + local_step += 1 + + # Collect metrics + for k, v in info.items(): + epoch_info[k].append(v) + + # Log metrics + if self.rank == 0: + log_dict = { + 'epoch': epoch, + 'reward_mean': np.mean(all_rewards), + 'reward_std': np.std(all_rewards), + 'num_samples': len(all_rewards), + } + + # Add training metrics + for k, v in epoch_info.items(): + if len(v) > 0: + if isinstance(v[0], torch.Tensor): + log_dict[k] = torch.mean(torch.stack(v)).item() + else: + log_dict[k] = np.mean(v) + + # if self.logging_cfg is not None: + wandb.log(log_dict, step=self.total_step_counter) + + print(f"Epoch {epoch}: Reward {log_dict['reward_mean']:.3f} ± {log_dict['reward_std']:.3f}") + + self.total_step_counter += 1 + + return epoch_info + + def save(self): + save_dict = { + 'model': self.model.state_dict(), + 'ema': self.ema.state_dict(), + 'opt': self.opt.state_dict(), + 'scaler': self.scaler.state_dict(), + 'steps': self.total_step_counter + } + if self.scheduler is not None: + save_dict['scheduler'] = self.scheduler.state_dict() + super().save(save_dict) + + def load(self): + has_ckpt = False + try: + if self.train_cfg.resume_ckpt is not None: + save_dict = super().load(self.train_cfg.resume_ckpt) + has_ckpt = True + except: + print("Error loading checkpoint") + + if not has_ckpt: + return + + self.model.load_state_dict(save_dict['model']) + self.ema.load_state_dict(save_dict['ema']) + self.opt.load_state_dict(save_dict['opt']) + if self.scheduler is not None and 'scheduler' in save_dict: + self.scheduler.load_state_dict(save_dict['scheduler']) + self.scaler.load_state_dict(save_dict['scaler']) + self.total_step_counter = save_dict['steps'] + + def train(self): + """Main training loop.""" + if self.rank == 0: + print("Starting DDPO training...") + + # Setup training components + self.setup_training() + + # Load checkpoint if available + self.load() + + # Setup data loader + loader = get_loader(self.train_cfg.data_id, self.train_cfg.batch_size, **self.train_cfg.data_kwargs) + + # Timer and metrics + timer = Timer() + timer.reset() + metrics = LogHelper() + if self.rank == 0: + wandb.watch(self.get_module(), log='all') + + # Training loop + for epoch in range(self.train_cfg.epochs): + self.barrier() + + # Train for one epoch + self.train_epoch(epoch, loader) + + # Save checkpoint + if epoch % self.train_cfg.get('save_interval', 10) == 0 and self.rank == 0: + self.save() + print(f"Saved checkpoint at epoch {epoch}") + + # Cleanup + self.executor.shutdown(wait=True) + + if self.rank == 0: + print("DDPO training completed!") diff --git a/owl_wms/utils/owl_vae_bridge.py b/owl_wms/utils/owl_vae_bridge.py index e4cfd74a..4ff003d6 100644 --- a/owl_wms/utils/owl_vae_bridge.py +++ b/owl_wms/utils/owl_vae_bridge.py @@ -33,6 +33,21 @@ def get_decoder_only(vae_id, cfg_path, ckpt_path): model = model.decoder model = model.bfloat16().cuda().eval() return model + +def get_encoder_only(vae_id, cfg_path, ckpt_path): + if vae_id == "dcae": + model_id = "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers" + model = AutoencoderDC.from_pretrained(model_id).bfloat16().cuda().eval() + del model.decoder # Keep encoder only + return model.encoder + else: + cfg = Config.from_yaml(cfg_path).model + model = get_model_cls(cfg.model_id)(cfg) + model.load_state_dict(torch.load(ckpt_path, map_location='cpu',weights_only=False)) + del model.decoder # Keep encoder only + model = model.encoder + model = model.bfloat16().cuda().eval() + return model @torch.no_grad() def make_batched_decode_fn(decoder, batch_size = 8): diff --git a/resources/ddim_with_logprob.py b/resources/ddim_with_logprob.py new file mode 100644 index 00000000..7c1e824c --- /dev/null +++ b/resources/ddim_with_logprob.py @@ -0,0 +1,192 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py +# with the following modifications: +# - It computes and returns the log prob of `prev_sample` given the UNet prediction. +# - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided, +# it uses it to compute the log prob. +# - Timesteps can be a batched torch.Tensor. + +from typing import Optional, Tuple, Union + +import math +import torch + +from diffusers.utils import randn_tensor +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler + + +def _left_broadcast(t, shape): + assert t.ndim <= len(shape) + return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape) + + +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to( + timestep.device + ) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + +def ddim_step_with_logprob( + self: DDIMScheduler, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + assert isinstance(self, DDIMScheduler) + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = ( + timestep - self.config.num_train_timesteps // self.num_inference_steps + ) + # to prevent OOB on gather + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ) + alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) + alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to( + sample.device + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + ( + beta_prod_t**0.5 + ) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = _get_variance(self, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( + 0.5 + ) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = ( + alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + ) + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return prev_sample.type(sample.dtype), log_prob diff --git a/resources/pipeline_with_logprob.py b/resources/pipeline_with_logprob.py new file mode 100644 index 00000000..b4320c93 --- /dev/null +++ b/resources/pipeline_with_logprob.py @@ -0,0 +1,254 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +# with the following modifications: +# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the +# `ddim` scheduler. +# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipeline, + rescale_noise_cfg, +) +from .ddim_with_logprob import ddim_step_with_logprob + + +@torch.no_grad() +def pipeline_with_logprob( + self: StableDiffusionPipeline, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents, log_prob = ddim_step_with_logprob( + self.scheduler, noise_pred, t, latents, **extra_step_kwargs + ) + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image, has_nsfw_concept, all_latents, all_log_probs diff --git a/resources/train.py b/resources/train.py new file mode 100644 index 00000000..2c876b1b --- /dev/null +++ b/resources/train.py @@ -0,0 +1,600 @@ +from collections import defaultdict +import contextlib +import os +import datetime +from concurrent import futures +import time +from absl import app, flags +from ml_collections import config_flags +from accelerate import Accelerator +from accelerate.utils import set_seed, ProjectConfiguration +from accelerate.logging import get_logger +from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +import numpy as np +import ddpo_pytorch.prompts +import ddpo_pytorch.rewards +from ddpo_pytorch.stat_tracking import PerPromptStatTracker +from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob +from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob +import torch +import wandb +from functools import partial +import tqdm +import tempfile +from PIL import Image + +tqdm = partial(tqdm.tqdm, dynamic_ncols=True) + + +FLAGS = flags.FLAGS +config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") + +logger = get_logger(__name__) + + +def main(_): + # basic Accelerate and logging setup + config = FLAGS.config + + unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + if not config.run_name: + config.run_name = unique_id + else: + config.run_name += "_" + unique_id + + if config.resume_from: + config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from)) + if "checkpoint_" not in os.path.basename(config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {config.resume_from}") + config.resume_from = os.path.join( + config.resume_from, + sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1], + ) + + # number of timesteps within each trajectory to train on + num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) + + accelerator_config = ProjectConfiguration( + project_dir=os.path.join(config.logdir, config.run_name), + automatic_checkpoint_naming=True, + total_limit=config.num_checkpoint_limit, + ) + + accelerator = Accelerator( + log_with="wandb", + mixed_precision=config.mixed_precision, + project_config=accelerator_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=config.train.gradient_accumulation_steps + * num_train_timesteps, + ) + if accelerator.is_main_process: + accelerator.init_trackers( + project_name="ddpo-pytorch", + config=config.to_dict(), + init_kwargs={"wandb": {"name": config.run_name}}, + ) + logger.info(f"\n{config}") + + # set seed (device_specific is very important to get different prompts on different devices) + set_seed(config.seed, device_specific=True) + + # load scheduler, tokenizer and models. + pipeline = StableDiffusionPipeline.from_pretrained( + config.pretrained.model, revision=config.pretrained.revision + ) + # freeze parameters of models to save more memory + pipeline.vae.requires_grad_(False) + pipeline.text_encoder.requires_grad_(False) + pipeline.unet.requires_grad_(not config.use_lora) + # disable safety checker + pipeline.safety_checker = None + # make the progress bar nicer + pipeline.set_progress_bar_config( + position=1, + disable=not accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + # switch to DDIM scheduler + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + inference_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to inference_dtype + pipeline.vae.to(accelerator.device, dtype=inference_dtype) + pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype) + if config.use_lora: + pipeline.unet.to(accelerator.device, dtype=inference_dtype) + + if config.use_lora: + # Set correct lora layers + lora_attn_procs = {} + for name in pipeline.unet.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else pipeline.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[ + block_id + ] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipeline.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + pipeline.unet.set_attn_processor(lora_attn_procs) + + # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in + # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a + # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure. + class _Wrapper(AttnProcsLayers): + def forward(self, *args, **kwargs): + return pipeline.unet(*args, **kwargs) + + unet = _Wrapper(pipeline.unet.attn_processors) + else: + unet = pipeline.unet + + # set up diffusers-friendly checkpoint saving with Accelerate + + def save_model_hook(models, weights, output_dir): + assert len(models) == 1 + if config.use_lora and isinstance(models[0], AttnProcsLayers): + pipeline.unet.save_attn_procs(output_dir) + elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def load_model_hook(models, input_dir): + assert len(models) == 1 + if config.use_lora and isinstance(models[0], AttnProcsLayers): + # pipeline.unet.load_attn_procs(input_dir) + tmp_unet = UNet2DConditionModel.from_pretrained( + config.pretrained.model, + revision=config.pretrained.revision, + subfolder="unet", + ) + tmp_unet.load_attn_procs(input_dir) + models[0].load_state_dict( + AttnProcsLayers(tmp_unet.attn_processors).state_dict() + ) + del tmp_unet + elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained( + input_dir, subfolder="unet" + ) + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + # Initialize the optimizer + if config.train.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=config.train.learning_rate, + betas=(config.train.adam_beta1, config.train.adam_beta2), + weight_decay=config.train.adam_weight_decay, + eps=config.train.adam_epsilon, + ) + + # prepare prompt and reward fn + prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) + reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() + + # generate negative prompt embeddings + neg_prompt_embed = pipeline.text_encoder( + pipeline.tokenizer( + [""], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + )[0] + sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) + train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) + + # initialize stat tracker + if config.per_prompt_stat_tracking: + stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking.buffer_size, + config.per_prompt_stat_tracking.min_count, + ) + + # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast + # autocast = accelerator.autocast + + # Prepare everything with our `accelerator`. + unet, optimizer = accelerator.prepare(unet, optimizer) + + # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a + # remote server running llava inference. + executor = futures.ThreadPoolExecutor(max_workers=2) + + # Train! + samples_per_epoch = ( + config.sample.batch_size + * accelerator.num_processes + * config.sample.num_batches_per_epoch + ) + total_train_batch_size = ( + config.train.batch_size + * accelerator.num_processes + * config.train.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {config.num_epochs}") + logger.info(f" Sample batch size per device = {config.sample.batch_size}") + logger.info(f" Train batch size per device = {config.train.batch_size}") + logger.info( + f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}" + ) + logger.info("") + logger.info(f" Total number of samples per epoch = {samples_per_epoch}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}" + ) + logger.info( + f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}" + ) + logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}") + + assert config.sample.batch_size >= config.train.batch_size + assert config.sample.batch_size % config.train.batch_size == 0 + assert samples_per_epoch % total_train_batch_size == 0 + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + accelerator.load_state(config.resume_from) + first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + first_epoch = 0 + + global_step = 0 + for epoch in range(first_epoch, config.num_epochs): + #################### SAMPLING #################### + pipeline.unet.eval() + samples = [] + prompts = [] + for i in tqdm( + range(config.sample.num_batches_per_epoch), + desc=f"Epoch {epoch}: sampling", + disable=not accelerator.is_local_main_process, + position=0, + ): + # generate prompts + prompts, prompt_metadata = zip( + *[ + prompt_fn(**config.prompt_fn_kwargs) + for _ in range(config.sample.batch_size) + ] + ) + + # encode prompts + prompt_ids = pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + prompt_embeds = pipeline.text_encoder(prompt_ids)[0] + + # sample + with autocast(): + images, _, latents, log_probs = pipeline_with_logprob( + pipeline, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=config.sample.num_steps, + guidance_scale=config.sample.guidance_scale, + eta=config.sample.eta, + output_type="pt", + ) + + latents = torch.stack( + latents, dim=1 + ) # (batch_size, num_steps + 1, 4, 64, 64) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = pipeline.scheduler.timesteps.repeat( + config.sample.batch_size, 1 + ) # (batch_size, num_steps) + + # compute rewards asynchronously + rewards = executor.submit(reward_fn, images, prompts, prompt_metadata) + # yield to to make sure reward computation starts + time.sleep(0) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[ + :, :-1 + ], # each entry is the latent before timestep t + "next_latents": latents[ + :, 1: + ], # each entry is the latent after timestep t + "log_probs": log_probs, + "rewards": rewards, + } + ) + + # wait for all rewards to be computed + for sample in tqdm( + samples, + desc="Waiting for rewards", + disable=not accelerator.is_local_main_process, + position=0, + ): + rewards, reward_metadata = sample["rewards"].result() + # accelerator.print(reward_metadata) + sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + + # this is a hack to force wandb to log the images as JPEGs instead of PNGs + with tempfile.TemporaryDirectory() as tmpdir: + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + pil.save(os.path.join(tmpdir, f"{i}.jpg")) + accelerator.log( + { + "images": [ + wandb.Image( + os.path.join(tmpdir, f"{i}.jpg"), + caption=f"{prompt:.25} | {reward:.2f}", + ) + for i, (prompt, reward) in enumerate( + zip(prompts, rewards) + ) # only log rewards from process 0 + ], + }, + step=global_step, + ) + + # gather rewards across processes + rewards = accelerator.gather(samples["rewards"]).cpu().numpy() + + # log rewards and images + accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + # per-prompt mean/std tracking + if config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = pipeline.tokenizer.batch_decode( + prompt_ids, skip_special_tokens=True + ) + advantages = stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; we only need to keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(accelerator.num_processes, -1)[accelerator.process_index] + .to(accelerator.device) + ) + + del samples["rewards"] + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + assert ( + total_batch_size + == config.sample.batch_size * config.sample.num_batches_per_epoch + ) + assert num_timesteps == config.sample.num_steps + + #################### TRAINING #################### + for inner_epoch in range(config.train.num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + perms = torch.stack( + [ + torch.randperm(num_timesteps, device=accelerator.device) + for _ in range(total_batch_size) + ] + ) + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=accelerator.device)[:, None], + perms, + ] + + # rebatch for training + samples_batched = { + k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) + for k, v in samples.items() + } + + # dict of lists -> list of dicts for easier iteration + samples_batched = [ + dict(zip(samples_batched, x)) for x in zip(*samples_batched.values()) + ] + + # train + pipeline.unet.train() + info = defaultdict(list) + for i, sample in tqdm( + list(enumerate(samples_batched)), + desc=f"Epoch {epoch}.{inner_epoch}: training", + position=0, + disable=not accelerator.is_local_main_process, + ): + if config.train.cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat( + [train_neg_prompt_embeds, sample["prompt_embeds"]] + ) + else: + embeds = sample["prompt_embeds"] + + for j in tqdm( + range(num_train_timesteps), + desc="Timestep", + position=1, + leave=False, + disable=not accelerator.is_local_main_process, + ): + with accelerator.accumulate(unet): + with autocast(): + if config.train.cfg: + noise_pred = unet( + torch.cat([sample["latents"][:, j]] * 2), + torch.cat([sample["timesteps"][:, j]] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = ( + noise_pred_uncond + + config.sample.guidance_scale + * (noise_pred_text - noise_pred_uncond) + ) + else: + noise_pred = unet( + sample["latents"][:, j], + sample["timesteps"][:, j], + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + _, log_prob = ddim_step_with_logprob( + pipeline.scheduler, + noise_pred, + sample["timesteps"][:, j], + sample["latents"][:, j], + eta=config.sample.eta, + prev_sample=sample["next_latents"][:, j], + ) + + # ppo logic + advantages = torch.clamp( + sample["advantages"], + -config.train.adv_clip_max, + config.train.adv_clip_max, + ) + ratio = torch.exp(log_prob - sample["log_probs"][:, j]) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - config.train.clip_range, + 1.0 + config.train.clip_range, + ) + loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + # debugging values + # John Schulman says that (ratio - 1) - log(ratio) is a better + # estimator, but most existing code uses this so... + # http://joschu.net/blog/kl-approx.html + info["approx_kl"].append( + 0.5 + * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) + ) + info["clipfrac"].append( + torch.mean( + ( + torch.abs(ratio - 1.0) > config.train.clip_range + ).float() + ) + ) + info["loss"].append(loss) + + # backward pass + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + unet.parameters(), config.train.max_grad_norm + ) + optimizer.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + assert (j == num_train_timesteps - 1) and ( + i + 1 + ) % config.train.gradient_accumulation_steps == 0 + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + + # make sure we did an optimization step at the end of the inner epoch + assert accelerator.sync_gradients + + if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process: + accelerator.save_state() + + +if __name__ == "__main__": + app.run(main) diff --git a/rewards.py b/rewards.py new file mode 100644 index 00000000..03785f3b --- /dev/null +++ b/rewards.py @@ -0,0 +1,11 @@ +import numpy as np + +def random_reward(*args): + # assumes first arg has a batch size... + batch_size = len(args[0]) + + return np.random.random(batch_size) + +def darkness_reward(videos, *args): + + return -videos.mean(dim=(1,2,3,4)) \ No newline at end of file