diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml new file mode 100644 index 0000000..b65e562 --- /dev/null +++ b/configs/dpo_pickapic.yml @@ -0,0 +1,45 @@ +method: + name : "DPO" + +model: + model_path: "stabilityai/stable-diffusion-2-1" + pipeline_kwargs: + use_safetensors: True + variant: "fp16" + sdxl: False + model_arch_type: "LDMUnet" + attention_slicing: True + xformers_memory_efficient: False + gradient_checkpointing: True + + +sampler: + guidance_scale: 7.5 + num_inference_steps: 50 + +optimizer: + name: "adamw" + kwargs: + lr: 2.048e-8 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + +scheduler: + name: "linear" # Name of learning rate scheduler + kwargs: + start_factor: 1.0 + end_factor: 1.0 + +logging: + run_name: 'dpo_pickapic' + #wandb_entity: None + #wandb_project: None + +train: + num_epochs: 500 + num_samples_per_epoch: 256 + batch_size: 1 + target_batch: 256 + checkpoint_interval: 640 + tf32: True + suppress_log_keywords: "diffusers.pipelines,transformers" diff --git a/examples/DPO/train_dpo_pickapic.py b/examples/DPO/train_dpo_pickapic.py new file mode 100644 index 0000000..c581c52 --- /dev/null +++ b/examples/DPO/train_dpo_pickapic.py @@ -0,0 +1,24 @@ +import sys + +sys.path.append("./src") + +from drlx.trainer.dpo_trainer import DPOTrainer +from drlx.configs import DRLXConfig +from drlx.utils import get_latest_checkpoint + +# Pipeline first +from drlx.pipeline.pickapic_dpo import PickAPicDPOPipeline + +import torch + +pipe = PickAPicDPOPipeline() +resume = False + +config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml") +trainer = DPOTrainer(config) + +if resume: + cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") + trainer.load_checkpoint(cp_dir) + +trainer.train(pipe) \ No newline at end of file diff --git a/src/drlx/configs.py b/src/drlx/configs.py index 51c9199..4912bc1 100644 --- a/src/drlx/configs.py +++ b/src/drlx/configs.py @@ -91,6 +91,22 @@ class DDPOConfig(MethodConfig): buffer_size: int = 32 # Set to None to avoid using per prompt stat tracker min_count: int = 16 +@register_method("DPO") +@dataclass +class DPOConfig(MethodConfig): + """ + Config for DPO-related hyperparams + + :param beta: Deviation from initial model + :type beta: float + + :param ref_mem_strategy: Strategy for managing reference model on memory. By default, puts it in 16 bit. + :type ref_mem_strategy: str + """ + name : str = "DPO" + beta : float = 0.9 + ref_mem_strategy : str = None # None or "half" + @dataclass class TrainConfig(ConfigClass): """ @@ -144,7 +160,7 @@ class TrainConfig(ConfigClass): num_epochs: int = 50 total_samples: int = None num_samples_per_epoch: int = 256 - grad_clip: float = 1.0 + grad_clip: float = -1 checkpoint_interval: int = 10 checkpoint_path: str = "checkpoints" seed: int = 0 @@ -219,14 +235,14 @@ class ModelConfig(ConfigClass): :param model_path: Path or name of the model (local or on huggingface hub) :type model_path: str - :param model_arch_type: Type of model architecture. - :type model_arch_type: str + :param pipeline_kwargs: Keyword arguments for pipeline if model is being loaded from one + :type pipeline_kwargs: dict - :param use_safetensors: Use safe tensors when loading pipeline? - :type use_safetensors: bool + :param sdxl: Using SDXL model? + :type sdxl: bool - :param local_model: Force model to load checkpoint locally only - :type local_model: bool + :param model_arch_type: Type of model architecture. Defaults to LDM UNet + :type model_arch_type: str :param attention_slicing: Whether to use attention slicing :type attention_slicing: bool @@ -242,9 +258,9 @@ class ModelConfig(ConfigClass): """ model_path: str = None + pipeline_kwargs : dict = None + sdxl : bool = False model_arch_type: str = None - use_safetensors : bool = False - local_model : bool = False attention_slicing: bool = False xformers_memory_efficient: bool = False gradient_checkpointing: bool = False diff --git a/src/drlx/denoisers/ldm_unet.py b/src/drlx/denoisers/ldm_unet.py index 3ef62d0..f92253b 100644 --- a/src/drlx/denoisers/ldm_unet.py +++ b/src/drlx/denoisers/ldm_unet.py @@ -28,7 +28,10 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, super().__init__(config, sampler_config, sampler) self.unet : UNet2DConditionModel = None + self.text_encoder = None + self.text_encoder_2 = None # SDXL Support, just needs to be here for device mapping + self.vae = None self.encode_prompt : Callable = None @@ -37,6 +40,8 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, self.scale_factor = None + self.sdxl_flag = self.config.sdxl + def get_input_shape(self) -> Tuple[int]: """ Figure out latent noise input shape for the UNet. Requires that unet and vae are defined @@ -65,16 +70,25 @@ def from_pretrained_pipeline(self, cls : Type, path : str): :rtype: LDMUNet """ - pipe = cls.from_pretrained(path, use_safetensors = self.config.use_safetensors, local_files_only = self.config.local_model) + kwargs = self.config.pipeline_kwargs + kwargs["torch_dtype"] = torch.float32 + + pipe = cls.from_pretrained(path, **kwargs) if self.config.attention_slicing: pipe.enable_attention_slicing() if self.config.xformers_memory_efficient: pipe.enable_xformers_memory_efficient_attention() self.unet = pipe.unet self.text_encoder = pipe.text_encoder + + # SDXL compat + if self.sdxl_flag: + self.text_encoder_2 = pipe.text_encoder_2 + self.vae = pipe.vae self.scale_factor = pipe.vae_scale_factor - self.encode_prompt = pipe._encode_prompt + + self.encode_prompt = pipe.encode_prompt self.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) @@ -149,7 +163,8 @@ def forward( time_step : Union[TensorType["batch"], int], # Note diffusers tyically does 999->0 as steps input_ids : TensorType["batch", "seq_len"] = None, attention_mask : TensorType["batch", "seq_len"] = None, - text_embeds : TensorType["batch", "d"] = None + text_embeds : TensorType["batch", "d"] = None, + added_cond_kwargs = {} ) -> TensorType["batch", "channels", "height", "width"]: """ For text conditioned UNET, inputs are assumed to be: @@ -162,8 +177,20 @@ def forward( return self.unet( pixel_values, time_step, - encoder_hidden_states = text_embeds + encoder_hidden_states = text_embeds, + added_cond_kwargs = added_cond_kwargs ).sample + @property + def device(self): + return self.unet.device + + def enable_adapters(self): + if self.config.lora_rank: + self.unet.enable_adapters() + + def disable_adapters(self): + if self.config.lora_rank: + self.unet.disable_adapters() diff --git a/src/drlx/pipeline/dpo_pipeline.py b/src/drlx/pipeline/dpo_pipeline.py new file mode 100644 index 0000000..95dc880 --- /dev/null +++ b/src/drlx/pipeline/dpo_pipeline.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from typing import Tuple, Callable + +from PIL import Image + +from drlx.pipeline import Pipeline + +class DPOPipeline(Pipeline): + """ + Pipeline for training with DPO. Returns prompts, chosen images, and rejected images + """ + def __init__(self, *args): + super().__init__(*args) + + @abstractmethod + def __getitem__(self, index : int) -> Tuple[str, Image.Image, Image.Image]: + pass + + def make_default_collate(self, prep : Callable): + def collate(batch : Iterable[Tuple[str, Image.Image, Image.Image]]): + prompts = [d[0] for d in batch] + chosen = [d[1] for d in batch] + rejected = [d[2] for d in batch] + + return prep(prompts, chosen, rejected) + + return collate + + + diff --git a/src/drlx/pipeline/pickapic_dpo.py b/src/drlx/pipeline/pickapic_dpo.py new file mode 100644 index 0000000..f032c23 --- /dev/null +++ b/src/drlx/pipeline/pickapic_dpo.py @@ -0,0 +1,65 @@ +from datasets import load_dataset +import io + +from drlx.pipeline.dpo_pipeline import DPOPipeline + +import torch +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader +from PIL import Image + +def convert_bytes_to_image(image_bytes, id): + try: + image = Image.open(io.BytesIO(image_bytes)) + image = image.resize((512, 512)) + return image + except Exception as e: + print(f"An error occurred: {e}") + +def create_train_dataset(): + ds = load_dataset("yuvalkirstain/pickapic_v2",split='train') + ds = ds.filter(lambda example: example['has_label'] == True and example['label_0'] != 0.5) + return ds + +class Collator: + def __call__(self, batch): + # Batch is list of rows which are dicts + image_0_bytes = [b['jpg_0'] for b in batch] + image_1_bytes = [b['jpg_1'] for b in batch] + uid_0 = [b['image_0_uid'] for b in batch] + uid_1 = [b['image_1_uid'] for b in batch] + + label_0s = [b['label_0'] for b in batch] + + for i in range(len(batch)): + if not label_0s[i]: # label_1 is 1 => jpg_1 is the chosen one + image_0_bytes[i], image_1_bytes[i] = image_1_bytes[i], image_0_bytes[i] + # Swap so image_0 is always the chosen one + + prompts = [b['caption'] for b in batch] + + images_0 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_0_bytes, uid_0)] + images_1 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_1_bytes, uid_1)] + + images_0 = torch.stack([transforms.ToTensor()(image) for image in images_0]) + images_0 = images_0 * 2 - 1 + + images_1 = torch.stack([transforms.ToTensor()(image) for image in images_1]) + images_1 = images_1 * 2 - 1 + + return { + "chosen_pixel_values" : images_0, + "rejected_pixel_values" : images_1, + "prompts" : prompts + } + +class PickAPicDPOPipeline(DPOPipeline): + """ + Pipeline for training LDM with DPO + """ + def __init__(self): + self.train_ds = create_train_dataset() + self.dc = Collator() + + def create_loader(self, **kwargs): + return DataLoader(self.train_ds, collate_fn = self.dc, **kwargs) \ No newline at end of file diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index e7c5593..fe8a7f5 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -1,290 +1,3 @@ -from typing import Union, Iterable, Tuple, Any, Optional -from torchtyping import TensorType - -import torch -from tqdm import tqdm -import math -import einops as eo - -from drlx.utils import rescale_noise_cfg - -from drlx.configs import SamplerConfig, DDPOConfig - -class Sampler: - """ - Generic class for sampling generations using a denoiser. Assumes LDMUnet - """ - def __init__(self, config : SamplerConfig = SamplerConfig()): - self.config = config - - def cfg_rescale(self, pred : TensorType["2 * b", "c", "h", "w"]): - """ - Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled - - :param pred: - Assumed to be batched repeated prediction with first half consisting of - unconditioned (empty token) predictions and second half being conditioned - predictions - """ - - pred_uncond, pred_cond = pred.chunk(2) - pred = pred_uncond + self.config.guidance_scale * (pred_cond - pred_uncond) - - if self.config.guidance_rescale is not None: - pred = rescale_noise_cfg(pred, pred_cond, self.config.guidance_rescale) - - return pred - - @torch.no_grad() - def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress : bool = False, accelerator = None): - """ - Samples latents given some prompts and a denoiser - - :param prompts: Text prompts for image generation (to condition denoiser) - :param denoiser: Model to use for denoising - :param device: Device on which to perform model inference - :param show_progress: Whether to display a progress bar for the sampling steps - :param accelerator: Accelerator object for accelerated training (optional) - - :return: Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images) - """ - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess - noise_shape = denoiser_unwrapped.get_input_shape() - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - latents = torch.randn(len(prompts), *noise_shape, device = device) - - for i, t in enumerate(tqdm(scheduler.timesteps), disable = not show_progress): - input = torch.cat([latents] * 2) - input = scheduler.scale_model_input(input, t) - - pred = denoiser( - pixel_values=input, - time_step = t, - text_embeds = text_embeds - ) - - # guidance - pred = self.cfg_rescale(pred) - - # step backward - scheduler_out = scheduler.step(pred, t, latents, self.config.eta) - latents = scheduler_out.prev_sample - - if self.config.postprocess: - return denoiser_unwrapped.postprocess(latents) - else: - return latents - -class DDPOSampler(Sampler): - def step_and_logprobs(self, - scheduler, - pred : TensorType["b", "c", "h", "w"], - t : float, - latents : TensorType["b", "c", "h", "w"], - old_pred : Optional[TensorType["b", "c", "h", "w"]] = None - ): - """ - Steps backwards using scheduler. Considers the prediction as an action sampled - from a normal distribution and returns average log probability for that prediction. - Can also be used to find probability of current model giving some other prediction (old_pred) - - :param scheduler: Scheduler being used for diffusion process - :param pred: Denoiser prediction with CFG and scaling accounted for - :param t: Timestep in diffusion process - :param latents: Latent vector given as input to denoiser - :param old_pred: Alternate prediction. If given, computes log probability of current model predicting alternative output. - """ - scheduler_out = scheduler.step(pred, t, latents, self.config.eta, variance_noise=0) - - # computing log_probs - t_1 = t - scheduler.config.num_train_timesteps // self.config.num_inference_steps - variance = scheduler._get_variance(t, t_1) - std_dev_t = self.config.eta * variance ** 0.5 - prev_sample_mean = scheduler_out.prev_sample - prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t - - std_dev_t = torch.clip(std_dev_t, 1e-6) # force sigma > 1e-6 - - # If old_pred provided, we are finding probability of new model outputting same action as before - # Otherwise finding probability of current action - action = old_pred if old_pred is not None else prev_sample # Log prob of new model giving old output - log_probs = -((action.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) - log_probs = eo.reduce(log_probs, 'b c h w -> b', 'mean') - - return prev_sample, log_probs - - @torch.no_grad() - def sample( - self, prompts, denoiser, device, - show_progress : bool = False, - accelerator = None - ) -> Iterable[torch.Tensor]: - """ - DDPO sampling is analagous to playing a game in an RL environment. This function samples - given denoiser and prompts but in addition to giving latents also gives log probabilities - for predictions as well as ALL predictions (i.e. at each timestep) - - :param prompts: Text prompts to condition denoiser - :param denoiser: Denoising model - :param device: Device to do inference on - :param show_progress: Display progress bar? - :param accelerator: Accelerator object for accelerated training (optional) - - :return: triple of final denoised latents, all model predictions, all log probabilities for each prediction - """ - - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess - noise_shape = denoiser_unwrapped.get_input_shape() - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - latents = torch.randn(len(prompts), *noise_shape, device = device) - - all_step_preds, all_log_probs = [latents], [] - - for t in tqdm(scheduler.timesteps, disable = not show_progress): - latent_input = torch.cat([latents] * 2) - latent_input = scheduler.scale_model_input(latent_input, t) - - pred = denoiser( - pixel_values = latent_input, - time_step = t, - text_embeds = text_embeds - ) - - # cfg - pred = self.cfg_rescale(pred) - - # step - prev_sample, log_probs = self.step_and_logprobs(scheduler, pred, t, latents) - - all_step_preds.append(prev_sample) - all_log_probs.append(log_probs) - latents = prev_sample - - return latents, torch.stack(all_step_preds), torch.stack(all_log_probs) - - def compute_loss( - self, prompts, denoiser, device, - show_progress : bool = False, - advantages = None, old_preds = None, old_log_probs = None, - method_config : DDPOConfig = None, - accelerator = None - ): - - - """ - Computes the loss for the DDPO sampling process. This function is used to train the denoiser model. - - :param prompts: Text prompts to condition the denoiser - :param denoiser: Denoising model - :param device: Device to perform model inference on - :param show_progress: Whether to display a progress bar for the sampling steps - :param advantages: Normalized advantages obtained from reward computation - :param old_preds: Previous predictions from past model - :param old_log_probs: Log probabilities of predictions from past model - :param method_config: Configuration for the DDPO method - :param accelerator: Accelerator object for accelerated training (optional) - - :return: Total loss computed over the sampling process - """ - - # All metrics are reduced and gathered before result is returned - metrics = { - "loss" : [], - "kl_div" : [], # ~ KL div between new policy and old one (average) - "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped - } - - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess - - adv_clip = method_config.clip_advantages # clip value for advantages - pi_clip = method_config.clip_ratio # clip value for policy ratio - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - total_loss = 0. - - for i, t in enumerate(tqdm(scheduler.timesteps, disable = not show_progress)): - latent_input = torch.cat([old_preds[i].detach()] * 2) - latent_input = scheduler.scale_model_input(latent_input, t) - - pred = denoiser( - pixel_values = latent_input, - time_step = t, - text_embeds = text_embeds - ) - - # cfg - pred = self.cfg_rescale(pred) - - # step - prev_sample, log_probs = self.step_and_logprobs( - scheduler, pred, t, old_preds[i], - old_preds[i+1] - ) - - # Need to be computed and detached again because of autograd weirdness - clipped_advs = torch.clip(advantages,-adv_clip,adv_clip).detach() - - # ppo actor loss - - ratio = torch.exp(log_probs - old_log_probs[i].detach()) - surr1 = -clipped_advs * ratio - surr2 = -clipped_advs * torch.clip(ratio, 1. - pi_clip, 1. + pi_clip) - loss = torch.max(surr1, surr2).mean() - if accelerator is not None: - accelerator.backward(loss) - else: - loss.backward() - - # Metric computations - kl_div = 0.5 * (log_probs - old_log_probs[i]).mean() ** 2 - clip_frac = ((ratio < 1 - pi_clip) | (ratio > 1 + pi_clip)).float().mean() - - metrics["loss"].append(loss.item()) - metrics["kl_div"].append(kl_div.item()) - metrics["clip_frac"].append(clip_frac.item()) - - # Reduce across timesteps then across devices - for k in metrics: - metrics[k] = torch.tensor(metrics[k]).mean().cuda() # Needed for reduction to work - if accelerator is not None: - metrics = accelerator.reduce(metrics, 'mean') - - return metrics \ No newline at end of file +from .base import Sampler +from .ddpo_sampler import DDPOSampler +from .dpo_sampler import DPOSampler \ No newline at end of file diff --git a/src/drlx/sampling/base.py b/src/drlx/sampling/base.py new file mode 100644 index 0000000..88cf118 --- /dev/null +++ b/src/drlx/sampling/base.py @@ -0,0 +1,89 @@ +from typing import Union, Iterable, Tuple, Any, Optional +from torchtyping import TensorType + +import torch +from tqdm import tqdm +import math +import einops as eo +import torch.nn.functional as F + +from drlx.utils import rescale_noise_cfg +from drlx.configs import SamplerConfig + +class Sampler: + """ + Generic class for sampling generations using a denoiser. Assumes LDMUnet + """ + def __init__(self, config : SamplerConfig = SamplerConfig()): + self.config = config + + def cfg_rescale(self, pred : TensorType["2 * b", "c", "h", "w"]): + """ + Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled + + :param pred: + Assumed to be batched repeated prediction with first half consisting of + unconditioned (empty token) predictions and second half being conditioned + predictions + """ + + pred_uncond, pred_cond = pred.chunk(2) + pred = pred_uncond + self.config.guidance_scale * (pred_cond - pred_uncond) + + if self.config.guidance_rescale is not None: + pred = rescale_noise_cfg(pred, pred_cond, self.config.guidance_rescale) + + return pred + + @torch.no_grad() + def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress : bool = False, accelerator = None): + """ + Samples latents given some prompts and a denoiser + + :param prompts: Text prompts for image generation (to condition denoiser) + :param denoiser: Model to use for denoising + :param device: Device on which to perform model inference + :param show_progress: Whether to display a progress bar for the sampling steps + :param accelerator: Accelerator object for accelerated training (optional) + + :return: Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images) + """ + if accelerator is None: + denoiser_unwrapped = denoiser + else: + denoiser_unwrapped = accelerator.unwrap_model(denoiser) + + scheduler = denoiser_unwrapped.scheduler + preprocess = denoiser_unwrapped.preprocess + noise_shape = denoiser_unwrapped.get_input_shape() + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + latents = torch.randn(len(prompts), *noise_shape, device = device) + + for i, t in enumerate(tqdm(scheduler.timesteps), disable = not show_progress): + input = torch.cat([latents] * 2) + input = scheduler.scale_model_input(input, t) + + pred = denoiser( + pixel_values=input, + time_step = t, + text_embeds = text_embeds + ) + + # guidance + pred = self.cfg_rescale(pred) + + # step backward + scheduler_out = scheduler.step(pred, t, latents, self.config.eta) + latents = scheduler_out.prev_sample + + if self.config.postprocess: + return denoiser_unwrapped.postprocess(latents) + else: + return latents \ No newline at end of file diff --git a/src/drlx/sampling/ddpo_sampler.py b/src/drlx/sampling/ddpo_sampler.py new file mode 100644 index 0000000..8439e2b --- /dev/null +++ b/src/drlx/sampling/ddpo_sampler.py @@ -0,0 +1,213 @@ +from torchtyping import TensorType +from typing import Iterable, Optional + +import einops as eo +import torch +from tqdm import tqdm +import math + +from drlx.sampling.base import Sampler +from drlx.configs import DDPOConfig + +class DDPOSampler(Sampler): + def step_and_logprobs(self, + scheduler, + pred : TensorType["b", "c", "h", "w"], + t : float, + latents : TensorType["b", "c", "h", "w"], + old_pred : Optional[TensorType["b", "c", "h", "w"]] = None + ): + """ + Steps backwards using scheduler. Considers the prediction as an action sampled + from a normal distribution and returns average log probability for that prediction. + Can also be used to find probability of current model giving some other prediction (old_pred) + + :param scheduler: Scheduler being used for diffusion process + :param pred: Denoiser prediction with CFG and scaling accounted for + :param t: Timestep in diffusion process + :param latents: Latent vector given as input to denoiser + :param old_pred: Alternate prediction. If given, computes log probability of current model predicting alternative output. + """ + scheduler_out = scheduler.step(pred, t, latents, self.config.eta, variance_noise=0) + + # computing log_probs + t_1 = t - scheduler.config.num_train_timesteps // self.config.num_inference_steps + variance = scheduler._get_variance(t, t_1) + std_dev_t = self.config.eta * variance ** 0.5 + prev_sample_mean = scheduler_out.prev_sample + prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t + + std_dev_t = torch.clip(std_dev_t, 1e-6) # force sigma > 1e-6 + + # If old_pred provided, we are finding probability of new model outputting same action as before + # Otherwise finding probability of current action + action = old_pred if old_pred is not None else prev_sample # Log prob of new model giving old output + log_probs = -((action.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) + log_probs = eo.reduce(log_probs, 'b c h w -> b', 'mean') + + return prev_sample, log_probs + + @torch.no_grad() + def sample( + self, prompts, denoiser, device, + show_progress : bool = False, + accelerator = None + ) -> Iterable[torch.Tensor]: + """ + DDPO sampling is analagous to playing a game in an RL environment. This function samples + given denoiser and prompts but in addition to giving latents also gives log probabilities + for predictions as well as ALL predictions (i.e. at each timestep) + + :param prompts: Text prompts to condition denoiser + :param denoiser: Denoising model + :param device: Device to do inference on + :param show_progress: Display progress bar? + :param accelerator: Accelerator object for accelerated training (optional) + + :return: triple of final denoised latents, all model predictions, all log probabilities for each prediction + """ + + if accelerator is None: + denoiser_unwrapped = denoiser + else: + denoiser_unwrapped = accelerator.unwrap_model(denoiser) + + scheduler = denoiser_unwrapped.scheduler + preprocess = denoiser_unwrapped.preprocess + sdxl_flag = denoiser_unwrapped.sdxl_flag + noise_shape = denoiser_unwrapped.get_input_shape() + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ) + + # If not SDXL, we assume encode prompts gave normal and negative embeds, which we concat + if sdxl_flag: + pass # TODO: SDXL Support for DDPO + else: + text_embeds = torch.cat([text_embeds[1], text_embeds[0]]).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + latents = torch.randn(len(prompts), *noise_shape, device = device) + + all_step_preds, all_log_probs = [latents], [] + + for t in tqdm(scheduler.timesteps, disable = not show_progress): + latent_input = torch.cat([latents] * 2) # Double for CFG + latent_input = scheduler.scale_model_input(latent_input, t) + + pred = denoiser( + pixel_values = latent_input, + time_step = t, + text_embeds = text_embeds + ) + + # cfg + pred = self.cfg_rescale(pred) + + # step + prev_sample, log_probs = self.step_and_logprobs(scheduler, pred, t, latents) + + all_step_preds.append(prev_sample) + all_log_probs.append(log_probs) + latents = prev_sample + + return latents, torch.stack(all_step_preds), torch.stack(all_log_probs) + + def compute_loss( + self, prompts, denoiser, device, + show_progress : bool = False, + advantages = None, old_preds = None, old_log_probs = None, + method_config : DDPOConfig = None, + accelerator = None + ): + + + """ + Computes the loss for the DDPO sampling process. This function is used to train the denoiser model. + + :param prompts: Text prompts to condition the denoiser + :param denoiser: Denoising model + :param device: Device to perform model inference on + :param show_progress: Whether to display a progress bar for the sampling steps + :param advantages: Normalized advantages obtained from reward computation + :param old_preds: Previous predictions from past model + :param old_log_probs: Log probabilities of predictions from past model + :param method_config: Configuration for the DDPO method + :param accelerator: Accelerator object for accelerated training (optional) + + :return: Total loss computed over the sampling process + """ + + # All metrics are reduced and gathered before result is returned + metrics = { + "loss" : [], + "kl_div" : [], # ~ KL div between new policy and old one (average) + "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped + } + + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess + + adv_clip = method_config.clip_advantages # clip value for advantages + pi_clip = method_config.clip_ratio # clip value for policy ratio + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + total_loss = 0. + + for i, t in enumerate(tqdm(scheduler.timesteps, disable = not show_progress)): + latent_input = torch.cat([old_preds[i].detach()] * 2) + latent_input = scheduler.scale_model_input(latent_input, t) + + pred = denoiser( + pixel_values = latent_input, + time_step = t, + text_embeds = text_embeds + ) + + # cfg + pred = self.cfg_rescale(pred) + + # step + prev_sample, log_probs = self.step_and_logprobs( + scheduler, pred, t, old_preds[i], + old_preds[i+1] + ) + + # Need to be computed and detached again because of autograd weirdness + clipped_advs = torch.clip(advantages,-adv_clip,adv_clip).detach() + + # ppo actor loss + + ratio = torch.exp(log_probs - old_log_probs[i].detach()) + surr1 = -clipped_advs * ratio + surr2 = -clipped_advs * torch.clip(ratio, 1. - pi_clip, 1. + pi_clip) + loss = torch.max(surr1, surr2).mean() + if accelerator is not None: + accelerator.backward(loss) + else: + loss.backward() + + # Metric computations + kl_div = 0.5 * (log_probs - old_log_probs[i]).mean() ** 2 + clip_frac = ((ratio < 1 - pi_clip) | (ratio > 1 + pi_clip)).float().mean() + + metrics["loss"].append(loss.item()) + metrics["kl_div"].append(kl_div.item()) + metrics["clip_frac"].append(clip_frac.item()) + + # Reduce across timesteps then across devices + for k in metrics: + metrics[k] = torch.tensor(metrics[k]).mean().cuda() # Needed for reduction to work + if accelerator is not None: + metrics = accelerator.reduce(metrics, 'mean') + + return metrics diff --git a/src/drlx/sampling/dpo_sampler.py b/src/drlx/sampling/dpo_sampler.py new file mode 100644 index 0000000..d3c8b03 --- /dev/null +++ b/src/drlx/sampling/dpo_sampler.py @@ -0,0 +1,156 @@ +import torch +import torch.nn.functional as F +import einops as eo + +from drlx.sampling.base import Sampler +from drlx.configs import DPOConfig +from drlx.utils.sdxl import get_time_ids + +class DPOSampler(Sampler): + def compute_loss( + self, + prompts, + chosen_img, + rejected_img, + denoiser, + vae, + device, + method_config : DPOConfig, + accelerator = None, + ref_denoiser = None + ): + """ + Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. + """ + do_lora = ref_denoiser is None + + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess + sdxl_flag = accelerator.unwrap_model(denoiser).sdxl_flag + encode = accelerator.unwrap_model(vae).encode + + beta = method_config.beta + ref_strategy = method_config.ref_mem_strategy + + # Text and image preprocessing + with torch.no_grad(): + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ) + + # The value returned above varies depending on model + # With most models its two values, positive and negative prompts + # With DPO we don't care about CFG, so just only get the positive prompts + added_cond_kwargs = {} + if sdxl_flag: + added_cond_kwargs['text_embeds'] = text_embeds[2].detach() # Pooled prompt embeds + added_cond_kwargs['time_ids'] = get_time_ids(chosen_img) + + text_embeds = text_embeds[0].detach() + + chosen_latent = encode(chosen_img).latent_dist.sample() + rejected_latent = encode(rejected_img).latent_dist.sample() + + # sample random ts + timesteps = torch.randint( + 0, self.config.num_inference_steps, (len(chosen_img),), device = device, dtype = torch.long + ) + + # One step of noising to samples + noise = torch.randn_like(chosen_latent) # [B, C, H, W] + + # Doubling across chosen and rejeceted + def double_up(x): + return torch.cat([x,x], dim = 0) + + def double_down(x): + n = len(x) + return x[:n//2], x[n//2:] + + # Double everything up so we can input both chosen and rejected at the same time + timesteps = double_up(timesteps) + noise = double_up(noise) + text_embeds = double_up(text_embeds) + + if sdxl_flag: + added_cond_kwargs['text_embeds'] = double_up(added_cond_kwargs['text_embeds']) + added_cond_kwargs['time_ids'] = double_up(added_cond_kwargs['time_ids']) + + latent = torch.cat([chosen_latent, rejected_latent]) + + noisy_inputs = scheduler.add_noise( + latent, + noise, + timesteps + ) + + # Get targets + if scheduler.config.prediction_type == "epsilon": + target = noise + elif scheduler.config.prediction_type == "v_prediction": + target = scheduler.get_velocity( + latent, + noise, + timesteps + ) + + # utility function to get loss simpler + def split_mse(pred, target): + mse = eo.reduce(F.mse_loss(pred, target, reduction = 'none'), 'b ... -> b', reduction = "mean") + chosen, rejected = double_down(mse) + return chosen - rejected, mse.mean() + + # Forward pass and loss for DPO denoiser + pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds, + added_cond_kwargs = added_cond_kwargs + ) + model_diff, base_loss = split_mse(pred, target) + + # Forward pass and loss for refrence + with torch.no_grad(): + if do_lora: + accelerator.unwrap_model(denoiser).disable_adapters() + + ref_pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds, + added_cond_kwargs = added_cond_kwargs + ) + ref_diff, ref_loss = split_mse(ref_pred, target) + + accelerator.unwrap_model(denoiser).enable_adapters() + else: + ref_inputs = { + "sample" : noisy_inputs.half() if ref_strategy == "half" else noisy_inputs, + "timestep" : timesteps, + "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds, + "added_cond_kwargs" : added_cond_kwargs + } + ref_pred = ref_denoiser(**ref_inputs).sample + ref_diff, ref_loss = split_mse(ref_pred, target) + + # DPO Objective + surr_loss = -beta * (model_diff - ref_diff) + loss = -1 * F.logsigmoid(surr_loss.mean()) + + # Get approx accuracy as models probability of giving chosen over rejected + acc = (surr_loss > 0).sum().float() / len(surr_loss) + acc += 0.5 * (surr_loss == 0).sum().float() / len(surr_loss) # 50% for when both match + + if accelerator is None: + loss.backward() + else: + accelerator.backward(loss) + + return { + "loss" : loss.item(), + "diffusion_loss" : base_loss.item(), + "accuracy" : acc.item(), + "ref_deviation" : (ref_loss - base_loss) ** 2 + } \ No newline at end of file diff --git a/src/drlx/trainer/base_accelerate.py b/src/drlx/trainer/base_accelerate.py new file mode 100644 index 0000000..daa49c9 --- /dev/null +++ b/src/drlx/trainer/base_accelerate.py @@ -0,0 +1,139 @@ +from drlx.trainer import BaseTrainer +from drlx.configs import DRLXConfig +from drlx.sampling import Sampler +from drlx.utils import suppress_warnings + +from accelerate import Accelerator +import wandb +import logging +import torch +from diffusers import StableDiffusionPipeline +import os + +from diffusers.utils import convert_state_dict_to_diffusers +from peft.utils import get_peft_model_state_dict + +class AcceleratedTrainer(BaseTrainer): + """ + Base class for any trainer using accelerate. Assumes model comes from a pretrained + pipeline + + :param config: DRLX config. Method config can be anything. + :type config: DRLXConfig + """ + def __init__(self, config : DRLXConfig): + super().__init__(config) + # Figure out batch size and accumulation steps + if self.config.train.target_batch is not None: # Just use normal batch_size + self.accum_steps = (self.config.train.target_batch // self.config.train.batch_size) + else: + self.accum_steps = 1 + + self.accelerator = Accelerator( + log_with = config.logging.log_with, + gradient_accumulation_steps = self.accum_steps + ) + + # Disable tokenizer warnings since they clutter the CLI + kw_str = self.config.train.suppress_log_keywords + if kw_str is not None: + for prefix in kw_str.split(","): + suppress_warnings(prefix.strip()) + + self.pipe = None # Store reference to pipeline so that we can use save_pretrained later + self.model = self.setup_model() + self.optimizer = self.setup_optimizer() + self.scheduler = self.setup_scheduler() + + self.sampler = self.model.sampler + self.model, self.optimizer, self.scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.scheduler + ) + + # Setup tracking + + tracker_kwargs = {} + self.use_wandb = not (config.logging.wandb_project is None) + if self.use_wandb: + log = config.logging + tracker_kwargs["wandb"] = { + "name" : log.run_name, + "entity" : log.wandb_entity, + "mode" : "online" + } + + self.accelerator.init_trackers( + project_name = log.wandb_project, + config = config.to_dict(), + init_kwargs = tracker_kwargs + ) + + self.world_size = self.accelerator.state.num_processes + + def setup_model(self): + """ + Set up model from config. + """ + model = self.get_arch(self.config)(self.config.model, sampler = Sampler(self.config.sampler)) + if self.config.model.model_path is not None: + model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) + + self.pipe = pipe + return model + + def extract_pipeline(self): + """ + Return original pipeline with finetuned denoiser plugged in + + :return: Diffusers pipeline + """ + + self.pipe.unet = self.accelerator.unwrap_model(self.model).unet + return self.pipe + + def load_checkpoint(self, fp : str): + """ + Load checkpoint + + :param fp: File path to checkpoint to load from + """ + self.accelerator.load_state(fp) + self.accelerator.print("Succesfully loaded checkpoint") + + def save_checkpoint(self, fp : str, components = None): + """ + Save checkpoint in main process + + :param fp: File path to save checkpoint to + """ + if self.accelerator.is_main_process: + os.makedirs(fp, exist_ok = True) + self.accelerator.save_state(output_dir=fp) + self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved + + def save_pretrained(self, fp : str): + """ + Save model into pretrained pipeline so it can be loaded in pipeline later + + :param fp: File path to save to + """ + if self.accelerator.is_main_process: + os.makedirs(fp, exist_ok = True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + if self.config.model.lora_rank is not None: + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_model.unet)) + StableDiffusionPipeline.save_lora_weights(fp, unet_lora_layers=unet_lora_state_dict, safe_serialization = unwrapped_model.config.use_safetensors) + else: + self.pipe.unet = unwrapped_model.unet + self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.pipeline_kwargs['use_safetensors']) + self.accelerator.wait_for_everyone() + + def extract_pipeline(self): + """ + Return original pipeline with finetuned denoiser plugged in + + :return: Diffusers pipeline + """ + + self.pipe.unet = self.accelerator.unwrap_model(self.model).unet + return self.pipe diff --git a/src/drlx/trainer/ddpo_trainer.py b/src/drlx/trainer/ddpo_trainer.py index 2a15762..f867252 100644 --- a/src/drlx/trainer/ddpo_trainer.py +++ b/src/drlx/trainer/ddpo_trainer.py @@ -3,7 +3,7 @@ from accelerate import Accelerator from drlx.configs import DRLXConfig, DDPOConfig -from drlx.trainer import BaseTrainer +from drlx.trainer.base_accelerate import AcceleratedTrainer from drlx.sampling import DDPOSampler from drlx.utils import suppress_warnings, Timer, PerPromptStatTracker, scoped_seed, save_images @@ -82,7 +82,7 @@ def collate(batch): return DataLoader(self, collate_fn=collate, **kwargs) -class DDPOTrainer(BaseTrainer): +class DDPOTrainer(AcceleratedTrainer): """ DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging @@ -95,53 +95,6 @@ def __init__(self, config : DRLXConfig): assert isinstance(self.config.method, DDPOConfig), "ERROR: Method config must be DDPO config" - # Figure out batch size and accumulation steps - if self.config.train.target_batch is not None: # Just use normal batch_size - self.accum_steps = (self.config.train.target_batch // self.config.train.batch_size) - else: - self.accum_steps = 1 - - self.accelerator = Accelerator( - log_with = config.logging.log_with, - gradient_accumulation_steps = self.accum_steps - ) - - # Disable tokenizer warnings since they clutter the CLI - kw_str = self.config.train.suppress_log_keywords - if kw_str is not None: - for prefix in kw_str.split(","): - suppress_warnings(prefix.strip()) - - self.pipe = None # Store reference to pipeline so that we can use save_pretrained later - self.model = self.setup_model() - self.optimizer = self.setup_optimizer() - self.scheduler = self.setup_scheduler() - - self.sampler = self.model.sampler - self.model, self.optimizer, self.scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.scheduler - ) - - # Setup tracking - - tracker_kwargs = {} - self.use_wandb = not (config.logging.wandb_project is None) - if self.use_wandb: - log = config.logging - tracker_kwargs["wandb"] = { - "name" : log.run_name, - "entity" : log.wandb_entity, - "mode" : "online" - } - - self.accelerator.init_trackers( - project_name = log.wandb_project, - config = config.to_dict(), - init_kwargs = tracker_kwargs - ) - - self.world_size = self.accelerator.state.num_processes - def setup_model(self): """ Set up model from config. @@ -383,51 +336,4 @@ def time_per_1k(n_samples : int): last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) - del metrics, dataloader, experience_loader - - def save_checkpoint(self, fp : str, components = None): - """ - Save checkpoint in main process - - :param fp: File path to save checkpoint to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - self.accelerator.save_state(output_dir=fp) - self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved - - def save_pretrained(self, fp : str): - """ - Save model into pretrained pipeline so it can be loaded in pipeline later - - :param fp: File path to save to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - unwrapped_model = self.accelerator.unwrap_model(self.model) - if self.config.model.lora_rank is not None: - unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_model.unet)) - StableDiffusionPipeline.save_lora_weights(fp, unet_lora_layers=unet_lora_state_dict, safe_serialization = unwrapped_model.config.use_safetensors) - else: - self.pipe.unet = unwrapped_model.unet - self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) - self.accelerator.wait_for_everyone() - - def extract_pipeline(self): - """ - Return original pipeline with finetuned denoiser plugged in - - :return: Diffusers pipeline - """ - - self.pipe.unet = self.accelerator.unwrap_model(self.model).unet - return self.pipe - - def load_checkpoint(self, fp : str): - """ - Load checkpoint - - :param fp: File path to checkpoint to load from - """ - self.accelerator.load_state(fp) - self.accelerator.print("Succesfully loaded checkpoint") + del metrics, dataloader, experience_loader \ No newline at end of file diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py new file mode 100644 index 0000000..6033ebd --- /dev/null +++ b/src/drlx/trainer/dpo_trainer.py @@ -0,0 +1,204 @@ +from torchtyping import TensorType +from typing import Iterable, Tuple, Callable + +from accelerate import Accelerator +from drlx.configs import DRLXConfig, DPOConfig +from drlx.trainer.base_accelerate import AcceleratedTrainer +from drlx.sampling import DPOSampler +from drlx.utils import suppress_warnings, Timer, scoped_seed, save_images + +import torch +import einops as eo +import os +import gc +import logging +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import numpy as np +import wandb +import accelerate.utils +from PIL import Image +from copy import deepcopy + +from diffusers import DiffusionPipeline + +class DPOTrainer(AcceleratedTrainer): + """ + DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging + + :param config: DRLX config + :type config: DRLXConfig + """ + + def __init__(self, config : DRLXConfig): + super().__init__(config) + + # DPO requires we use vae encode, so let's put it on all GPUs + self.vae = self.accelerator.unwrap_model(self.model).vae + self.vae = self.accelerator.prepare(self.vae) + + assert isinstance(self.config.method, DPOConfig), "ERROR: Method config must be DPO config" + + def setup_model(self): + """ + Set up model from config. + """ + model = self.get_arch(self.config)(self.config.model, sampler = DPOSampler(self.config.sampler)) + if self.config.model.model_path is not None: + model, pipe = model.from_pretrained_pipeline(DiffusionPipeline, self.config.model.model_path) + + self.pipe = pipe + self.pipe.set_progress_bar_config(disable=True) + return model + + def loss( + self, + prompts, chosen_img, rejected_img, ref_denoiser + ): + """ + Get loss for training + + :param chosen_batch_preds: Predictions for the ba + """ + return self.sampler.compute_loss( + prompts=prompts, chosen_img=chosen_img, rejected_img=rejected_img, + denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.vae, + device=self.accelerator.device, + method_config=self.config.method, + accelerator=self.accelerator + ) + + @torch.no_grad() + def deterministic_sample(self, prompts): + """ + Sample images deterministically. Utility for visualizing changes for fixed prompts through training. + """ + gen = torch.Generator(device=self.pipe.device).manual_seed(self.config.train.seed) + self.pipe.unet = self.accelerator.unwrap_model(self.model).unet + return self.pipe(prompts, generator = gen).images + + def train(self, pipeline): + """ + Trains the model based on config parameters. Needs to be passed a prompt pipeline and reward function. + + :param pipeline: Pipeline to draw tuples from with prompts + :type prompt_pipeline: DPOPipeline + """ + + # === SETUP === + do_lora = self.config.model.lora_rank is not None + + # Singular dataloader made to get a sample of prompts + # This sample batch is dependent on config seed so it can be same across runs + with scoped_seed(self.config.train.seed): + dataloader = self.accelerator.prepare( + pipeline.create_loader(batch_size = self.config.train.batch_size, shuffle = False) + ) + sample_prompts = self.config.train.sample_prompts + if sample_prompts is None: + sample_prompts = [] + if len(sample_prompts) < self.config.train.batch_size: + new_sample_prompts = next(iter(dataloader))["prompts"] + sample_prompts += new_sample_prompts + sample_prompts = sample_prompts[:self.config.train.batch_size] + + # Now make main dataloader + + assert isinstance(self.sampler, DPOSampler), "Error: Model Sampler for DPO training must be DPO sampler" + + # Set the epoch count + epochs = self.config.train.num_epochs + if self.config.train.total_samples is not None: + epochs = int(self.config.train.total_samples // self.config.train.num_samples_per_epoch) + + # Timer to measure time per 1k images (as metric) + timer = Timer() + def time_per_1k(n_samples : int): + total_time = timer.hit() + return total_time * 1000 / n_samples + last_batch_time = timer.hit() + + # Ref model + if not do_lora: + ref_model = deepcopy(self.accelerator.unwrap_model(self.model).unet) + ref_model.requires_grad = False + if self.config.method.ref_mem_strategy == "half": ref_model = ref_model.half() + ref_model = self.accelerator.prepare(ref_model) + else: + ref_model = None + + # === MAIN TRAINING LOOP === + + mean_rewards = [] + accum = 0 + last_epoch_time = timer.hit() + for epoch in range(epochs): + dataloader = pipeline.create_loader(batch_size = self.config.train.batch_size, shuffle = True) + dataloader = self.accelerator.prepare(dataloader) + + # Clean up unused resources + self.accelerator._dataloaders = [] # Clear dataloaders + gc.collect() + torch.cuda.empty_cache() + + self.accelerator.print(f"Epoch {epoch}/{epochs}.") + + for batch in tqdm(dataloader): + with self.accelerator.accumulate(self.model): + metrics = self.loss( + prompts = batch['prompts'], + chosen_img = batch['chosen_pixel_values'], + rejected_img = batch['rejected_pixel_values'], + ref_denoiser = ref_model + ) + + self.accelerator.wait_for_everyone() + + # Optimizer step + if self.config.train.grad_clip > 0: + self.accelerator.clip_grad_norm_( + filter(lambda p: p.requires_grad, self.model.parameters()), + self.config.train.grad_clip + ) + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + + # Generate the sample prompts + with torch.no_grad(): + with scoped_seed(self.config.train.seed): + sample_imgs = self.deterministic_sample(sample_prompts) + sample_imgs_wandb = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] + + # Logging + if self.use_wandb: + self.accelerator.log({ + "base_loss" : metrics["diffusion_loss"], + "accuracy" : metrics["accuracy"], + "dpo_loss" : metrics["loss"], + "ref_deviation" : metrics["ref_deviation"], + "time_per_1k" : last_batch_time, + "img_sample" : sample_imgs_wandb + }) + # save images + if self.accelerator.is_main_process and self.config.train.save_samples: + save_images(sample_imgs, f"./samples/{self.config.logging.run_name}/{epoch}") + + + + # Save model every [interval] epochs + accum += 1 + if accum % self.config.train.checkpoint_interval == 0 and self.config.train.checkpoint_interval > 0: + self.accelerator.print("Saving...") + base_path = f"./checkpoints/{self.config.logging.run_name}" + output_path = f"./output/{self.config.logging.run_name}" + self.accelerator.wait_for_everyone() + # Commenting this out for now so I can test rest of the code even though this is broken + self.save_checkpoint(f"{base_path}/{accum}") + self.save_pretrained(output_path) + + last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) + + del metrics + del dataloader + diff --git a/src/drlx/utils/__init__.py b/src/drlx/utils/__init__.py index 49e776d..1cab956 100644 --- a/src/drlx/utils/__init__.py +++ b/src/drlx/utils/__init__.py @@ -22,6 +22,7 @@ class OptimizerName(str, Enum): ADAM_8BIT_BNB: str = "adam_8bit_bnb" ADAMW_8BIT_BNB: str = "adamw_8bit_bnb" SGD: str = "sgd" + RMSPROP: str = "rmsprop" def get_optimizer_class(name: OptimizerName): @@ -57,6 +58,8 @@ def get_optimizer_class(name: OptimizerName): ) if name == OptimizerName.SGD.value: return torch.optim.SGD + if name == OptimizerName.RMSPROP.value: + return torch.optim.RMSprop supported_optimizers = [o.value for o in OptimizerName] raise ValueError(f"`{name}` is not a supported optimizer. " f"Supported optimizers are: {supported_optimizers}") @@ -223,7 +226,12 @@ def save_images(images : np.array, fp : str): os.makedirs(fp, exist_ok = True) - images = [Image.fromarray(image) for image in images] + if isinstance(images, np.ndarray): + images = [Image.fromarray(image) for image in images] + elif isinstance(images, list) and all(isinstance(i, Image.Image) for i in images): + pass + else: + raise ValueError("Images should be either a numpy array or a list of PIL Images") for i, image in enumerate(images): image.save(os.path.join(fp,f"{i}.png")) diff --git a/src/drlx/utils/sdxl.py b/src/drlx/utils/sdxl.py new file mode 100644 index 0000000..205d24d --- /dev/null +++ b/src/drlx/utils/sdxl.py @@ -0,0 +1,14 @@ +import einops as eo +import torch + +def get_time_ids(batch): + """ + Computes time ids needed for SDXL in a heavily simplified manner that only requires image size + (assumes square images). Assumes crop top left is (0,0) for all images. Infers all needed info from batch of images. + """ + + b, c, h, w = batch.shape + + # input_size, crop, input_size + add_time_ids = torch.tensor([h, w, 0, 0, h, w], device = batch.device, dtype = batch.dtype) + return eo.repeat(add_time_ids, 'd -> (b d)', b = b)