diff --git a/train.py b/train.py index 8199964..ec43a94 100644 --- a/train.py +++ b/train.py @@ -734,6 +734,8 @@ def finetune_unet(batch, train_encoder=False): is_enabled=True, negation=unet_negation ) + if lora_manager.use_unet_lora: + unet.conv_in.requires_grad_(True) # Convert videos to latent space pixel_values = batch["pixel_values"]