#This repo is made for a project in Deep Learning and Multimedia Data Analysis in MCS Artificial Inteligence of Aistotle University of Thessaloniki.
This document provides a comprehensive guide to the ResNet34-LSTM image captioning model, including architecture details, training instructions, and troubleshooting tips.
- Overview
- Installation
- Data Preparation
- Model Architecture
- Training the Model
- Generating Captions
- Model Outputs
- Key Functions
- Troubleshooting
- How A2C Works
This image captioning system combines a ResNet34 convolutional neural network (CNN) for image encoding with a Long Short-Term Memory (LSTM) network with attention for text generation. The model follows the encoder-decoder architecture pattern common in sequence generation tasks.
Key features of the model:
- ResNet34 CNN encoder for feature extraction
- LSTM decoder with attention mechanism
- Beam search for caption generation
- Attention visualization
- BLEU score tracking
- Training metrics tracking
- Pretrained weights support
- Training resumption capability
# Install Python dependencies
pip install torch torchvision numpy matplotlib nltk tqdm scikit-image Pillow
pip install graphviz torchvizFor model architecture visualization, Graphviz is required:
- Download and install from https://graphviz.org/download/
- Add the Graphviz bin directory to your system PATH:
- Windows: Add to system environment variables
- macOS:
export PATH=$PATH:/usr/local/Cellar/graphviz/X.XX/bin - Linux:
export PATH=$PATH:/usr/bin/graphviz
- Restart your terminal/command prompt
resNet_LSTM/
├── model.py # Model architecture definition
├── train.py # Training script
├── test_caption.py # Script to generate captions with visualization
├── captions.py # Simpler script to generate captions
├── create_data_n_prep.py # Script to prepare dataset
├── model_architecture.py # Script to display model architecture
├── utils.py # Utility functions
├── dataset.py # Dataset loading and processing
├── visualize_metrics.py # Script to visualize training metrics
├── data_output/ # Directory for dataset files
└── model_outputs/ # Directory for model outputs
Before training the model, you need to prepare the dataset. This involves downloading the image dataset and corresponding caption annotations, then processing them into the format required by the model.
Download one of the supported datasets:
Download the JSON file containing the training/validation/test splits created by Andrej Karpathy:
Use the create_data_n_prep.py script to process the images and captions:
python create_data_n_prep.pyBy default, the script is configured for Flickr8k with the following settings:
create_input_data(
dataset='flickr8k',
json_path='data/caption_datasets/dataset_flickr8k.json',
image_folder='data/flickr8k/Images',
captions_per_image=5,
min_word_freq=5,
output_folder='data_output',
max_len=50
)Modify these parameters in the script if you're using a different dataset or want to change the configuration:
dataset: Choose from 'flickr8k', 'flickr30k', or 'coco'json_path: Path to JSON fileimage_folder: Directory containing the imagescaptions_per_image: Number of captions to use per imagemin_word_freq: Minimum frequency for a word to be included in vocabularyoutput_folder: Directory to save processed datamax_len: Maximum caption length
The script will create the following files in the data_output directory:
- Word map:
WORDMAP_flickr8k_5_5.json - Encoded captions and lengths:
TRAIN/VAL/TEST_CAPTIONS/CAPLENS_flickr8k_5_5_*.json - Image features:
TRAIN/VAL/TEST_IMAGES_flickr8k_5_5.hdf5
The image captioning model consists of three main components: an encoder, an attention mechanism, and a decoder.
The encoder is a ResNet34 convolutional neural network with the following characteristics:
- Input: RGB images of size 256x256 pixels
- Output: Feature maps of size 14x14 with 512 channels
- Structure: ResNet34 architecture (3-4-6-3 layers configuration)
- Pretrained Option: Can use pretrained weights from torchvision or train from scratch
- Fine-tuning: The encoder can be fine-tuned during training or frozen
The encoder extracts spatial features from the input image, which are then used by the attention mechanism to focus on relevant parts of the image during caption generation.
The attention mechanism allows the decoder to focus on different parts of the image when generating each word:
- Inputs:
- Encoder features: (batch_size, 196, 512)
- Decoder hidden state: (batch_size, 512)
- Processing:
- Transform encoder features with a linear layer
- Transform decoder hidden state with a linear layer
- Combine transformed features and apply tanh activation
- Score each pixel location with another linear layer
- Apply softmax to get attention weights
- Compute weighted sum of encoder features using attention weights
- Output:
- Attention weights: (batch_size, 196)
- Context vector: (batch_size, 512)
The decoder is an LSTM network that generates captions word by word:
- Inputs:
- Encoder features: (batch_size, 196, 512)
- Captions: (batch_size, caption_length) [during training]
- Embedding Layer: Converts word indices to dense vectors of size 512
- LSTM Cell: Processes embedded word and context vector to update hidden state
- Attention Gate: Controls how much attention information to use
- Output Layer: Linear projection from LSTM hidden state to vocabulary size
- Outputs:
- Word predictions: (batch_size, max_length, vocab_size)
- Attention weights: (batch_size, max_length, 196)
The image captioning system is designed to generate descriptive captions for images by combining a convolutional neural network (CNN) encoder with a Long Short-Term Memory (LSTM) decoder. Here's an overview of how the system works:
- The
create_data_n_prep.pyscript processes the dataset by:- Extracting image features using ResNet34.
- Tokenizing and encoding captions into numerical format.
- Creating a word map (vocabulary) and saving it as a JSON file.
- Splitting the dataset into training, validation, and test sets.
- Encoder: A ResNet34 CNN extracts spatial features from input images.
- Attention Mechanism: Focuses on relevant parts of the image for each word in the caption.
- Decoder: An LSTM generates captions word by word, guided by the attention mechanism.
- The
train.pyscript trains the model using the prepared dataset. - Key steps during training:
- Images are passed through the encoder to extract features.
- Captions are tokenized and fed into the decoder.
- The decoder generates predictions for the next word in the sequence.
- Loss is calculated using cross-entropy and attention regularization.
- Gradients are computed and used to update model weights.
- Training metrics (loss, accuracy, BLEU scores) are tracked and saved for visualization.
- The
captions.pyscript generates captions for new images. - Steps:
- The image is passed through the encoder to extract features.
- The decoder generates a caption using beam search or greedy decoding.
- Optionally, attention weights are visualized to show which parts of the image influenced each word.
- BLEU-4 scores are used to evaluate the quality of generated captions.
- Validation and test sets are used to measure model performance.
- Model checkpoints, training metrics, and generated captions are saved in the
model_outputsdirectory. - Attention visualizations and captioned images are also saved for analysis.
This system is modular and extensible, allowing you to experiment with different datasets, architectures, and training strategies.
python train.pyThis will train the model with default parameters:
- 120 epochs
- Batch size of 64
- Random initialization of ResNet34 weights
- No fine-tuning of the encoder
python train.py --pretrainedThis loads pretrained ResNet34 weights from torchvision to initialize the encoder.
python train.py --pretrained --fine_tune_encoderThis allows the encoder parameters to be updated during training, which can improve performance.
You can customize the training with the following parameters:
--resume: Resume training from the latest checkpoint--fine_tune_encoder: Enable fine-tuning of the encoder--epochs: Number of training epochs (default: 120)--batch_size: Batch size for training (default: 64)--checkpoint: Specific checkpoint to resume from--pretrained: Use pretrained ResNet34 weights--one_hot: Use one-hot-encoding on the word embeding of the captions (default: False)
Example with custom parameters:
python train.py --pretrained --fine_tune_encoder --epochs 50 --batch_size 32 --one_hotTo resume training from the last saved checkpoint:
python train.py --resumeWhen using the --resume flag, the model will:
- Look for the checkpoint file at
model_outputs/checkpoint_flickr8k_5_5.pth.tar - Load the model weights, optimizer states, and training progress
- Continue training from the epoch where it left off
- Preserve the BLEU score history
- Load and continue tracking training metrics (loss and accuracy)
You can also resume with modified settings:
python train.py --resume --fine_tune_encoder --epochs 20Note: When resuming training, the --epochs parameter specifies the total number of epochs to train, not additional epochs. For example, if you've already trained for 5 epochs and set --epochs 10, the model will train for 5 more epochs.
The training script automatically tracks and saves the following metrics for each epoch:
- Training Loss: Average loss across all training batches
- Training Top-5 Accuracy: Percentage of predictions where the correct word is among the top 5 predictions
- Validation Loss: Average loss on validation data
- Validation Top-5 Accuracy: Top-5 accuracy on validation data
- BLEU-4 Score: Evaluation metric for generated captions
All metrics are saved to JSON files in the model_outputs directory. To visualize the training progress:
python visualize_metrics.pyThis will generate comprehensive plots and print summary statistics about the training.
For image captioning models, BLEU-4 scores typically fall in these ranges:
- 0.00-0.05: Poor performance
- 0.05-0.10: Basic performance (model is learning)
- 0.10-0.20: Reasonable/moderate performance
- 0.20-0.30: Good performance
-
0.30: Excellent (state-of-the-art on some datasets)
- Start with pretrained weights (
--pretrained): This gives the model a better starting point. - Initially train with frozen encoder: First train without
--fine_tune_encoderfor about 10-15 epochs. - Fine-tune the encoder: After initial training, enable
--fine_tune_encoderto improve performance. - Resume training: Use
--resumeif training is interrupted. - Monitor the metrics: Watch for plateaus in the validation metrics to determine when to stop training.
- Automatic learning rate adjustment: If BLEU scores don't improve for 10 epochs, the learning rate is automatically reduced.
- Early stopping: After 20 epochs without improvement, training stops automatically.
There are two scripts for generating captions: captions.py. The first provides more features including attention visualization, while the second is simpler.
The captions.py script provides a simpler way to generate captions:
python captions.py --img path/to/your/image.jpg --model model_outputs/BEST_flickr8k_5_5.pth.tar --word_map data_output/WORDMAP_flickr8k_5_5.jsonRequired arguments:
--imgor-i: Path to the image you want to caption--modelor-m: Path to your trained model checkpoint file--word_mapor-wm: Path to the word map JSON file (created during data preparation)
Optional arguments:
--beam_sizeor-b: Size of beam search (default is 5)--visualize_attentionor-v: Add this flag to visualize attention weights for each word--captions_jsonor-cj: Path to the original captions JSON file to display original captions alongside generated ones
The script will:
- Load the specified model and word map
- Generate a caption using beam search
- Print the generated caption to the console
- If original captions are provided, print them as well
- Display the image with both generated and original captions (if available)
- Save the captioned image to
model_outputs/caption_result.png - If original captions are available, save a comparison visualization to
model_outputs/caption_comparison.png - If attention visualization is enabled, it will also generate and save attention maps to
model_outputs/attention_visualization.png
Example for visualizing attention:
python captions.py --img test_images/dog.jpg --model model_outputs/BEST_flickr8k_5_5.pth.tar --word_map data_output/WORDMAP_flickr8k_5_5.json --visualize_attentionExample with original captions:
python captions.py --img test_images/dog.jpg --model model_outputs/BEST_flickr8k_5_5.pth.tar --word_map data_output/WORDMAP_flickr8k_5_5.json --captions_json data/caption_datasets/dataset_flickr8k.jsonTo visualize which parts of the image the model focuses on when generating each word:
python test_caption.py --image path/to/your/image.jpg --visualize_attentionThis will generate:
- A caption for the image
- A visualization showing attention weights for each word in the caption
- The attention visualization is saved to
model_outputs/attention_visualization.png
The model uses beam search to generate captions during inference:
- Beam Size: Controls how many caption candidates are maintained during generation (default: 5)
- Process:
- Start with the start token
- At each step, consider all possible next words for each candidate sequence
- Keep the top-k (beam size) sequences with highest probability
- Continue until all sequences reach the end token or maximum length
- Return the sequence with highest probability
To adjust the beam size:
python test_caption.py --image path/to/your/image.jpg --beam_size 10All model outputs are saved to the model_outputs directory, which is automatically created by the scripts if it doesn't exist.
| File Pattern | Description |
|---|---|
checkpoint_*.pth.tar |
Regular model checkpoints saved during training |
BEST_*.pth.tar |
Best model weights based on validation BLEU scores |
bleu_scores_*.json |
JSON file tracking BLEU scores across epochs |
bleu_scores_*.png |
Plot of BLEU scores across training epochs |
training_metrics_*.json |
JSON file with loss and accuracy values |
loss_plot_*.png |
Plot of training and validation loss |
accuracy_plot_*.png |
Plot of training and validation accuracy |
attention_visualization.png |
Visualization of attention weights for each word in a caption |
caption_result.png |
Generated caption overlaid on the input image |
model_architecture.png |
Visual diagram of the model architecture |
training_visualization_*.png |
Comprehensive visualization of all training metrics |
To visualize the model architecture:
python model_architecture.pyThis will:
- Print detailed information about the model architecture
- Generate a visual diagram of the model architecture (requires Graphviz)
- Save the visualization to
model_outputs/model_architecture.png
- main(): Entry point for training, handles model initialization and training loop
- train(): Performs one epoch of training
- validate(): Performs validation and calculates BLEU scores
- save_bleu_scores(): Saves BLEU scores to a file and creates a plot
- save_training_metrics(): Saves training metrics to a file and creates plots
- save_checkpoint(): Saves model state to a checkpoint file
- adjust_learning_rate(): Reduces learning rate to help convergence
- AverageMeter: Class for tracking metrics during training
- clip_gradient(): Prevents gradient explosion during training
- accuracy(): Calculates top-k accuracy for evaluation
- EncoderCNN: ResNet34-based encoder class
- Attention: Attention mechanism implementation
- LSTMDecoderWithAttention: Decoder with attention implementation
- ImageCaptioningModel: Combined model for visualization
- caption_image(): Generates a caption for a given image
- visualize_att(): Visualizes attention weights for each word
-
Out of Memory Errors:
- Reduce batch size with
--batch_size 32or even lower - Disable fine-tuning of encoder
- Use a smaller model variant
- Reduce batch size with
-
Slow Training:
- Ensure you're using GPU if available
- Set
workers=0to avoid h5py pickling issues - Optimize data loading
-
Low BLEU Scores:
- Ensure you're using pretrained weights
- Train for more epochs
- Enable fine-tuning after initial training
- Verify dataset preparation is correct
-
Training Doesn't Progress When Resuming:
- Check that the
--epochsparameter is greater than the current epoch number - Verify the checkpoint file exists
- Make sure there's enough disk space for new checkpoints
- Check that the
If you encounter errors with model visualization:
- Verify Graphviz is installed: Run
dot -vin a terminal - If not installed, download from https://graphviz.org/download/
- Add the Graphviz bin directory to your system PATH
- On Windows, you may need to restart your computer after installation
- The model will function without Graphviz, but visualization will not be available
-
"TypeError: h5py objects cannot be pickled":
- Set
num_workers=0in DataLoader to resolve this
- Set
-
"CUDA out of memory":
- Reduce batch size
- Move to CPU if necessary (
device = torch.device('cpu'))
-
"checkpoint_*.pth.tar not found":
- Ensure you have previously trained the model
- Check the
model_outputsdirectory exists
-
"ImportError: cannot import name 'ResNet34_Weights'":
- Update to a newer version of PyTorch/torchvision
- The code should fall back to an older method if this fails
-
"Dimension out of range" in encoder outputs:
- This can happen if your model architecture has changed or was trained differently
- The captions.py script has been updated to handle different tensor dimensions automatically
- If you still encounter this error, check if your encoder outputs dimensions match what the decoder expects
- Common tensor shape should be (batch_size, channels, height, width) for encoder outputs
-
Image loading errors in captions.py:
- The script now supports both PIL and OpenCV image loading methods
- Ensure your image file exists and is a valid image format (jpg, png, etc.)
- If you're getting errors with PIL, the script will automatically try OpenCV as a fallback
The Advantage Actor-Critic (A2C) reinforcement learning approach is implemented to improve the image captioning model by directly optimizing for BLEU-4 scores. Here's an overview of how A2C works in this project:
- Actor: The existing image captioning model (encoder-decoder with attention) acts as the policy network.
- It generates captions by sampling words from its probability distribution.
- Critic: A separate neural network predicts the expected reward (value) of the generated sequence at each time step.
- Sampling: The actor generates captions by sampling words from its probability distribution.
- Reward: BLEU-4 scores are used as rewards for the generated captions.
- Advantage Calculation: The advantage is computed as the difference between the actual rewards and the critic's predicted values.
- Advantage = Reward - Predicted Value
- Policy Gradient Update: The actor is updated using the policy gradient method, weighted by the advantage.
- Critic Update: The critic is trained to minimize the mean squared error between its predictions and the actual rewards.
- Entropy Regularization: Encourages exploration by penalizing low-entropy (overconfident) predictions.
- Discount Factor (Gamma): Rewards are discounted over time to prioritize immediate rewards.
- Gradient Clipping: Prevents exploding gradients during training.
- The actor uses the existing encoder-decoder model with attention.
- The critic is a simple multi-layer perceptron (MLP) that takes the decoder's hidden states as input and predicts the expected reward.
- BLEU-4 scores are computed for each generated caption and used as the final reward.
- The A2C training loop alternates between updating the actor and the critic.
To train the model using A2C, run the following command:
python -m src.A2C --checkpoint model_outputs/BEST_flickr8k_5_5.pth.tar --epochs 20 --batch_size 32- Freeze Encoder: You can now completely freeze the CNN encoder during training using the
--freeze_encoderflag. - Resume Training: Continue training from a previous A2C checkpoint with the
--resumeflag, preserving optimizer states and metrics. - Metrics Saving: Training and BLEU metrics are automatically saved and exported to the model output folder.
- Actor Loss: Measures how well the actor improves its policy based on the advantage.
- Critic Loss: Measures the accuracy of the critic's value predictions.
- BLEU-4 Scores: Tracks the quality of generated captions over time.
This approach allows the model to directly optimize for BLEU-4 scores, leading to more accurate and diverse captions.
The A2C implementation has been enhanced with several new features to improve training stability and performance:
- Freeze Encoder: You can now completely freeze the CNN encoder during training using the
--freeze_encoderflag. This is useful for stabilizing training when the encoder is already well-trained. - Resume Training: Continue training from a previous A2C checkpoint with the
--resumeflag, preserving optimizer states and metrics. - Entropy Annealing: Gradually reduces the entropy weight over epochs to balance exploration and exploitation.
- Reward Clipping: Stabilizes training by clipping rewards to a specified range.
- Reward Baseline: Reduces variance in policy updates by using a baseline calculated as the mean reward for the batch.
- Temperature Control: Adjusts the exploration level during sampling with the
--temperatureparameter. - Metrics Saving: Training and BLEU metrics are automatically saved and exported to the model output folder.
To train the model using the updated A2C implementation, use the following command:
python -m src.A2C --checkpoint model_outputs/BEST_flickr8k_5_5.pth.tar --epochs 30 --batch_size 32 --temperature 1.2 --entropy_weight 0.05 --value_loss_weight 0.2 --freeze_encoder--checkpoint: Path to the model checkpoint to resume from.--epochs: Number of epochs to train.--batch_size: Batch size for training.--temperature: Temperature for sampling (higher values = more exploration).--entropy_weight: Initial weight for entropy regularization.--value_loss_weight: Weight for the value loss term.--freeze_encoder: Completely freeze the CNN encoder during training.--resume: Resume training from a previous A2C checkpoint.--no_entropy_annealing: Disable entropy weight annealing.
These updates provide more control over the training process and help achieve better performance by directly optimizing for BLEU-4 scores.
This implementation is based on the "Show, Attend and Tell" paper by Xu et al. # Image_captioning