From 34711655be3321f4b77b5568486d6c37be60c830 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Fri, 5 Apr 2024 14:47:58 -0700 Subject: [PATCH] Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable config selectable Norm Type (#181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR has multiple aspects: 1 - Adds a new Triton based Fused RMSNorm I wrote. I've verified it's numerical accuracy on both forward and backward with a unit test. It improves MFU by +15% with FSDP2 7B, and compiled slightly by +1.2%: Screenshot 2024-03-29 at 5 18 14 PM 2 - Adds norms.py to house all 4 norm types, and standardizes to [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]. Norms.py has a create_norms function that then creates the appropriate norm. 3 - Adds np_layernorm, which is layernorm with no affine transformation. 4 - Updates model.py to now support plug and play of any supported norm. Thus instead of this type of if/then logic in the model class: Screenshot 2024-03-30 at 1 52 07 PM We simply have this: Screenshot 2024-03-30 at 1 52 23 PM This then allows for easy plug and play of any norm type with no fiddling around in the model code. 5 - updates run_llama_train.sh to randomly select a port vs previous fixed port number. (thanks @yifuwang for this tip!) 6 - Now users can quickly select the norm of their choice via the config file: Screenshot 2024-03-30 at 3 01 43 PM 7 - adds a NotImpl error if users try to run TP + fused_rnsmorm to avoid any confusion (per @tianyu-l feedback): ~~~ NotImplementedError: fused_rmsnorm not yet compatible with TP. Please use rmsnorm. ~~~ --- .flake8 | 3 +- run_llama_train.sh | 2 +- torchtrain/config_manager.py | 6 + torchtrain/models/llama/model.py | 67 +--- torchtrain/models/norms.py | 314 +++++++++++++++++++ torchtrain/parallelisms/parallelize_llama.py | 5 + train.py | 1 + train_configs/debug_model.toml | 1 + train_configs/llama_13b.toml | 1 + train_configs/llama_70b.toml | 1 + train_configs/llama_7b.toml | 1 + 11 files changed, 346 insertions(+), 56 deletions(-) create mode 100644 torchtrain/models/norms.py diff --git a/.flake8 b/.flake8 index 1bf6e4b0d..8fb7c1063 100644 --- a/.flake8 +++ b/.flake8 @@ -7,8 +7,9 @@ max-line-length = 120 # N812 ignored because import torch.nn.functional as F is PyTorch convention # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) # E731 allow usage of assigning lambda expressions +# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style. ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731 + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806 # shebang has extra meaning in fbcode lints, so I think it's not worth trying # to line this up with executable bit EXE001, diff --git a/run_llama_train.sh b/run_llama_train.sh index e9797f304..5d7a75df2 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -24,6 +24,6 @@ if [ $# -ne 0 ]; then overrides="$*" fi -torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index af48fc5c1..581ca8de7 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -117,6 +117,12 @@ def __init__(self): default="debugmodel", help="which model config to train", ) + self.parser.add_argument( + "--model.norm_type", + type=str, + default="rmsnorm", + help="Layer Normalization type to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]", + ) self.parser.add_argument( "--model.tokenizer_path", type=str, diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 50f467216..afce0ba8b 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from torch import nn +from torchtrain.models.norms import create_norm @dataclass @@ -25,57 +26,7 @@ class ModelArgs: depth_init: bool = ( True # initialization uses each unique layer_id or total model layer count ) - - -class RMSNorm(torch.nn.Module): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.empty(dim)) - self.reset_parameters() - - def _norm(self, x: torch.Tensor): - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - """ - Forward pass through the RMSNorm layer. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - return output * self.weight - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) + norm_type: str = "rmsnorm" def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): @@ -381,8 +332,13 @@ def __init__(self, layer_id: int, model_args: ModelArgs): ) self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + + self.attention_norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) if model_args.depth_init: self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 @@ -447,7 +403,10 @@ def __init__(self, model_args: ModelArgs): for layer_id in range(model_args.n_layers): self.layers.append(TransformerBlock(layer_id, model_args)) - self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) self.init_weights() diff --git a/torchtrain/models/norms.py b/torchtrain/models/norms.py new file mode 100644 index 000000000..1eaed4992 --- /dev/null +++ b/torchtrain/models/norms.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import math + +import torch +import torch.nn as nn + +import triton +import triton.language as tl + + +def create_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Creates the specified normalization layer based on the norm_type. + + Args: + norm_type (str): The type of normalization layer to create. + Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The created normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps) + elif norm_type == "fused_rmsnorm": + return FusedRMSNorm(dim, eps=eps) + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class FusedRMSNorm(nn.Module): + """Fused RMS Norm, wraps a fused Triton Kernel""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.fused_rms_norm_fn = fused_rms_norm_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """leverages Triton Fused RMS Norm kernel""" + return self.fused_rms_norm_fn( + x, + self.weight, + eps=self.eps, + ) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +# FusedRMSNorm in Triton + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows + N, # num cols + block_N: tl.constexpr, +): + + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, +): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class TritonFusedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, eps): + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @staticmethod + def backward(ctx, dy): + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + dw = torch.empty_like(weight) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +# expose fusedRMSNorm as a function +def fused_rms_norm_fn( + x, + weight, + eps=1e-6, +): + return TritonFusedRMSNorm.apply( + x, + weight, + eps, + ) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 1e7a9ed57..34dca31f2 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -137,6 +137,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." + ) + tp_mesh = world_mesh["tp"] row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( job_config diff --git a/train.py b/train.py index e98bc1807..369183f77 100644 --- a/train.py +++ b/train.py @@ -163,6 +163,7 @@ def loss_fn(pred, labels): # build model (using meta init) model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] + model_config.norm_type = job_config.model.norm_type model_config.vocab_size = tokenizer.n_words with torch.device("meta"): diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index a498cca79..17759fab5 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -18,6 +18,7 @@ save_tb_folder = "tb" [model] name = "llama" flavor = "debugmodel" +norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama_13b.toml b/train_configs/llama_13b.toml index 0d6ac45c7..a3c4f1e72 100644 --- a/train_configs/llama_13b.toml +++ b/train_configs/llama_13b.toml @@ -17,6 +17,7 @@ save_tb_folder = "tb" [model] name = "llama" flavor = "13B" +norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama_70b.toml b/train_configs/llama_70b.toml index d6bcd1ac9..20c1def8f 100644 --- a/train_configs/llama_70b.toml +++ b/train_configs/llama_70b.toml @@ -17,6 +17,7 @@ save_tb_folder = "tb" [model] name = "llama" flavor = "70B" +norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama_7b.toml b/train_configs/llama_7b.toml index b2fe966d8..33921bbb5 100644 --- a/train_configs/llama_7b.toml +++ b/train_configs/llama_7b.toml @@ -17,6 +17,7 @@ save_tb_folder = "tb" [model] name = "llama" flavor = "7B" +norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer]