Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training with pytorch lightning #230

Open
wants to merge 71 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
135675d
adding requirements
ekg Dec 28, 2023
c6463f2
FSDP training draft
ekg Dec 28, 2023
186bf87
debugging
ekg Dec 28, 2023
52828a5
Revert "debugging"
ekg Dec 28, 2023
6900450
rewrite to use fairscale fsdp
ekg Dec 29, 2023
94490e9
tweak import
ekg Dec 31, 2023
07c7661
this doesnt work so we will try with lightning
ekg Dec 31, 2023
2162a41
rewrite to use lightning
ekg Dec 31, 2023
b8d47b0
install pytorch-lightning
ekg Dec 31, 2023
f6ca729
use newer format for call to the trainer
ekg Dec 31, 2023
d85522e
require tensorboard
ekg Dec 31, 2023
3788cf5
provide loss based on cross entropy loss
ekg Dec 31, 2023
c63b8b7
use automatic mixed precision
ekg Dec 31, 2023
8c34d55
adjust use of AMP
ekg Dec 31, 2023
64d00ca
training appears to work
ekg Dec 31, 2023
700d810
using all possible contexts
ekg Dec 31, 2023
338b1c4
add stride parameter
ekg Dec 31, 2023
9cd2b7f
simple generation script
ekg Jan 1, 2024
df54375
load from checkpoint
ekg Jan 1, 2024
bf25e3d
fix import
ekg Jan 1, 2024
076ea81
not good strategy
ekg Jan 1, 2024
fe888e6
add perplexity logging
ekg Jan 1, 2024
990c564
control output directory and use save_pretrained
ekg Jan 1, 2024
8bb6ab8
forward to mixer seq mamba save and load
ekg Jan 1, 2024
1ec72c0
hmm
ekg Jan 1, 2024
b3a5120
add proper model loader
ekg Jan 1, 2024
a48373f
use save_pretrained in training
ekg Jan 1, 2024
8781a64
make the model loadable from file
ekg Jan 1, 2024
2a54d2d
is it working?
ekg Jan 1, 2024
c1a7bc3
use real variable
ekg Jan 2, 2024
0b78105
try to get real text
ekg Jan 2, 2024
14dfaaf
tweak setup
ekg Jan 2, 2024
d974130
work towards memory mapped input
ekg Jan 2, 2024
70aec11
try to improve inference with better dynamic range
ekg Jan 2, 2024
bbdcd2e
extreme version of dynamic range in inference
ekg Jan 2, 2024
cf012f9
there was no obvious difference
ekg Jan 2, 2024
5097260
FSDP strategy to shard model
ekg Jan 2, 2024
87079a9
correct typo
ekg Jan 2, 2024
20d5422
oops
ekg Jan 2, 2024
e93ed26
try tokenizer parallelism
ekg Jan 2, 2024
4816409
save full checkpoints
ekg Jan 3, 2024
71a0fd7
add saving method for working with fsdp trainer
ekg Jan 3, 2024
e551b80
use the fsdp saver
ekg Jan 4, 2024
55c5e99
fix save pattern to use the right function
ekg Jan 4, 2024
314ca5a
be explicit
ekg Jan 4, 2024
b1ddc30
try to save a chkpt differently
ekg Jan 4, 2024
24acea9
hmm
ekg Jan 4, 2024
4307181
well that was useless
ekg Jan 4, 2024
bfd7625
well...
ekg Jan 4, 2024
07673b7
opps
ekg Jan 4, 2024
7f953c6
what happens
ekg Jan 4, 2024
c9814c4
manual mode save
ekg Jan 4, 2024
28e0bc1
ok
ekg Jan 4, 2024
67d5fa2
un-oops
ekg Jan 5, 2024
744c181
why not
ekg Jan 5, 2024
ae30ae3
why not
ekg Jan 5, 2024
4caf9b3
magari
ekg Jan 5, 2024
d11121c
rewrite the state dict because of stuff
ekg Jan 5, 2024
1b9250c
import ordereddict
ekg Jan 5, 2024
d6e79c5
try to build a 2.8b model with fsdp
ekg Jan 5, 2024
99778ca
does this decrease memory
ekg Jan 5, 2024
4d69658
not really
ekg Jan 5, 2024
ec8abec
print stuff out
ekg Jan 5, 2024
7359088
cleanup print statements
ekg Jan 5, 2024
61f7fb3
count tokens/s
ekg Jan 5, 2024
28b06f1
in time
ekg Jan 5, 2024
f302984
rewrite to be byte level
ekg Jan 26, 2024
cfdcb9b
little cleanup
ekg Jan 26, 2024
4151fda
use the word "context" rather than block
ekg Jan 26, 2024
869c957
Merge branch 'main' of https://github.com/state-spaces/mamba into tra…
ekg Feb 28, 2024
e93dfc7
Merge branch 'main' of https://github.com/state-spaces/mamba into tra…
ekg Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
@@ -239,11 +239,21 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
return CausalLMOutput(logits=lm_logits)

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, on_hf=True, **kwargs):
if on_hf:
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
else:
# otherwise, load the state dict from the pretrained model file
model_path = os.path.join(pretrained_model_name, 'pytorch_model.bin')
config_path = os.path.join(pretrained_model_name, 'config.json')
with open(config_path, 'r') as f:
config_data = json.load(f)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(torch.load(model_path, map_location=device))
return model

def save_pretrained(self, save_directory):
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
packaging
wheel
fairscale
tensorboard
numpy
pytorch-lightning
torch==2.1.0
transformers==4.35.0
66 changes: 66 additions & 0 deletions train/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import argparse
import time
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

# Setting up the parser for command line arguments
parser = argparse.ArgumentParser(description="mamba model generation tool")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the trained model checkpoint")
parser.add_argument("--prompt", type=str, default=None, help="Initial text to start generation")
parser.add_argument("--genlen", type=int, default=100, help="Length of the generation")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for controlled randomness")
parser.add_argument("--topk", type=int, default=1, help="Top-k sampling strategy")
parser.add_argument("--topp", type=float, default=1.0, help="Top-p (nucleus) sampling strategy")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Penalty for repetition")
args = parser.parse_args()

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32

# Loading the model from the spiritual checkpoint
#print(f"Loading model from the checkpoint: {args.checkpoint_path}")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(args.checkpoint_path, on_hf=False).to(device)
model.eval()

# Preparing the prompt
torch.random.manual_seed(0)
if args.prompt is None:
input_ids = torch.randint(1, 1000, (1, args.genlen), dtype=torch.long, device=device)
else:
input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)

torch.random.manual_seed(0)
if args.prompt is None:
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
else:
tokens = tokenizer(args.prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + args.genlen

# Generation settings
max_length = input_ids.shape[1] + args.genlen

fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)

# Generate and decode the text
out = fn()
if args.prompt is not None:
for elem in tokenizer.batch_decode(out.sequences.tolist()):
print(elem)

178 changes: 178 additions & 0 deletions train/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import FSDPStrategy
import time
from collections import OrderedDict
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig

class TextDataset(Dataset):
def __init__(self, file_path, context_size, eos_token=0):
self.context_size = context_size
self.eos_token = eos_token
self.data_file = file_path
self.mmap_array = np.memmap(self.data_file, dtype='uint8', mode='r')
self.seed = 42

def __len__(self):
return len(self.mmap_array) - self.context_size + 1

def __getitem__(self, idx):
generator = random.Random(self.seed + idx)
start = generator.randint(0, len(self.mmap_array) - 1)
end = start + self.context_size
if end > len(self.mmap_array):
padding_size = end - len(self.mmap_array)
data_slice = np.concatenate(
(self.mmap_array[start:], np.full(padding_size, self.eos_token, dtype='uint8'))
)
else:
data_slice = self.mmap_array[start:end]
return torch.tensor(data_slice, dtype=torch.long)

def collate_batch(batch):
return pad_sequence(batch, batch_first=True, padding_value=0)

class MambaDataModule(pl.LightningDataModule):
def __init__(self, file_path, context_size, batch_size, num_workers, split_ratio=0.8):
super().__init__()
self.file_path = file_path
self.context_size = context_size
self.batch_size = batch_size
self.num_workers = num_workers
self.split_ratio = split_ratio

def setup(self, stage=None):
dataset = TextDataset(self.file_path, self.context_size)
train_size = int(len(dataset) * self.split_ratio)
val_size = len(dataset) - train_size
self.train_dataset, self.val_dataset = random_split(dataset, [train_size, val_size])

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=collate_batch, pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=collate_batch, pin_memory=True)

class MambaModel(pl.LightningModule):
def __init__(self, mamba_config):
super().__init__()
self.model = MambaLMHeadModel(mamba_config)
self.last_step_end_time = time.time()

def forward(self, input_ids):
return self.model(input_ids)

def training_step(self, batch, batch_idx):
start_time = self.last_step_end_time
input_ids = batch
with torch.cuda.amp.autocast(): # mixed precision training
outputs = self(input_ids)
labels = input_ids[:, 1:].contiguous()
logits = outputs.logits[:, :-1, :].contiguous()
loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
perplexity = torch.exp(loss)
tokens_in_batch = input_ids.numel()

self.last_step_end_time = time.time()
elapsed_time = self.last_step_end_time - start_time
if elapsed_time > 0:
tokens_per_second = tokens_in_batch / elapsed_time
self.log('tokens_per_second', tokens_per_second, on_step=True, on_epoch=False, sync_dist=True)

self.log('train_loss', loss, sync_dist=True)
self.log('train_perplexity', perplexity, sync_dist=True)
return loss

def validation_step(self, batch, batch_idx):
input_ids = batch
with torch.cuda.amp.autocast(): # mixed precision training
outputs = self(input_ids)
labels = input_ids[:, 1:].contiguous()
logits = outputs.logits[:, :-1, :].contiguous()
loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
self.log('val_loss', loss, sync_dist=True)

def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=3e-5)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, verbose=True)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler, "monitor": "val_loss"}

def save_pretrained(self, *args, **kwargs):
return self.model.save_pretrained(*args, **kwargs)

def main(args):
pl.seed_everything(42)

os.makedirs(args.output_dir, exist_ok=True)

mamba_config = MambaConfig(
d_model=2560,
n_layer=64,
vocab_size=256, # byte level
ssm_cfg={},
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8
)

model = MambaModel(mamba_config)
data_module = MambaDataModule(args.file_path, args.context_size, args.batch_size, args.num_workers)

checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, monitor='val_loss', save_top_k=3, mode='min')
lr_monitor = LearningRateMonitor(logging_interval='epoch')

logger = TensorBoardLogger("tb_logs", name="mamba_model")

trainer = pl.Trainer(
max_epochs=args.num_epochs,
logger=logger,
log_every_n_steps=1,
accelerator='gpu',
strategy=FSDPStrategy(state_dict_type="full"),
use_distributed_sampler=False,
devices=args.num_gpus,
callbacks=[checkpoint_callback, lr_monitor],
precision='16-mixed'
)

trainer.fit(model, datamodule=data_module)

if trainer.is_global_zero:
print(f"Saving model to {os.path.join(args.output_dir, args.model_name)}")
checkpoint = torch.load(checkpoint_callback.best_model_path)
model = MambaLMHeadModel(mamba_config).to('cpu')
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('model.'):
k = k[6:]
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.save_pretrained(os.path.join(args.output_dir, args.model_name))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--context_size", type=int, default=1024)
parser.add_argument("--file_path", type=str, required=True, help="Path to the input text file")
parser.add_argument("--model_name", type=str, default="mamba_model")
parser.add_argument("--output_dir", type=str, default="./")
args = parser.parse_args()

main(args)