Skip to content
Open

DPO #30

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
71c3014
Add base trainer for any accelerate model
shahbuland Sep 17, 2023
af358b7
Add PickaPic pipeline for DPO
shahbuland Sep 27, 2023
b140623
add skeleton for DPO trainer
shahbuland Sep 27, 2023
704a85c
Pipeline for DPO
Jan 23, 2024
68ec789
Allow for list of images instead of just list of np arrays for sample…
Jan 23, 2024
e70e537
Add sampler for DPO
Jan 23, 2024
2202d9f
Add method config for DPO
Jan 23, 2024
9304bac
Add DPO trainer initial version
Jan 23, 2024
1752621
basic debugs
Jan 24, 2024
b11e873
Remove streaming
Jan 24, 2024
9201d6a
minor bug fixes
Jan 24, 2024
35dd03d
Moved saving from DDPO trainer to base accelerate
Jan 24, 2024
e16526f
LoRA, refactorings, quick bug fixes
Jan 25, 2024
14fe254
small bug fixes
Jan 25, 2024
e121257
bug fixes
Jan 25, 2024
6d9e03d
Fix import errors and checkpointing
Jan 25, 2024
c2350cb
Add base model loss deviation to sampling as metric
Jan 26, 2024
765b9f6
Add base model loss deviation to trainer logging as metric
Jan 26, 2024
74012cc
Add non-lora training with memory saving options in config
Jan 26, 2024
ef91f92
some refactorings to sampling, add rmsprop
Jan 28, 2024
38847c5
Delete old DPO example, push new one
Jan 28, 2024
be05515
Rename DPO2 to DPO
Feb 13, 2024
e6023a3
Move DPO and DDPO sampler to their own files for better organiation
Feb 13, 2024
5253473
prepare for adding SDXL
Feb 13, 2024
54f6ec1
Fix issue with modularizing samplers
Feb 13, 2024
44f163d
Add SDXL support and reorganize config for model
Feb 13, 2024
565efe6
Remove mandatory gradient clipping and fix model saving with new config
Feb 13, 2024
4324932
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
dde1265
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
70f1827
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions configs/dpo_pickapic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
method:
name : "DPO"

model:
model_path: "stabilityai/stable-diffusion-2-1"
pipeline_kwargs:
use_safetensors: True
variant: "fp16"
sdxl: False
model_arch_type: "LDMUnet"
attention_slicing: True
xformers_memory_efficient: False
gradient_checkpointing: True


sampler:
guidance_scale: 7.5
num_inference_steps: 50

optimizer:
name: "adamw"
kwargs:
lr: 2.048e-8
weight_decay: 1.0e-4
betas: [0.9, 0.999]

scheduler:
name: "linear" # Name of learning rate scheduler
kwargs:
start_factor: 1.0
end_factor: 1.0

logging:
run_name: 'dpo_pickapic'
#wandb_entity: None
#wandb_project: None

train:
num_epochs: 500
num_samples_per_epoch: 256
batch_size: 1
target_batch: 256
checkpoint_interval: 640
tf32: True
suppress_log_keywords: "diffusers.pipelines,transformers"
24 changes: 24 additions & 0 deletions examples/DPO/train_dpo_pickapic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys

sys.path.append("./src")

from drlx.trainer.dpo_trainer import DPOTrainer
from drlx.configs import DRLXConfig
from drlx.utils import get_latest_checkpoint

# Pipeline first
from drlx.pipeline.pickapic_dpo import PickAPicDPOPipeline

import torch

pipe = PickAPicDPOPipeline()
resume = False

config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml")
trainer = DPOTrainer(config)

if resume:
cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}")
trainer.load_checkpoint(cp_dir)

trainer.train(pipe)
34 changes: 25 additions & 9 deletions src/drlx/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ class DDPOConfig(MethodConfig):
buffer_size: int = 32 # Set to None to avoid using per prompt stat tracker
min_count: int = 16

@register_method("DPO")
@dataclass
class DPOConfig(MethodConfig):
"""
Config for DPO-related hyperparams

:param beta: Deviation from initial model
:type beta: float

:param ref_mem_strategy: Strategy for managing reference model on memory. By default, puts it in 16 bit.
:type ref_mem_strategy: str
"""
name : str = "DPO"
beta : float = 0.9
ref_mem_strategy : str = None # None or "half"

@dataclass
class TrainConfig(ConfigClass):
"""
Expand Down Expand Up @@ -144,7 +160,7 @@ class TrainConfig(ConfigClass):
num_epochs: int = 50
total_samples: int = None
num_samples_per_epoch: int = 256
grad_clip: float = 1.0
grad_clip: float = -1
checkpoint_interval: int = 10
checkpoint_path: str = "checkpoints"
seed: int = 0
Expand Down Expand Up @@ -219,14 +235,14 @@ class ModelConfig(ConfigClass):
:param model_path: Path or name of the model (local or on huggingface hub)
:type model_path: str

:param model_arch_type: Type of model architecture.
:type model_arch_type: str
:param pipeline_kwargs: Keyword arguments for pipeline if model is being loaded from one
:type pipeline_kwargs: dict

:param use_safetensors: Use safe tensors when loading pipeline?
:type use_safetensors: bool
:param sdxl: Using SDXL model?
:type sdxl: bool

:param local_model: Force model to load checkpoint locally only
:type local_model: bool
:param model_arch_type: Type of model architecture. Defaults to LDM UNet
:type model_arch_type: str

:param attention_slicing: Whether to use attention slicing
:type attention_slicing: bool
Expand All @@ -242,9 +258,9 @@ class ModelConfig(ConfigClass):
"""

model_path: str = None
pipeline_kwargs : dict = None
sdxl : bool = False
model_arch_type: str = None
use_safetensors : bool = False
local_model : bool = False
attention_slicing: bool = False
xformers_memory_efficient: bool = False
gradient_checkpointing: bool = False
Expand Down
35 changes: 31 additions & 4 deletions src/drlx/denoisers/ldm_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None,
super().__init__(config, sampler_config, sampler)

self.unet : UNet2DConditionModel = None

self.text_encoder = None
self.text_encoder_2 = None # SDXL Support, just needs to be here for device mapping

self.vae = None
self.encode_prompt : Callable = None

Expand All @@ -37,6 +40,8 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None,

self.scale_factor = None

self.sdxl_flag = self.config.sdxl

def get_input_shape(self) -> Tuple[int]:
"""
Figure out latent noise input shape for the UNet. Requires that unet and vae are defined
Expand Down Expand Up @@ -65,16 +70,25 @@ def from_pretrained_pipeline(self, cls : Type, path : str):
:rtype: LDMUNet
"""

pipe = cls.from_pretrained(path, use_safetensors = self.config.use_safetensors, local_files_only = self.config.local_model)
kwargs = self.config.pipeline_kwargs
kwargs["torch_dtype"] = torch.float32

pipe = cls.from_pretrained(path, **kwargs)

if self.config.attention_slicing: pipe.enable_attention_slicing()
if self.config.xformers_memory_efficient: pipe.enable_xformers_memory_efficient_attention()

self.unet = pipe.unet
self.text_encoder = pipe.text_encoder

# SDXL compat
if self.sdxl_flag:
self.text_encoder_2 = pipe.text_encoder_2

self.vae = pipe.vae
self.scale_factor = pipe.vae_scale_factor
self.encode_prompt = pipe._encode_prompt

self.encode_prompt = pipe.encode_prompt

self.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
Expand Down Expand Up @@ -149,7 +163,8 @@ def forward(
time_step : Union[TensorType["batch"], int], # Note diffusers tyically does 999->0 as steps
input_ids : TensorType["batch", "seq_len"] = None,
attention_mask : TensorType["batch", "seq_len"] = None,
text_embeds : TensorType["batch", "d"] = None
text_embeds : TensorType["batch", "d"] = None,
added_cond_kwargs = {}
) -> TensorType["batch", "channels", "height", "width"]:
"""
For text conditioned UNET, inputs are assumed to be:
Expand All @@ -162,8 +177,20 @@ def forward(
return self.unet(
pixel_values,
time_step,
encoder_hidden_states = text_embeds
encoder_hidden_states = text_embeds,
added_cond_kwargs = added_cond_kwargs
).sample

@property
def device(self):
return self.unet.device

def enable_adapters(self):
if self.config.lora_rank:
self.unet.enable_adapters()

def disable_adapters(self):
if self.config.lora_rank:
self.unet.disable_adapters()


30 changes: 30 additions & 0 deletions src/drlx/pipeline/dpo_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import abstractmethod
from typing import Tuple, Callable

from PIL import Image

from drlx.pipeline import Pipeline

class DPOPipeline(Pipeline):
"""
Pipeline for training with DPO. Returns prompts, chosen images, and rejected images
"""
def __init__(self, *args):
super().__init__(*args)

@abstractmethod
def __getitem__(self, index : int) -> Tuple[str, Image.Image, Image.Image]:
pass

def make_default_collate(self, prep : Callable):
def collate(batch : Iterable[Tuple[str, Image.Image, Image.Image]]):
prompts = [d[0] for d in batch]
chosen = [d[1] for d in batch]
rejected = [d[2] for d in batch]

return prep(prompts, chosen, rejected)

return collate



65 changes: 65 additions & 0 deletions src/drlx/pipeline/pickapic_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from datasets import load_dataset
import io

from drlx.pipeline.dpo_pipeline import DPOPipeline

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def convert_bytes_to_image(image_bytes, id):
try:
image = Image.open(io.BytesIO(image_bytes))
image = image.resize((512, 512))
return image
except Exception as e:
print(f"An error occurred: {e}")

def create_train_dataset():
ds = load_dataset("yuvalkirstain/pickapic_v2",split='train')
ds = ds.filter(lambda example: example['has_label'] == True and example['label_0'] != 0.5)
return ds

class Collator:
def __call__(self, batch):
# Batch is list of rows which are dicts
image_0_bytes = [b['jpg_0'] for b in batch]
image_1_bytes = [b['jpg_1'] for b in batch]
uid_0 = [b['image_0_uid'] for b in batch]
uid_1 = [b['image_1_uid'] for b in batch]

label_0s = [b['label_0'] for b in batch]

for i in range(len(batch)):
if not label_0s[i]: # label_1 is 1 => jpg_1 is the chosen one
image_0_bytes[i], image_1_bytes[i] = image_1_bytes[i], image_0_bytes[i]
# Swap so image_0 is always the chosen one

prompts = [b['caption'] for b in batch]

images_0 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_0_bytes, uid_0)]
images_1 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_1_bytes, uid_1)]

images_0 = torch.stack([transforms.ToTensor()(image) for image in images_0])
images_0 = images_0 * 2 - 1

images_1 = torch.stack([transforms.ToTensor()(image) for image in images_1])
images_1 = images_1 * 2 - 1

return {
"chosen_pixel_values" : images_0,
"rejected_pixel_values" : images_1,
"prompts" : prompts
}

class PickAPicDPOPipeline(DPOPipeline):
"""
Pipeline for training LDM with DPO
"""
def __init__(self):
self.train_ds = create_train_dataset()
self.dc = Collator()

def create_loader(self, **kwargs):
return DataLoader(self.train_ds, collate_fn = self.dc, **kwargs)
Loading