Quokka : Easiest Way to Create, Train, and Customize SoTA LLMs
Quokka is the one and only fully open-source SoTA LLM framework for AI education. All code and features are optimized for exploring and understanding the deep learning behind LLMs like ChatGPT.
-
3x faster download
- download only the necessary rows of data as you train.
- instead of downloading 100+ GB of data, stream a few KB at a time (often 3x faster than GET).
-
Infinite dataset shuffling
- tokenize the training data on the fly.
- instead of loading all training data into memory, load one batch at a time.
- dataset loaders with 10K and 100M rows of text use the same RAM, even with shuffling enabled! 🤯
-
Efficient long-context pretraining
- instead of traditional training using one row of text per forward pass, fill up the tokens to the brim,
- thus significantly reducing the number of steps per epoch.
-
Least overwhelming
- Quokka is more complete, easier to understand, and easier to build/train new SoTA models compared to existing open-source Llama 3.1, GPT-3, and Transformer implementations.
-
Extremely simple
- the main Quokka model is just 157 lines of code, including comments, layers, weight initialization, and forward propagation with training & inference compatibility;
- all done without using Hugging Face
transformers
. 🤯
-
Infinite batch size
- regular batch size (limited) * gradient accumulation steps (unlimited) = effective batch size (unlimited)
-
[Oct 17, 2024]
🚀 Implemented Learn, Focus, and Review in Quokka.
🏆 20x (40x with SophiaG) speedup compared to OpenAI's GPT2 pretraining,
📈 Requiring only 5% of forward pass to achieve lower (better) perplexity scores across 4 datasets. -
[Oct 15, 2024]
🚀 Implemented GPT-4o, GPT-4/GPT-3.5, and GPT-3 tokenization using tiktoken.
🏆 3-6x faster than Hugging Face tokenizers.
✌️ No license restrictions like Llama or Qwen. -
[Oct 15, 2024]
🚀 Implemented Grouped Query Attention in Quokka. -
[Oct 14, 2024]
🔍 Explored Zigzag Ring Attention with Flash Attention.
❗️No improvements on single-GPU training, increased VRAM usage by 0.4 GB for 2048 context length. Only meant for distributed GPUs. -
[Oct 12, 2024]
🎯 Achieved a perplexity score of 1.0 🤯 on 80/20 train/val splits with a subset of 30,000 samples from FineWeb.
🧠 Pretrained with Multiquery Attention, 32 effective batch size, 2 epochs, 350M parameters, 2048 context length, and 12.6 GB VRAM. Works on free Google Colab! -
[Oct 12, 2024]
🏆 Applied Sophia, achieving 2x speedup and less total compute compared to Adam.
🔄 Applied token embedding rotation technique RoPE from Llama 2 by Meta AI.
📈 Reduced complexity and improved performance, especially for longer sequences.
- (optional) Installing Flash Attention on Windows
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
pip install ninja==1.11.1
pip install flash-attn --no-build-isolation
Compilation might take a while (~2 hours) on Windows, depending on your system specs.
Please read this issue for other setups that worked on Windows.
- Build
pip install git+https://github.com/aistrova/quokka.git
- Train
CACHE_DIR = "F:/hf_home"
MODEL_NAME = "Quokka_230M"
DATASET_SIZE = 200_000 # either an integer or "ALL" (to train on the entire dataset)
PERCENT_TRAIN = 1.0
from quokka.optimizers import SophiaG
from quokka.model import Quokka
from quokka.trainer import Trainer
from quokka.config import Quokka230MConfig, PretrainConfig
from quokka.data import get_dataloaders
import torch
import math
train_dataloader, val_dataloader, tokenizer = get_dataloaders(
CACHE_DIR,
DATASET_SIZE,
PERCENT_TRAIN,
openai_encoding="p50k_base"
)
quokka_config = Quokka230MConfig(vocab_size=tokenizer.n_vocab, max_seq_len=PretrainConfig.max_ctx_len, layer_dropout=0.0, feedforward_dropout=0.0)
model = Quokka(quokka_config, pretrained=False)
# checkpoint = torch.load('Quokka_230M_epoch2.pth')
# model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = SophiaG(
model.parameters(),
lr=PretrainConfig.learning_rate,
betas=PretrainConfig.betas,
rho=PretrainConfig.rho,
weight_decay=PretrainConfig.weight_decay
)
# optimizer.load_state_dict(checkpoint['optimizer'])
# Dynamic -> (converted to) fixed number of steps per epoch
steps_per_epoch = sum(1 for _ in train_dataloader)
PretrainConfig.total_steps = PretrainConfig.num_epochs * math.ceil(steps_per_epoch / PretrainConfig.accumulation_steps)
warmup_steps = int(0.1 * PretrainConfig.total_steps)
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=PretrainConfig.total_steps
)
# scheduler.load_state_dict(checkpoint['scheduler'])
# Mixed-precision scaler
scaler = torch.cuda.amp.GradScaler()
trainer = Trainer(
model=model,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
device=device,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
tokenizer=tokenizer
)
# start training
trainer.run(MODEL_NAME)
- Generate text
import torch
from quokka.model import Quokka
from quokka.text_generation import AutoregressiveWrapper
from quokka.data import get_tokenizer
from quokka.config import Quokka230MConfig, PretrainConfig
tokenizer = get_tokenizer("p50k_base")
config = Quokka230MConfig(tokenizer.n_vocab, PretrainConfig.max_ctx_len)
model = Quokka(config, pretrained=True)
checkpoint = torch.load('Quokka_epoch8.pth', weights_only=False)
model.load_state_dict(checkpoint['model'])
model.eval()
pad_value = tokenizer.max_token_value+1
wrapper = AutoregressiveWrapper(model, pad_value=pad_value)
# encode the prompt with bos
prompt_text = "let him cook"
prompt_tokens = torch.tensor([tokenizer.bos] + tokenizer.encode(prompt_text))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
prompt_tokens = prompt_tokens.to(device)
generated_tokens = wrapper.generate(
prompts=prompt_tokens,
temperature=0.7,
seq_len=50,
eos_token=tokenizer.eos
).tolist()
generated_tokens = [token for token in generated_tokens if token != pad_value and token not in tokenizer.special_tokens_set]
generated_text = tokenizer.decode(generated_tokens) # , skip_special_tokens=True
print("Generated Text:")
print(generated_text)
GPT3 with RoPE = GPT3 Plus
from quokka.config import (
GPT3Plus_125M_Config,
GPT3Plus_350M_Config,
GPT3Plus_760M_Config,
GPT3Plus_1B_Config,
GPT3Plus_3B_Config,
GPT3Plus_6B_Config,
GPT3Plus_13B_Config,
GPT3Plus_175B_Config
)
[x] - Import GPT3 Plus - GPT3 with RoPE
[ ] - Import Llama 3.1, Qwen 2.5 as components
[ ] - Reduce pretraining memory by half
[ ] - Finetuning
[ ] - Multimodal
[ ] - Documentation
Quokka, by AIstrova, is licensed under Apache 2.0.
Please attribute if you can to advocate for transparency, ethics, and integrity in AI development and technologies.
An attribution as simple as Built with [Quokka](https://github.com/aistrova/quokka)
or Inspired by [Quokka](https://github.com/aistrova/quokka)
suffice.
Thanks to Fred for leading this initiative, innovating all the Educational Features, and implementing all the SoTA Features up to 10/18/2024.
--
Ex-Google CEO said that AI will make the rich richer and the poor poorer.
At AIstrova, we do not wish for this future. Let's infer and reproduce top closed-source AI (across 6 modalities) to run fast on consumer CPUs and GPUs together!