Skip to content

Gallifantjack/llama2_prebias

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

55 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

License Proposed Work

llama2_prebiasing

⚗️ Status: This is a fun project created on a weekend, so is still in alpha, and may change without warning.

Dalle generated LLama image reading books

📃 Overview

The current machine learning paradigm heavily relies on a pretraining-followed-by-fine-tuning strategy. llama2_prebiasing aims to critically assess this strategy by exploring its underlying principles and effectiveness.

Objectives

The project is driven by two main considerations:

1. Saturation Curves

  • Goal 1: Investigate the occurrence of saturation, where a model, given all information in the world and a small set of tasks, ceases to show significant improvement in one metric. This goal seeks to understand how saturation in one metric relates to others as well as the model architecture, nature of training data, and task specificity.
  • Goal 2: Analyze the model's performance across various training checkpoints and data inputs pre- and post-saturation. This goal aims to compare performance metrics and losses versus the data used, to understand different models' tendencies and learning behaviors.

2. Durability

  • Goal 1: Examine the malleability of fundamental structures (such as embeddings and attention mechanisms) within models, which are not uniform throughout training and across different architectures. This involves demonstrating these aspects in the context of ordered batch sampling using batch metadata.
  • Goal 2: Explore the effects of continued pretraining at different stages of the process, sensitivity to labels, and overt-> covert bias. This includes controlling for losses and performance metrics to understand the model's ability to learn/unlearn biases or improve performance.
Saturation and Influence Curves Durability

Outcomes

  • 🌐 Training with Peripheral vision: Method for evaluating data used and model performance, detecting covert attributes, throughout LM training- real-time or post-hoc at checkpoint level (on HF?).
  • 📊 Fragility: Analysis of model flexibility to new information during pretraining.
  • 🔍 Bias-aware Training: Tools for Data and Model card integration with respect to LLMs before and after fine-tuning.


Workpackages

1. Tinystories Working

a. Download Dataset

First we download tinystories dataset.

python download/download_datasets.py tinystories

This creates a new folder called data in the root directory that contains 'TinyStories_all_data'.

b. Tokenize and Evaluate Batches

We then tokenize the dataset using sentencepiece tokenizer model. This creates a .bin file in the data/TinyStories_all_data folder that contains the tokenized dataset.

python tokenization/pretokenize.py pretokenize

During this process each shard is tokenized and global ids are created

c. Metadata

python metadata/batch_metadata.py compute_metadata

Here each sample is evaluated using metrics such as perplexity and sentence length. These are the same metrics as what is used in model evaluation in 3.

Batch metrics are stored in out/tables/batch_metrics.csv

2. Pretraining Working |

python modelling/train.py

The Modelling folder contains the model configurations, training loop, and dataset class. It also contains the custom sampler and transformation functions that allow modification of the batch order.

During training checkpoint files contain embeddings, attention weights, and batches used to that point, and the model folder contains the bin files. These are all store in the out folder.

3. Model Evaluation Working |

There are three main scripts for checkpoint evaluation:

a. Attention Visualization

python visualize_attn.py

These take the stored attention weights from each checkpoint and plot the attention weights for each layer and head.

All plots from this section are stored in the model specific out dir/visualize

b. Embedding Visualization

python visualize_embd.py

This takes the stored embeddings from each checkpoint and plots the embeddings using PCA.

c. input/output Evaluation

python eval.py

To evaluate the output this loads the model and generates a sequence of 200 tokens. It then compares the generated sequence to the expected sequence and other metrics such as sentance length and outputs the results to a csv file in out/tables.

For the batch metrics, these metrics were already calculated in stage 1b. This script just loads the table and adds the batch metrics of those batches used at a given checkpoint to a summary table.

d. Saturation curves

python visualize_sat_curves.py

This script takes the batch metadata (1c) and the model ouput metadata (3c) and plots the metrics for each checkpoint next to each other. This allows for easy comparison of the metrics of model inputs and outputs

4. Experiments Working |

a. Durability

python experiments/durability/experiment_durability.py

This script takes the model input metadata (1c) and sorts the batches by a given metric. It then trains the model on the sorted batches and evaluates the model output (3c).

Saturation curves are then plotted for each metric and compared to the saturation curves of the unsorted model. Comparisons across time as well as the end points are made to see if the same data in different orders has different effects on the end model.

5. Dashboard Working |

python dash_llama/app.py
Current Dashboard


About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published