From 4583d9964833ec8417037cf1b05472606ada6a80 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 11 Apr 2025 05:04:29 -0700 Subject: [PATCH] Enabling MOE Quantization using linear decomposition [WIP] Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. current tests are running locally but will be added once working. currently int8wo and int8dq are working for multi and single token moe inference while int4wo is being finished up. TODO move test set into ao, move quantizable moe module code to ao test on hf model definition. Test Plan: Reviewers: Subscribers: Tasks: Tags: testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/mixtral-moe/generate.py | 396 +++++++++++++++++ torchao/_models/mixtral-moe/model.py | 360 +++++++++++++++ torchao/_models/mixtral-moe/run.sh | 6 + .../scripts/convert_hf_checkpoint.py | 100 +++++ .../_models/mixtral-moe/scripts/download.py | 30 ++ torchao/dtypes/affine_quantized_tensor_ops.py | 50 ++- torchao/dtypes/uintx/plain_layout.py | 11 + .../dtypes/uintx/tensor_core_tiled_layout.py | 139 ++++-- ...est_int8_dynamic_activation_intx_weight.py | 45 ++ .../linear_activation_quantized_tensor.py | 2 + .../prototype/moe_quant/__init__.py | 0 .../prototype/moe_quant/llama4_quant.py | 416 ++++++++++++++++++ .../moe_quant/quantizable_moe_modules.py | 101 +++++ .../quantization/prototype/moe_quant/run.sh | 1 + .../quantization/prototype/moe_quant/utils.py | 241 ++++++++++ torchao/quantization/quant_api.py | 2 +- torchao/quantization/transform_module.py | 1 + torchao/quantization/utils.py | 13 +- torchao/utils.py | 2 +- 19 files changed, 1857 insertions(+), 59 deletions(-) create mode 100644 torchao/_models/mixtral-moe/generate.py create mode 100644 torchao/_models/mixtral-moe/model.py create mode 100644 torchao/_models/mixtral-moe/run.sh create mode 100644 torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py create mode 100644 torchao/_models/mixtral-moe/scripts/download.py create mode 100644 torchao/quantization/prototype/moe_quant/__init__.py create mode 100644 torchao/quantization/prototype/moe_quant/llama4_quant.py create mode 100644 torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py create mode 100644 torchao/quantization/prototype/moe_quant/run.sh create mode 100644 torchao/quantization/prototype/moe_quant/utils.py diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py new file mode 100644 index 0000000000..5a2167dc23 --- /dev/null +++ b/torchao/_models/mixtral-moe/generate.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config +torch.manual_seed(0) + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._dynamo.config.capture_scalar_outputs = True + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from sentencepiece import SentencePieceProcessor + +from model import Transformer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[:, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + batch_size: int, + *, + interactive: bool, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + device, dtype = prompt.device, prompt.dtype + + + T = prompt.size(-1) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt + + with torch.device(device): + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + + input_pos = torch.arange(0, T, device=device) + next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) + seq[:, T] = next_token.squeeze() + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + + return seq + +def encode_tokens(tokenizer, string, bos=True, device='cuda'): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision): + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + + model = model.to(device=device, dtype=precision) + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + batch_size: int = 1, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + moe_quant: Optional[str] = None, + profile: Optional[Path] = None, + device='cuda', +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print(f"Using device={device}") + precision = torch.bfloat16 + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + import torchao + from torchao.quantization.quant_api import ( + quantize_, + Int8WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Float8WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + PerRow, + _replace_with_custom_fn_if_matches_filter, + ) + from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter + + if moe_quant: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + config = None + if "int8wo" in moe_quant: + config = MoEQuantConfig(Int8WeightOnlyConfig()) + + elif "int8wo-base" in moe_quant: + config=1 + def int8wo_quant_convert_fn(module, config): + def quant_tensor(weight): + from torchao.quantization.quant_api import ( + MappingType, + to_affine_quantized_intx, + ) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = [1 for x in range(weight.dim())] + block_size[-1] = weight.shape[-1] + block_size = tuple(block_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + return new_weight + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + assert hasattr(module, "w1") + assert hasattr(module, "w2") + assert hasattr(module, "w3") + + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + new_param = quant_tensor(param) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + _replace_with_custom_fn_if_matches_filter( + model, + replacement_fn=int8wo_quant_convert_fn, + filter_fn=cond_ffn_filter, + extra_args=(Int8WeightOnlyConfig(),) + ) + + elif "int8dq" in moe_quant: + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + + elif "int8dq-base" in moe_quant: + pass + + elif "int4wo" in moe_quant: + config = MoEQuantConfig(Int4WeightOnlyConfig()) + + elif "int4wo-base" in moe_quant: + pass + + elif "fp8wo" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8dq" in moe_quant: + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + + else: + assert config is not None, f"expected moe_quant to match one of the options but got {moe_quant}" + + if isinstance(config, MoEQuantConfig): + quantize_(model, config, filter_fn=cond_ffn_filter) + + + if compile: + # moe quant + compile causes repeated warnings + import warnings + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + + torch._inductor.config.assert_indirect_indexing = False + + global decode_one_token, prefill + + if batch_size > 1 or (isinstance(moe_quant, str) and "base" not in moe_quant): + # if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + + if args.compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + batch_size, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y[0].tolist())) + else: + print() + tokens_generated = y.size(-1) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') + parser.add_argument('--moe_quant', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--device', type=str, default="cuda", help='device to use') + + args = parser.parse_args() + print(args) + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.moe_quant, args.profile, args.device + ) diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py new file mode 100644 index 0000000000..3d76e0e326 --- /dev/null +++ b/torchao/_models/mixtral-moe/model.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class ConditionalFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.dim) # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) + return out.reshape(batch_size, -1, self.dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.num_experts = config.num_experts + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + cur_out = F.linear( y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) + return final_out diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh new file mode 100644 index 0000000000..482acb8a04 --- /dev/null +++ b/torchao/_models/mixtral-moe/run.sh @@ -0,0 +1,6 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 +export CHECKPOINT_PATH=/data/users/cdhernandez/gpt-fast/checkpoints/ + + + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile diff --git a/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py new file mode 100644 index 0000000000..b120c5c56d --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import glob +import json +import re +import sys +from pathlib import Path +from typing import Optional + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import ModelArgs + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + weight_map = { + "tok_embeddings.weight": "tok_embeddings.weight", + "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", + "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", + "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", + "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", + "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", + "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", + "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", + "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", + "norm.weight": "norm.weight", + "output.weight": "output.weight", + } + + pt_files = glob.glob(str(checkpoint_dir / "*.pt")) + + merged_result = {} + for file in sorted(pt_files): + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'.(\d+).', '.{}.', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] + elif "w1" in key or "w3" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() + elif "w2" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + elif "gate" in key: + final_result[key] = final_result[key].contiguous() + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') + parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/mistralai/Mixtral-8x7B-v0.1")) + parser.add_argument('--model_name', type=str, default=None) + + args = parser.parse_args() + convert_hf_checkpoint( + checkpoint_dir=args.checkpoint_dir, + model_name=args.model_name, + ) diff --git a/torchao/_models/mixtral-moe/scripts/download.py b/torchao/_models/mixtral-moe/scripts/download.py new file mode 100644 index 0000000000..7dc828004f --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/download.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: + from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + try: + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + except HTTPError as e: + if e.response.status_code == 401: + print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + else: + raise e + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') + parser.add_argument('--repo_id', type=str, default="checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1", help='Repository ID to download from.') + parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + args = parser.parse_args() + hf_download(args.repo_id, args.hf_token) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 50ef8c9e89..be84430067 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -460,11 +460,13 @@ def _(func, types, args, kwargs): shape[dim] = end - start block_size = self.block_size assert ( - len(block_size) == 2 - ), f"Slice only works for 2d block_size right now, got: {block_size}" + len(block_size) in [2,3] + ), f"Slice only works for 2 and 3d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + if len(block_size) == 2: + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__( aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, @@ -473,10 +475,50 @@ def _(func, types, args, kwargs): self.quant_max, self.zero_point_domain, dtype=self.dtype, - strides=self.stride(), + strides=self.stride() if len(block_size)==2 else None, ) return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + self, indices = args + assert len(indices) == 1, f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}" + + new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices) + shape = tuple([indices[0].numel(), *self.shape[1:]]) + + block_size = self.block_size + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = fill_defaults(args, 3, [0, 0]) + assert dim==0, f"op {func} currently only implemented for dim=0 but got dim={dim}" + assert self.dim() == 3, f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}" + + new_tensor_impl = aten.select.int(self.tensor_impl, dim, index) + + shape = self.shape[1:] + block_size = self.block_size[1:] + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index b4b9e06f1a..042ab04564 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -154,6 +154,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) + + elif func in [aten.select.int, aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ), + ) + elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 901c4c4640..3890010d61 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -93,11 +93,13 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: + y=act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] @@ -119,7 +121,7 @@ class TensorCoreTiledLayout(Layout): inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -160,7 +162,7 @@ def post_process( zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -168,10 +170,10 @@ def post_process( (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) assert ( - len(block_size) == 2 - ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}" - scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0] - scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1] + len(block_size) == 2 or len(block_size) == 3, + ), f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}" + scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2] + scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1] scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) zero_point = torch.nn.functional.pad( zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) @@ -262,21 +264,29 @@ def from_plain( _layout: Layout, ): assert isinstance(_layout, TensorCoreTiledLayout) - - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert ( - int_data.dtype == torch.uint8 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + + def quant_2d(int_data_2d): + if TORCH_VERSION_AT_LEAST_2_5: + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to(torch.uint8) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d, _layout.inner_k_tiles + ) + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) if zero_point is not None else None else: - assert ( - int_data.dtype == torch.int32 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) if zero_point is not None else None from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) @@ -336,6 +346,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + if func in [aten.select.int, aten.index.Tensor]: + assert not (func is aten.select.int and args[1]!=0), "aten.select.int currently only has support for dim=0" + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ), + ) + + if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -386,11 +408,16 @@ def block_size(self): scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape - assert len(cur_shape) == 4 + if len(cur_shape) == 5: + ones = [1,1] + cur_shape = cur_shape[1:] + else: + assert len(cur_shape) == 4 + ones = [1] inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) groupsize = int(original_shape[1] / scale.shape[-2]) - return (1, groupsize) + return tuple([*ones, groupsize]) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( @@ -399,35 +426,53 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + def dequant_4d(self): + cur_shape = self.shape + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + original_dtype = torch.bfloat16 + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=self.device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + return dequantized + + cur_shape = self.shape + + if len(cur_shape)==4: + dequantized = dequant_4d(self) + else: + assert len(cur_shape) == 5 + num_experts = cur_shape[0] + dequantized_list = [] + for expert in range(num_experts): + dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0)) + dequantized = torch.cat(dequantized_list, dim=0) + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() - cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) device = self.device - original_dtype = torch.bfloat16 + target_dtype = torch.int32 quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() int_data = quantize_affine( dequantized, - block_size, + self.block_size, scale, zero, target_dtype, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 94fcebd9d4..e482bd62d8 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -428,6 +428,51 @@ def test_moved_error(self): granularity=PerGroup(64), ) + def test_moe_quant(self): + from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable + from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor + from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig, PackedLinearInt8DynamicActivationIntxWeightLayout, quantize_ + from torchao.quantization.utils import compute_error + + with torch.device("cuda"): + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2).to(torch.bfloat16) + x = torch.randn(8, 512, dtype=torch.bfloat16) + + out = model(x).clone() + + # base_config = Int8DynamicActivationIntxWeightConfig() + base_config = Int8DynamicActivationIntxWeightConfig(layout = PackedLinearInt8DynamicActivationIntxWeightLayout()) + moe_config = MoEQuantConfig(base_config) + + quantize_(model, moe_config, cond_ffn_filter) + + + out_q = model(x).clone() + assert isinstance(model.experts.w1, FakeExtraDimTensor) + + mod_c = torch.compile(model, mode="reduce-overhead") + + mod_c(x) + mod_c(x) + + + out_qc = mod_c(x).clone() + + print(compute_error(out_q, out)) + print(compute_error(out_qc, out)) + + assert compute_error(out_q, out)>30 and compute_error(out_qc, out)>30, "error bad accuracy but everything ran" + + + + + + + + + + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4343a086f..a4e9bce7d8 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -82,6 +82,8 @@ def __tensor_unflatten__( def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): + if input_tensor.numel() == 0: + return input_tensor input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor quant_kwargs = weight_tensor.quant_kwargs diff --git a/torchao/quantization/prototype/moe_quant/__init__.py b/torchao/quantization/prototype/moe_quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py new file mode 100644 index 0000000000..0ba21ce7ea --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -0,0 +1,416 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# import torch +# from tabulate import tabulate +# from transformers import AutoModelForCausalLM, AutoTokenizer + +# try: +# from lm_eval.evaluator import evaluate +# from lm_eval.models.huggingface import HFLM +# from lm_eval.tasks import get_task_dict +# except ImportError: +# print(""" +# Error: The 'lm_eval' module was not found. +# To install, follow these steps: +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git +# """) +# raise # Re-raise the ImportError + +# from torchao.quantization import ( +# autoquant, +# fpx_weight_only, +# int4_weight_only, +# int8_dynamic_activation_int8_weight, +# int8_weight_only, +# quantize_, +# ) +# from torchao.sparsity import ( +# semi_sparse_weight, +# sparsify_, +# ) + +# torch._inductor.config.force_fuse_int_mm_with_mul = True +# torch._inductor.config.fx_graph_cache = True + + +# def pretty_print_nested_results(results, precision: int = 6): +# def format_value(value): +# if isinstance(value, float): +# return f"{value:.{precision}f}" +# return value + +# main_table = [] +# for task, metrics in results["results"].items(): +# subtable = [[k, format_value(v)] for k, v in metrics.items() if k != "alias"] +# subtable.sort(key=lambda x: x[0]) # Sort metrics alphabetically +# formatted_subtable = tabulate(subtable, tablefmt="grid") +# main_table.append([task, formatted_subtable]) + +# print(tabulate(main_table, headers=["Task", "Metrics"], tablefmt="grid")) + + +# def run_evaluation( +# repo_id, +# tasks, +# limit, +# device, +# precision, +# quantization, +# sparsity, +# compile, +# save, +# batch_size, +# max_length, +# ): +# tokenizer = AutoTokenizer.from_pretrained(repo_id) +# model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to( +# device +# ) + +# if quantization == "autoquant" and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# if quantization == "int8dq": +# quantize_(model, int8_dynamic_activation_int8_weight()) +# elif quantization == "int8wo": +# quantize_(model, int8_weight_only()) +# elif quantization == "int4wo": +# # note cannot quantize this model on cpu and run it on cuda at this time +# quantize_(model.to(device=device), int4_weight_only()) +# elif quantization == "fp6": +# quantize_(model, fpx_weight_only(3, 2)) +# elif quantization == "autoquant": +# model = autoquant(model.to(device=device)) +# elif quantization == "awq": +# from torchao.prototype.awq.example import get_calib_dataset +# from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + +# if not TORCH_VERSION_AT_LEAST_2_3: +# print("AWQ quantization requires torch2.3+") +# exit() +# from torchao.prototype.awq import ( +# AWQObservedLinear, +# awq_uintx, +# insert_awq_observer_, +# ) + +# quant_dtype = torch.uint4 +# group_size = 64 +# calibration_limit = 10 +# calibration_seq_length = 1024 +# model = model.to(device) +# insert_awq_observer_( +# model, +# calibration_limit, +# calibration_seq_length, +# quant_dtype=quant_dtype, +# group_size=group_size, +# ) +# with torch.no_grad(): +# calibration_data = get_calib_dataset( +# tokenizer=tokenizer, +# n_samples=calibration_limit, +# block_size=calibration_seq_length, +# ) +# for batch in calibration_data: +# model(batch.to(device)) +# del batch +# is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) +# quantize_( +# model, +# awq_uintx(quant_dtype=quant_dtype, group_size=group_size), +# is_observed_linear, +# ) + +# if quantization != "autoquant" and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# if sparsity == "semi_sparse": + +# def all_linear(mod, name): +# if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: +# return True +# return False + +# torch.sparse.semi_structured._FORCE_CUTLASS = False +# sparsify_(model, semi_sparse_weight(), filter_fn=all_linear) +# elif sparsity == "semi_sparse_mlp_only": + +# def all_linear(mod, name): +# if ( +# isinstance(mod, torch.nn.Linear) +# and "lm_head" not in name +# and "mlp" in name +# ): +# return True +# return False + +# torch.sparse.semi_structured._FORCE_CUTLASS = False +# sparsify_(model, semi_sparse_weight(), filter_fn=all_linear) + +# if sparsity and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# with torch.no_grad(): +# result = evaluate( +# HFLM( +# pretrained=model.to(device), +# tokenizer=tokenizer, +# batch_size=batch_size, +# max_length=max_length, +# ), +# get_task_dict(tasks), +# limit=limit, +# ) + +# pretty_print_nested_results(result) + +# if save: +# # This doesn't work yet: https://github.com/huggingface/transformers/issues/32364 +# # model.save_pretrained("quantized_model_test", safe_serialization=False) +# file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt" +# torch.save(model.state_dict(), file_name) + + +# if __name__ == "__main__": +# import argparse + +# parser = argparse.ArgumentParser(description="Run HF Model Evaluation") +# parser.add_argument( +# "--repo_id", +# type=str, +# default="meta-llama/Meta-Llama-3-8B", +# help="Repository ID to download from HF.", +# ) +# parser.add_argument( +# "--tasks", +# nargs="+", +# type=str, +# default=["wikitext"], +# help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", +# ) +# parser.add_argument( +# "--limit", type=int, default=None, help="Number of eval samples to evaluate" +# ) +# parser.add_argument( +# "--precision", +# type=lambda x: getattr(torch, x.split(".")[-1]), +# default=torch.bfloat16, +# help="dtype precision to use", +# ) +# parser.add_argument( +# "--device", type=str, default="cuda", help="Device to use for evaluation" +# ) +# parser.add_argument( +# "-q", +# "--quantization", +# default="None", +# choices=["int8dq", "int8wo", "int4wo", "autoquant", "awq", "None"], +# help="Which quantization technique to apply", +# ) +# parser.add_argument( +# "-s", +# "--sparsity", +# default="None", +# choices=["semi_sparse", "semi_sparse_mlp_only", "None"], +# help="Which sparsity technique to apply", +# ) +# parser.add_argument( +# "--compile", action="store_true", help="Whether to compile the model." +# ) +# parser.add_argument( +# "--save", action="store_true", help="Whether to save the model." +# ) +# parser.add_argument( +# "--batch_size", +# type=int, +# default=1, +# help="Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes", +# ) +# parser.add_argument( +# "--max_length", +# type=int, +# default=None, +# help="Length of text to process at one time", +# ) + +# args = parser.parse_args() +# run_evaluation( +# args.repo_id, +# args.tasks, +# args.limit, +# args.device, +# args.precision, +# args.quantization, +# args.sparsity, +# args.compile, +# args.save, +# args.batch_size, +# args.max_length, +# ) + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, hidden_dim, expert_dim, num_experts, top_k, act_fn=F.silu, shared_expert=None) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable(num_experts, hidden_dim, expert_dim, act_fn) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert=None + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.gate(x) # [T, E] + scores = F.sigmoid(scores, dim=-1) + scores, expert_indices = torch.topk(scores, self.top_k, dim=-1) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + if self.shared_expert: + out += self.shared_expert(x) + return out.reshape(batch_size, -1, self.hidden_dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn): + super().__init__() + self.w1 = nn.Parameter(torch.empty(num_experts, expert_dim, hidden_dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, expert_dim)) # E, D, I + self.w3 = nn.Parameter(torch.empty(num_experts, expert_dim, hidden_dim)) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + cur_out = F.linear( self.act_fn(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + cur_out = F.linear( self.act_fn(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) + return final_out + +from transformers import Llama4ForCausalLM, AutoTokenizer +import torch +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +from transformers.models.llama4 import Llama4ForCausalLM, Llama4TextMoe + +model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + +from t + +def llama4_moe_filter_fn(module, fqn): + return isinstance(model, Llama4TextMoe) + +def convert_fn(module): + # get data + hidden_dim = module.hidden_dim + expert_dim = module.experts.expert_dim + num_experts = module.num_experts + top_k = module.top_k + act_fn = module.experts.act_fn + shared_expert = module.shared_expert + MOEFeedForwardAOQuantizable( + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn, + shared_expert, + ) + + router = module.router + w1 + w2 + w3 + + + + + + + + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +prompt = "I am" +inputs = tokenizer(prompt, return_tensors="pt") + + +generate_ids = model.generate(inputs.input_ids, max_length=10) +out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] +print(out) diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py new file mode 100644 index 0000000000..7b0f04cae3 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -0,0 +1,101 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, hidden_dim, expert_dim, num_experts, top_k, act_fn=F.silu, shared_expert=None) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable(num_experts, hidden_dim, expert_dim, act_fn) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert=None + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.router(x) # [T, E] + scores = F.softmax(scores, dim=-1) + scores, expert_indices = torch.topk(scores, self.top_k, dim=-1) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + if self.shared_expert: + out += self.shared_expert(x) + return out.reshape(batch_size, -1, self.hidden_dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn): + super().__init__() + self.w1 = nn.Parameter(torch.randn(num_experts, expert_dim, hidden_dim)) # E, I, D + self.w2 = nn.Parameter(torch.randn(num_experts, hidden_dim, expert_dim)) # E, D, I + self.w3 = nn.Parameter(torch.randn(num_experts, expert_dim, hidden_dim)) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + cur_out = F.linear( self.act_fn(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + y1 = F.silu(F.linear(cur_x, w1)) + y3 = F.linear(cur_x, w3) + y2 = w2 + + cur_out = F.linear(y1 * y3, y2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) + return final_out diff --git a/torchao/quantization/prototype/moe_quant/run.sh b/torchao/quantization/prototype/moe_quant/run.sh new file mode 100644 index 0000000000..d7b819a828 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/run.sh @@ -0,0 +1 @@ +python llama4_quant.py #--repo_id "/meta-llama/Llama-4-Maverick-17B-128E-Instruct" diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py new file mode 100644 index 0000000000..215fee0e34 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -0,0 +1,241 @@ +import torch + +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +aten = torch.ops.aten + +from torchao.utils import fill_defaults + +from torchao.quantization.quant_api import AOBaseConfig, register_quantize_module_handler, dataclass +from typing import Union, List, Tuple, Optional + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]=None): + super().__init__() + self.weight = weight + self.bias = bias + +class FakeExtraDimTensor(torch.Tensor): + """This is a subclass of torch.Tensor that simulates a tensor of n+1 dimensions, akin to concatenating several tensors along the 0th dimension. + It takes a list of tensors with the same dtype, device and shape and creates a representation of shape (num_tensors, orig_shape). It can handle a + variety of ops like detach and clone but most importantly, supports any slicing and indexing along the extra dimension. + This is most useful when you have another tensor subclass that you'd like to concatenate together but don't want to support all the necessary + pieces of 3D scaffolding required to make it work. + + The structure of this tensor subclass is a linked_list of tensors with each instance of FakeExtraDimTensor containing a head tensor and a tail consisting of + either another intance of FakeExtraDimTensor or None if we've reached the end of the linked list. This implementation structure is necessary to + support compilation of this tensor subclass since compile requires each tensor component of the tensor subclass to have its own attribute. + """ + def __new__( + cls, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"]=None, + ): + assert len(tensors)>0 or tensor_tail is not None + num_tensors = len(tensors) + if tensor_tail is not None: + num_tensors += tensor_tail.num_tensors + test_tensor = tensor_tail.head_tensor + else: + test_tensor = tensors[0] + + dtype = test_tensor.dtype + shape = test_tensor.shape + device = test_tensor.device + layout = test_tensor.layout + for tensor in tensors: + assert tensor.dtype==dtype, f"all tensors in FakeExtraDimTensor must have same dtype but got {tensor.dtype} and {dtype}" + assert tensor.shape==shape, f"all tensors in FakeExtraDimTensor must have same shape but got {tensor.shape} and {shape}" + assert tensor.device == device, f"all tensors in FakeExtraDimTensor must have same device but got {tensor.device} and {device}" + assert tensor.layout == layout, f"all tensors in FakeExtraDimTensor must have same layout but got {tensor.layout} and {layout}" + kwargs = {} + kwargs["dtype"] = dtype + kwargs["layout"] = layout + kwargs["device"] = device + kwargs["requires_grad"]=False + new_shape = (num_tensors, *shape) + return torch.Tensor._make_wrapper_subclass(cls, new_shape, **kwargs) + + def __repr__( + self, + ): + return f"{self.__class__.__name__}(shape={self.shape}, containing {self.num_tensors}: {self.head_tensor})" + + def __init__( + self, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"]=None, + ): + tensors = list(tensors) + assert len(tensors)>0 or tensor_tail is not None + + # count num_tensors and make tensor_list + self.num_tensors = len(tensors) + if tensor_tail is not None: + self.num_tensors += tensor_tail.num_tensors + tail_list = tensor_tail.tensor_list + else: + tail_list = [] + self.tensor_list = tensors + tail_list + + # 3 cases + # 0) tensors has 0 elements -> take element from tail then do case 1 instead + # 1) tensors has 1 element, -> pop element and tail is None + # 2) tensors has >1 elements, -> pop element and recurse + + # convert case 0 to case 1 by taking 1 element from tail + if len(tensors) == 0 and tensor_tail is not None: + tensors = [tensor_tail.head_tensor,] + tensor_tail = tensor_tail.tensor_tail + + if len(tensors) > 1: + # case (1): remove first element from tensors, then recurse + self.head_tensor = tensors[0] # remove one + self.tensor_tail = self.__class__(tensors[1:], tensor_tail) # recurse + elif len(tensors) == 1: + # case (2) take final element from tensors, attach tensor_tail then stop recursion + self.head_tensor = tensors[0] + self.tensor_tail = tensor_tail + + def _apply_fn_to_data(self, fn): + self.head_tensor = fn(self.head_tensor) + if self.tensor_tail is not None: + self.tensor_tail = self.tensor_tail._apply_fn_to_data(fn) + return self.__class__([self.head_tensor], self.tensor_tail) + + def __tensor_flatten__(self): + if self.tensor_tail is None: + return ["head_tensor",], [self.num_tensors] + else: + return ["head_tensor", "tensor_tail",], [self.num_tensors] + + # def __get_item__(self, indices): + # if isinstance(indices, torch.Tensor): + + # elif isinstance(indices, int): + # return + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, + ): + head_tensor = tensor_data_dict["head_tensor"] + tensor_tail = tensor_data_dict.get("tensor_tail", None) + return cls([head_tensor], tensor_tail) + + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if func is torch.nn.functional.linear: + x, w, bias = ( + args[0], args[1], args[2] if len(args) > 2 else None, + ) + assert w.num_tensors == 1, "FakeExtraDimTensor used in a linear op when it had multiple tensors" + return func(x, w.head_tensor, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception as e: + print(f"ERR: subclass {cls} doesn't implement {func}, got error: {e}") + + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func == aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim==0: + return return_and_correct_aliasing( + func, + args, + kwargs, + cls(self.tensor_list[start:end:step]) + ) + + elif func == aten.select.int: + self, dim, index = fill_defaults(args, 3, [0, 0]) + if dim==0: + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index]]) + ) + elif func == aten.index.Tensor: + self, indices, dim = fill_defaults(args, 3, [0]) + if dim==0: + # this handles a weird bug where indices gets turned into a list + # between the function dispatch and torch dispatch but just for this function + if isinstance(indices, list) and len(indices)==1: + indices = indices[0] + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index] for index in indices]) + ) + try: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ) + ) + except Exception as e: + print( + f"function {func} failed for FakeExtraDimTensor, following error occured when trying to" + "run function on its elements: " + ) + raise e + + + +@dataclass +class MoEQuantConfig(AOBaseConfig): + """Configuration for applying quantization to MoE + Args: + `base_config`: normal AO Config + """ + base_config: AOBaseConfig + +@register_quantize_module_handler(MoEQuantConfig) +def moe_quant_fn(module, config: MoEQuantConfig): + import warnings + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER + + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + assert isinstance(config.base_config, AOBaseConfig), ( + f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + +"this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" + ) + handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + + # break 3D tensor + tensors = [param[i] for i in range(param.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + out_mods = list(map(lambda x: handler(x, config.base_config), dummy_modules)) + # pack quantized subclasses into FakeExtraDimTensor + new_param = FakeExtraDimTensor([mod.weight for mod in out_mods]) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + +def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + +def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 995030df67..3967e7e658 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter( device, extra_args, ) - if new_child is not child: + if new_child is not child and new_child is not None: setattr(model, name, new_child) if device is not None: model.to(device=device) # move parent module to device diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index b6fac49ae9..a1147f8459 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -47,5 +47,6 @@ def _transform( @functools.wraps(config_type) def decorator(func): _QUANTIZE_CONFIG_HANDLER[config_type] = func + return func # needed to make the functions usable externally return decorator diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index f5bdfa9193..f999667c30 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -366,22 +366,23 @@ def get_groupwise_affine_qparams( def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) guard_dtype_size(zeros, "zeros", dtype=dtype) + dim = scales.dim() return ( torch.cat( [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), + scales.unsqueeze(-1), + zeros.unsqueeze(-1), ], - 2, + dim, ) - .transpose(0, 1) + .transpose(-3, -2) .contiguous() ) def unpack_tinygemm_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + assert scales_and_zeros.shape[-1] == 2 + return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): diff --git a/torchao/utils.py b/torchao/utils.py index c8465274ea..e7d69e697a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -170,7 +170,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): return measurement.mean * 1e6 -def find_multiple(n: int, *args: Tuple[int]) -> int: +def find_multiple(n: int, *args: int) -> int: k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n