diff --git a/examples/toy_train_fsdp.py b/examples/toy_train_fsdp.py new file mode 100644 index 0000000..5e273d7 --- /dev/null +++ b/examples/toy_train_fsdp.py @@ -0,0 +1,403 @@ +from loguru import logger +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +import os +import math +import torch +import argparse +import torch.multiprocessing as mp +import functools +from transformers import ( + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Tokenizer, + get_cosine_schedule_with_warmup, +) +from tqdm import tqdm +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +import torch.distributed as dist +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + +class MoonDataset(Dataset): + + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = (b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [ + p for p in group["params"] if not self.state[p]["use_muon"] + ] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model_cpu(model_name, hidden_size): + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model + + +def get_dataloader(model_name, dataset_name, rank, world_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], + trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", + trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + sampler = DistributedSampler(dataset=train_dataset, + rank=rank, + num_replicas=world_size, + shuffle=True) + kwargs = { + 'batch_size': 16, + 'sampler': sampler, + 'num_workers': 4, + 'pin_memory': True + } + train_loader = DataLoader(train_dataset, **kwargs) + return train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW(model.parameters(), + lr=lr, + weight_decay=wd, + betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p for name, p in model.named_parameters() if p.ndim >= 2 + and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p for name, p in model.named_parameters() + if not (p.ndim >= 2 and "embed_tokens" not in name + and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12365' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def fsdp_main(rank, world_size, args): + + print((rank, world_size)) + setup(rank, world_size) + model = get_model_cpu(args.model, args.hidden_size) + my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, + min_num_params=100) + torch.cuda.set_device(rank) + model.to(rank) + model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) + model.train() + optimizer = get_optimizer(args.optimizer, model, lr=args.lr) + + epoch = 1 + train_loader = get_dataloader(args.model, args.dataset, rank, world_size) + logger.info(('train data length:', len(train_loader))) + + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(rank) + input_ids = batch + + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + + dist.all_reduce(loss, op=dist.ReduceOp.AVG) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if rank == 0: + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" + ) + cleanup() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)