This repository provides a clean, minimal PyTorch training pipeline for an unconditional diffusion model using a custom UNet2DModel from the HuggingFace diffusers library. The script is designed for fast prototyping and single-class image generation, supporting mixed precision, safe image loading, basic augmentations, and periodic sample export.
- SafeImageFolder: Handles corrupted images robustly.
- Flexible Augmentation: Includes random flip and color jitter.
- Mixed Precision Training: Uses torch.ampfor faster, memory-efficient training.
- Cosine Annealing Scheduler: Smoothly anneals the learning rate.
- Progress Image Saving: Periodically generates and saves sample images during training.
- Model Checkpointing: Saves final model and scheduler states for easy resumption.
- Python 3.8+
- PyTorch
- torchvision
- diffusers
- tqdm
- Pillow
You can install the dependencies using:
pip install torch torchvision diffusers tqdm pillow- 
Prepare your dataset: - Place all training images in a folder named imagesin the project root.
- Only images (e.g., .jpg,.png) should be in this folder.
 
- Place all training images in a folder named 
- 
Run the training script: python train.py - Training progress and sample images will be saved in the progress/directory.
- The final model and scheduler weights are saved as unet_final/andscheduler_final/.
 
- Training progress and sample images will be saved in the 
- SafeImageFolder: Custom Dataset for robust image loading.
- train_diffusion(): Handles all training logic, including loading data, training loop, and checkpointing.
- save_sample(): Denoises and saves generated samples at regular intervals.
- main(): Entry point that sets up the environment and triggers training.
- Hyperparameters: Adjust image_size,batch_size,num_epochs,learning_ratedirectly in the script.
- Model Architecture: Modify UNet2DModelparameters for deeper, wider, or more complex networks.
- Data Augmentation: Tune or expand the transformpipeline for your dataset.
This repository uses HuggingFace Diffusers and PyTorch. If you use this codebase, please consider citing the respective libraries.
This project is licensed under the MIT License.