Skip to content

Maheshram1/Parallel_JEPA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Parallel JEPA

Parallel JEPA is a scalable PyTorch framework for training Vision Transformer‑based student‑teacher models under the Joint Embedding Predictive Architecture (JEPA) paradigm. It splits image patches into multiple parts, shuffles and masks them in the student encoder, and reconstructs full patch representations in parallel decoder streams, using mean‑squared error against a frozen “teacher” ViT’s intermediate features. The code is fully distributed via PyTorch DDP, supports mixed precision with torch.cuda.amp, and employs a custom SOAP optimizer for adaptive gradient preconditioning.

Features

  • Student & Teacher ViTs: Implements both the masked/shuffled student JEPA model (VisionTransformer) and a standard teacher ViT (VisionTransformer1), sharing patch embedding but differing in masking and decoder depth (model.py).
  • Parallel Reconstruction: Divides the full sequence of patches into num_parts blocks processed independently by decoder streams, enabling parallel JEPA reconstruction (model.py#L138-L168).
  • Distributed Training: Leverages torch.distributed and DDP for multi‑GPU scaling, with helper utilities for setup/cleanup, synchronized checkpointing, and main‑process logging (utils.py, main.py).
  • Mixed Precision: Optional AMP via GradScaler for faster throughput and lower memory usage (main.py#L188-L189).
  • Custom SOAP Optimizer: Integrates a Shampoo‑style optimizer with second‑order gradient preconditioning for the student model (optimizer.py).
  • Flexible Hyperparameters: All core settings (image size, patch size, embedding dim, batch‑size warmup, learning‑rate schedule) are centralized in config.py with dynamic batch‑size support (config.py).

Installation

  1. Clone the repo

    git clone https://github.com/Maheshram1/Parallel_JEPA.git
    cd Parallel_JEPA
  2. Install dependencies

    pip install -r requirements.txt

    Requires torch>=1.13.0, torchvision>=0.14.0, and tqdm (requirements.txt).

Data Preparation

  • ImageNet: Place training and validation sets in directories and update paths in main.py:
    # main.py
    imagenet_train_path = '/path/to/imagenet/train' # UPDATE THIS
    imagenet_val_path   = '/path/to/imagenet/val'   # UPDATE THIS
    (main.py#L94-L95).

Configuration

All hyperparameters live in config.py:

# config.py
config = Config()
config.img_size = 224
config.patch_size = 14
config.embed_dim = 1280
config.num_layers = 16
config.num_heads  = 32
config.initial_batch_size = 128 # Global batch size start
config.final_batch_size   = 128 # Global batch size end
config.num_epochs = 100
config.base_learning_rate = 1e-3
config.warmup_epochs = 4
config.use_amp = True
# ... and others

You can modify settings like learning rates, warmup, batch size schedule, and AMP usage directly in the Config class (config.py).

Training

Launch distributed training with torchrun (or torch.distributed.run):

Example: Single node, 4 GPUs

torchrun --standalone --nnodes=1 --nproc_per_node=4 main.py

The script (main.py) will:

  1. Initialize DDP and set the CUDA device based on LOCAL_RANK (utils.py#L15-L30).
  2. Build student (VisionTransformer) and teacher (VisionTransformer1) models and optionally compile them with torch.compile (main.py#L120-L147).
  3. Load or initialize teacher weights, freeze its parameters, and wrap both models in DDP (main.py#L150-L175).
  4. Create MSE loss, SOAP optimizer, and LambdaLR scheduler with cosine decay & linear warmup (main.py#L178-L220).
  5. Run epochs: dynamic batch resizing, forwarding through student & teacher, computing loss, backward pass with AMP, optimizer & scheduler steps, teacher refresh, and periodic checkpointing/logging (main.py#L253-L350, engine.py).

Evaluation

Validation occurs at the end of each training epoch using the same DDP setup. It reports the average MSE loss between the student's reconstructed patch representations and the teacher’s intermediate features across the validation set (engine.py#L167-L258).

Checkpoints & Logging

  • Checkpoints: Saved by the main process (rank 0) into the checkpoints/ directory:
    • latest_checkpoint.pth: Overwritten after every epoch.
    • best_model.pth: Overwritten when validation loss improves.
    • checkpoint_epoch_*.pth: Saved periodically (default: every 10 epochs). (main.py#L325-L341).
  • Logs: Saved by the main process into the logs/ directory:
    • training_log.txt: Detailed text log including epoch times, losses, and LR.
    • loss_log.csv: CSV file tracking Epoch, TrainLoss, ValLoss, LearningRate for easier analysis. (main.py#L317-L324).

Utilities

  • DDP Setup & Cleanup: setup(), cleanup(), and is_main_process() in utils.py manage the distributed environment (utils.py#L15-L46).
  • Checkpointing Helpers: save_checkpoint() and load_checkpoint() in utils.py provide robust saving/loading, handling DDP/compile wrappers and ensuring CPU-based saving for portability (utils.py#L63-L219).
  • Teacher Utilities: load_pretrained_teacher_weights() handles loading external weights, while refresh_teacher() copies student weights to the teacher (simulating momentum update) and freezes the teacher each epoch (utils.py#L224-L368).

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages