diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py new file mode 100644 index 0000000000..1a070efeee --- /dev/null +++ b/test/quantization/test_moe_quant.py @@ -0,0 +1,255 @@ +import torch +import unittest +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable +from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor +from torchao.quantization.quant_api import ( + Int8WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Float8WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout, + quantize_, + AffineQuantizedTensor, + LinearActivationQuantizedTensor, + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.quantization.utils import compute_error +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl +from parameterized import param, parameterized + +class TestMoEQuantCompile(unittest.TestCase): + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k + + @torch.no_grad() + def _test_impl_moe_quant(self, + config, + num_tokens=1, + model_params=None, + base_class=AffineQuantizedTensor, + tensor_impl_class=None, + dtype=torch.bfloat16, + device="cuda", + fullgraph=False + ): + """ + Tests moe quant for techniques using fake extra dim + """ + if model_params is None: + model_params=self.DEFAULT_PARAMS + + input_shape = (num_tokens, model_params[0]) + model = MOEFeedForwardAOQuantizable(*model_params).to(dtype).to(device) + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + + out = model(input) + + import copy + new_mod = copy.deepcopy(model) + + quantize_(model, config, cond_ffn_filter) + + if isinstance(config, MoEQuantConfig): + self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) + if base_class is not None: + self.assertIsInstance(model.experts.w1.head_tensor, base_class) + if tensor_impl_class is not None: + self.assertIsInstance(model.experts.w1.head_tensor.tensor_impl, tensor_impl_class) + else: + if base_class is not None: + self.assertIsInstance(model.experts.w1, base_class) + if tensor_impl_class is not None: + self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) + + out_q = model(input) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph) + + model_c(input) + model_c(input) + out_qc = model_c(input).clone() + + for i in range(10): + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + model_c(input) + + self.assertGreaterEqual(compute_error(out_q, out), 10) + self.assertGreaterEqual(compute_error(out_qc, out), 10) + print(compute_error(out_q, out), compute_error(out_qc, out)) + + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ]) + def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int4WeightOnlyConfig()) + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ]) + def test_int4wo_base(self, name, num_tokens, fullgraph): + config = Int4WeightOnlyConfig() + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ]) + def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int8WeightOnlyConfig()) + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ]) + def test_int8wo_base(self, name, num_tokens, fullgraph): + config = Int8WeightOnlyConfig() + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("multiple_tokens", 32, False), + ]) + def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("multiple_tokens", 32, False), + ]) + def test_int8dq_base(self, name, num_tokens, fullgraph): + config = Int8DynamicActivationInt8WeightConfig() + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + + ]) + def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Float8WeightOnlyConfig()) + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ]) + def test_fp8wo_base(self, name, num_tokens, fullgraph): + config = Float8WeightOnlyConfig() + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + + ]) + def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph + ) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand([ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ]) + def test_fp8dq_base(self, name, num_tokens, fullgraph): + config = Float8DynamicActivationFloat8WeightConfig() + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph + ) + + + + + + + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py new file mode 100644 index 0000000000..a48717059e --- /dev/null +++ b/torchao/_models/mixtral-moe/generate.py @@ -0,0 +1,364 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config +torch.manual_seed(0) + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._dynamo.config.capture_scalar_outputs = True + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from sentencepiece import SentencePieceProcessor + +from model import Transformer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[:, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + batch_size: int, + *, + interactive: bool, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + device, dtype = prompt.device, prompt.dtype + + + T = prompt.size(-1) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt + + with torch.device(device): + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + + input_pos = torch.arange(0, T, device=device) + next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) + seq[:, T] = next_token.squeeze() + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + + return seq + +def encode_tokens(tokenizer, string, bos=True, device='cuda'): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision): + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + + model = model.to(device=device, dtype=precision) + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + batch_size: int = 1, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + moe_quant: Optional[str] = None, + profile: Optional[Path] = None, + device='cuda', +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print(f"Using device={device}") + precision = torch.bfloat16 + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + import torchao + from torchao.quantization.quant_api import ( + quantize_, + Int8WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Float8WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + PerRow, + Int8DynamicActivationIntxWeightConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout + ) + from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter + + if moe_quant: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + config = None + if "int8wo-base" in moe_quant: + config = Int8WeightOnlyConfig() + + elif "int8wo" in moe_quant: + config = MoEQuantConfig(Int8WeightOnlyConfig()) + + elif "int8dq-base" in moe_quant: + config = Int8DynamicActivationInt8WeightConfig() + + elif "int8dq" in moe_quant: + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + + elif "int4wo-base" in moe_quant: + config = Int4WeightOnlyConfig() + + elif "int4wo" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8wo-base" in moe_quant: + config = Int4WeightOnlyConfig() + + elif "fp8wo" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8dq-base" in moe_quant: + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + + elif "fp8dq" in moe_quant: + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + + elif "intxdq" in moe_quant: + config = Int8DynamicActivationIntxWeightConfig(layout = PackedLinearInt8DynamicActivationIntxWeightLayout()) + else: + assert config is not None, f"expected moe_quant to match one of the options but got {moe_quant}" + + if config is not None: + quantize_(model, config, filter_fn=cond_ffn_filter) + + + if compile: + # moe quant + compile causes repeated warnings + import warnings + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + + torch._inductor.config.assert_indirect_indexing = False + + global decode_one_token, prefill + + if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant): + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + + if args.compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + batch_size, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + pass + # print(tokenizer.decode(y[0].tolist())) + else: + print() + tokens_generated = y.size(-1) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + # print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + # print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') + parser.add_argument('--moe_quant', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--device', type=str, default="cuda", help='device to use') + + args = parser.parse_args() + print(args) + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.moe_quant, args.profile, args.device + ) diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py new file mode 100644 index 0000000000..297226c446 --- /dev/null +++ b/torchao/_models/mixtral-moe/model.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class ConditionalFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.dim) # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) + return out.reshape(batch_size, -1, self.dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.num_experts = config.num_experts + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + if x.shape[0]==1 and not isinstance(self.w1, FakeExtraDimTensor): #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + cur_out = F.linear( y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) + return final_out diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh new file mode 100644 index 0000000000..7b9c3bf567 --- /dev/null +++ b/torchao/_models/mixtral-moe/run.sh @@ -0,0 +1,39 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 +export CHECKPOINT_PATH=/data/users/cdhernandez/gpt-fast/checkpoints/ + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile + +# EXPERT CHOICE +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile +# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile +# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile + +# ARM +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --device cpu +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --compile --device cpu diff --git a/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py new file mode 100644 index 0000000000..b120c5c56d --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import glob +import json +import re +import sys +from pathlib import Path +from typing import Optional + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import ModelArgs + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + weight_map = { + "tok_embeddings.weight": "tok_embeddings.weight", + "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", + "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", + "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", + "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", + "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", + "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", + "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", + "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", + "norm.weight": "norm.weight", + "output.weight": "output.weight", + } + + pt_files = glob.glob(str(checkpoint_dir / "*.pt")) + + merged_result = {} + for file in sorted(pt_files): + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'.(\d+).', '.{}.', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] + elif "w1" in key or "w3" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() + elif "w2" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + elif "gate" in key: + final_result[key] = final_result[key].contiguous() + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') + parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/mistralai/Mixtral-8x7B-v0.1")) + parser.add_argument('--model_name', type=str, default=None) + + args = parser.parse_args() + convert_hf_checkpoint( + checkpoint_dir=args.checkpoint_dir, + model_name=args.model_name, + ) diff --git a/torchao/_models/mixtral-moe/scripts/download.py b/torchao/_models/mixtral-moe/scripts/download.py new file mode 100644 index 0000000000..7dc828004f --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/download.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: + from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + try: + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + except HTTPError as e: + if e.response.status_code == 401: + print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + else: + raise e + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') + parser.add_argument('--repo_id', type=str, default="checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1", help='Repository ID to download from.') + parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + args = parser.parse_args() + hf_download(args.repo_id, args.hf_token) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 50ef8c9e89..be84430067 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -460,11 +460,13 @@ def _(func, types, args, kwargs): shape[dim] = end - start block_size = self.block_size assert ( - len(block_size) == 2 - ), f"Slice only works for 2d block_size right now, got: {block_size}" + len(block_size) in [2,3] + ), f"Slice only works for 2 and 3d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + if len(block_size) == 2: + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__( aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, @@ -473,10 +475,50 @@ def _(func, types, args, kwargs): self.quant_max, self.zero_point_domain, dtype=self.dtype, - strides=self.stride(), + strides=self.stride() if len(block_size)==2 else None, ) return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + self, indices = args + assert len(indices) == 1, f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}" + + new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices) + shape = tuple([indices[0].numel(), *self.shape[1:]]) + + block_size = self.block_size + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = fill_defaults(args, 3, [0, 0]) + assert dim==0, f"op {func} currently only implemented for dim=0 but got dim={dim}" + assert self.dim() == 3, f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}" + + new_tensor_impl = aten.select.int(self.tensor_impl, dim, index) + + shape = self.shape[1:] + block_size = self.block_size[1:] + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index cd1d9a8ac0..15c7dd7599 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -159,6 +159,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func is aten.select.int or func is aten.index.Tensor: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index b4b9e06f1a..042ab04564 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -154,6 +154,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) + + elif func in [aten.select.int, aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ), + ) + elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 901c4c4640..3890010d61 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -93,11 +93,13 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: + y=act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] @@ -119,7 +121,7 @@ class TensorCoreTiledLayout(Layout): inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -160,7 +162,7 @@ def post_process( zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -168,10 +170,10 @@ def post_process( (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) assert ( - len(block_size) == 2 - ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}" - scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0] - scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1] + len(block_size) == 2 or len(block_size) == 3, + ), f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}" + scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2] + scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1] scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) zero_point = torch.nn.functional.pad( zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) @@ -262,21 +264,29 @@ def from_plain( _layout: Layout, ): assert isinstance(_layout, TensorCoreTiledLayout) - - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert ( - int_data.dtype == torch.uint8 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + + def quant_2d(int_data_2d): + if TORCH_VERSION_AT_LEAST_2_5: + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to(torch.uint8) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d, _layout.inner_k_tiles + ) + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) if zero_point is not None else None else: - assert ( - int_data.dtype == torch.int32 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) if zero_point is not None else None from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) @@ -336,6 +346,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + if func in [aten.select.int, aten.index.Tensor]: + assert not (func is aten.select.int and args[1]!=0), "aten.select.int currently only has support for dim=0" + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ), + ) + + if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -386,11 +408,16 @@ def block_size(self): scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape - assert len(cur_shape) == 4 + if len(cur_shape) == 5: + ones = [1,1] + cur_shape = cur_shape[1:] + else: + assert len(cur_shape) == 4 + ones = [1] inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) groupsize = int(original_shape[1] / scale.shape[-2]) - return (1, groupsize) + return tuple([*ones, groupsize]) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( @@ -399,35 +426,53 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + def dequant_4d(self): + cur_shape = self.shape + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + original_dtype = torch.bfloat16 + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=self.device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + return dequantized + + cur_shape = self.shape + + if len(cur_shape)==4: + dequantized = dequant_4d(self) + else: + assert len(cur_shape) == 5 + num_experts = cur_shape[0] + dequantized_list = [] + for expert in range(num_experts): + dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0)) + dequantized = torch.cat(dequantized_list, dim=0) + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() - cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) device = self.device - original_dtype = torch.bfloat16 + target_dtype = torch.int32 quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() int_data = quantize_affine( dequantized, - block_size, + self.block_size, scale, zero, target_dtype, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 94fcebd9d4..938b03a62c 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -428,6 +428,51 @@ def test_moved_error(self): granularity=PerGroup(64), ) + def test_moe_quant_intx(self): + from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable + from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor + from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig, PackedLinearInt8DynamicActivationIntxWeightLayout, quantize_ + from torchao.quantization.utils import compute_error + + with torch.device("cpu"): + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2).to(torch.bfloat16) + x = torch.randn(8, 512, dtype=torch.bfloat16) + + out = model(x).clone() + + # base_config = Int8DynamicActivationIntxWeightConfig() + base_config = Int8DynamicActivationIntxWeightConfig(layout = PackedLinearInt8DynamicActivationIntxWeightLayout()) + moe_config = MoEQuantConfig(base_config) + + quantize_(model, moe_config, cond_ffn_filter) + + + out_q = model(x).clone() + assert isinstance(model.experts.w1, FakeExtraDimTensor) + + mod_c = torch.compile(model, mode="reduce-overhead") + + mod_c(x) + mod_c(x) + + + out_qc = mod_c(x).clone() + + print(compute_error(out_q, out)) + print(compute_error(out_qc, out)) + + assert compute_error(out_q, out)>30 and compute_error(out_qc, out)>30, "error bad accuracy but everything ran" + + + + + + + + + + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4343a086f..fdb57ea07b 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -82,6 +82,8 @@ def __tensor_unflatten__( def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): + if input_tensor.numel() == 0: + return input_tensor input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor quant_kwargs = weight_tensor.quant_kwargs @@ -242,6 +244,31 @@ def _(func, types, args, kwargs): ), ) +@implements(aten.select.int) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) + +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) diff --git a/torchao/quantization/prototype/moe_quant/__init__.py b/torchao/quantization/prototype/moe_quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py new file mode 100644 index 0000000000..c5092fd336 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# import torch +# from tabulate import tabulate +# from transformers import AutoModelForCausalLM, AutoTokenizer + +# try: +# from lm_eval.evaluator import evaluate +# from lm_eval.models.huggingface import HFLM +# from lm_eval.tasks import get_task_dict +# except ImportError: +# print(""" +# Error: The 'lm_eval' module was not found. +# To install, follow these steps: +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git +# """) +# raise # Re-raise the ImportError + +# from torchao.quantization import ( +# autoquant, +# fpx_weight_only, +# int4_weight_only, +# int8_dynamic_activation_int8_weight, +# int8_weight_only, +# quantize_, +# ) +# from torchao.sparsity import ( +# semi_sparse_weight, +# sparsify_, +# ) + +# torch._inductor.config.force_fuse_int_mm_with_mul = True +# torch._inductor.config.fx_graph_cache = True + + +# def pretty_print_nested_results(results, precision: int = 6): +# def format_value(value): +# if isinstance(value, float): +# return f"{value:.{precision}f}" +# return value + +# main_table = [] +# for task, metrics in results["results"].items(): +# subtable = [[k, format_value(v)] for k, v in metrics.items() if k != "alias"] +# subtable.sort(key=lambda x: x[0]) # Sort metrics alphabetically +# formatted_subtable = tabulate(subtable, tablefmt="grid") +# main_table.append([task, formatted_subtable]) + +# print(tabulate(main_table, headers=["Task", "Metrics"], tablefmt="grid")) + + +# def run_evaluation( +# repo_id, +# tasks, +# limit, +# device, +# precision, +# quantization, +# sparsity, +# compile, +# save, +# batch_size, +# max_length, +# ): +# tokenizer = AutoTokenizer.from_pretrained(repo_id) +# model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to( +# device +# ) + +# if quantization == "autoquant" and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# if quantization == "int8dq": +# quantize_(model, int8_dynamic_activation_int8_weight()) +# elif quantization == "int8wo": +# quantize_(model, int8_weight_only()) +# elif quantization == "int4wo": +# # note cannot quantize this model on cpu and run it on cuda at this time +# quantize_(model.to(device=device), int4_weight_only()) +# elif quantization == "fp6": +# quantize_(model, fpx_weight_only(3, 2)) +# elif quantization == "autoquant": +# model = autoquant(model.to(device=device)) +# elif quantization == "awq": +# from torchao.prototype.awq.example import get_calib_dataset +# from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + +# if not TORCH_VERSION_AT_LEAST_2_3: +# print("AWQ quantization requires torch2.3+") +# exit() +# from torchao.prototype.awq import ( +# AWQObservedLinear, +# awq_uintx, +# insert_awq_observer_, +# ) + +# quant_dtype = torch.uint4 +# group_size = 64 +# calibration_limit = 10 +# calibration_seq_length = 1024 +# model = model.to(device) +# insert_awq_observer_( +# model, +# calibration_limit, +# calibration_seq_length, +# quant_dtype=quant_dtype, +# group_size=group_size, +# ) +# with torch.no_grad(): +# calibration_data = get_calib_dataset( +# tokenizer=tokenizer, +# n_samples=calibration_limit, +# block_size=calibration_seq_length, +# ) +# for batch in calibration_data: +# model(batch.to(device)) +# del batch +# is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) +# quantize_( +# model, +# awq_uintx(quant_dtype=quant_dtype, group_size=group_size), +# is_observed_linear, +# ) + +# if quantization != "autoquant" and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# if sparsity == "semi_sparse": + +# def all_linear(mod, name): +# if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: +# return True +# return False + +# torch.sparse.semi_structured._FORCE_CUTLASS = False +# sparsify_(model, semi_sparse_weight(), filter_fn=all_linear) +# elif sparsity == "semi_sparse_mlp_only": + +# def all_linear(mod, name): +# if ( +# isinstance(mod, torch.nn.Linear) +# and "lm_head" not in name +# and "mlp" in name +# ): +# return True +# return False + +# torch.sparse.semi_structured._FORCE_CUTLASS = False +# sparsify_(model, semi_sparse_weight(), filter_fn=all_linear) + +# if sparsity and compile: +# model = torch.compile(model, mode="max-autotune", fullgraph=True) + +# with torch.no_grad(): +# result = evaluate( +# HFLM( +# pretrained=model.to(device), +# tokenizer=tokenizer, +# batch_size=batch_size, +# max_length=max_length, +# ), +# get_task_dict(tasks), +# limit=limit, +# ) + +# pretty_print_nested_results(result) + +# if save: +# # This doesn't work yet: https://github.com/huggingface/transformers/issues/32364 +# # model.save_pretrained("quantized_model_test", safe_serialization=False) +# file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt" +# torch.save(model.state_dict(), file_name) + + +# if __name__ == "__main__": +# import argparse + +# parser = argparse.ArgumentParser(description="Run HF Model Evaluation") +# parser.add_argument( +# "--repo_id", +# type=str, +# default="meta-llama/Meta-Llama-3-8B", +# help="Repository ID to download from HF.", +# ) +# parser.add_argument( +# "--tasks", +# nargs="+", +# type=str, +# default=["wikitext"], +# help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", +# ) +# parser.add_argument( +# "--limit", type=int, default=None, help="Number of eval samples to evaluate" +# ) +# parser.add_argument( +# "--precision", +# type=lambda x: getattr(torch, x.split(".")[-1]), +# default=torch.bfloat16, +# help="dtype precision to use", +# ) +# parser.add_argument( +# "--device", type=str, default="cuda", help="Device to use for evaluation" +# ) +# parser.add_argument( +# "-q", +# "--quantization", +# default="None", +# choices=["int8dq", "int8wo", "int4wo", "autoquant", "awq", "None"], +# help="Which quantization technique to apply", +# ) +# parser.add_argument( +# "-s", +# "--sparsity", +# default="None", +# choices=["semi_sparse", "semi_sparse_mlp_only", "None"], +# help="Which sparsity technique to apply", +# ) +# parser.add_argument( +# "--compile", action="store_true", help="Whether to compile the model." +# ) +# parser.add_argument( +# "--save", action="store_true", help="Whether to save the model." +# ) +# parser.add_argument( +# "--batch_size", +# type=int, +# default=1, +# help="Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes", +# ) +# parser.add_argument( +# "--max_length", +# type=int, +# default=None, +# help="Length of text to process at one time", +# ) + +# args = parser.parse_args() +# run_evaluation( +# args.repo_id, +# args.tasks, +# args.limit, +# args.device, +# args.precision, +# args.quantization, +# args.sparsity, +# args.compile, +# args.save, +# args.batch_size, +# args.max_length, +# ) + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from transformers import Llama4ForCausalLM, AutoTokenizer + +from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, hidden_dim, expert_dim, num_experts, top_k, act_fn=F.silu, shared_expert=None) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable(num_experts, hidden_dim, expert_dim, act_fn) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert = shared_expert + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.router(x) # [T, E] + scores = F.softmax(scores, dim=-1) + scores, expert_indices = torch.topk(scores, self.top_k, dim=-1) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + try: + if self.shared_expert: + out += self.shared_expert(x) + except Exception as e: + print(e) + return out.reshape(batch_size, -1, self.hidden_dim), scores + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn): + super().__init__() + self.w1 = nn.Parameter(torch.empty(num_experts, expert_dim, hidden_dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, expert_dim)) # E, D, I + self.w3 = nn.Parameter(torch.empty(num_experts, expert_dim, hidden_dim)) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + self.hidden_dim = hidden_dim + self.expert_dim = expert_dim + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, _hidden_dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices].view(num_token_activations, self.expert_dim, self.hidden_dim) + w2 = self.w2[expert_indices].view(num_token_activations, self.hidden_dim, self.expert_dim) + w3 = self.w3[expert_indices].view(num_token_activations, self.expert_dim, self.hidden_dim) + + # run token through each expert + for index in range(num_activated_experts): + cur_out = F.linear( self.act_fn(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).reshape(x.shape) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + device = expert_indices.device + num_tokens_per_expert = torch.histc(expert_indices.to("cuda"), bins=self.num_experts+1, min=-1, max=self.num_experts).to(device) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + cur_out = F.linear( self.act_fn(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations, self.hidden_dim).to(torch.int64), src=weighted_ordered_outs) + return final_out + + + +def llama4_moe_filter_fn(module, fqn): + return isinstance(module, Llama4TextMoe) + +def convert_fn(module): + # get data + hidden_dim = module.hidden_dim + expert_dim = module.experts.expert_dim + num_experts = module.num_experts + top_k = module.top_k + act_fn = module.experts.act_fn + shared_expert = module.shared_expert + new_mod = MOEFeedForwardAOQuantizable( + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn, + shared_expert, + ) + + router = module.router + up_proj = module.experts.gate_up_proj + w1, w3 = up_proj.permute(0,2,1).chunk(2, dim=1) + w2 = module.experts.down_proj.permute(0,2,1) + + new_mod.router = router + new_mod.experts.w1 = nn.Parameter(w1, requires_grad=False) + new_mod.experts.w2 = nn.Parameter(w2, requires_grad=False) + new_mod.experts.w3 = nn.Parameter(w3, requires_grad=False) + return new_mod + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +_replace_with_custom_fn_if_matches_filter( + model, + convert_fn, + llama4_moe_filter_fn, +) + +model = model + + + +prompt = "He is here, teh one who will tear apart the very stars in heaven." +inputs = tokenizer(prompt, return_tensors="pt") + +# import fbvscode; fbvscode.set_trace() +generate_ids = model.generate(inputs.input_ids, max_length=30) +out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] +print(out) diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py new file mode 100644 index 0000000000..b4b59116d6 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -0,0 +1,108 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, hidden_dim, expert_dim, num_experts, top_k, act_fn=F.silu, shared_expert=None) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable(num_experts, hidden_dim, expert_dim, act_fn) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert= shared_expert + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.router(x) # [T, E] + scores = F.softmax(scores, dim=-1) + scores, expert_indices = torch.topk(scores, self.top_k, dim=-1) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + if self.shared_expert: + out += self.shared_expert(x) + return out.reshape(batch_size, -1, self.hidden_dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn): + super().__init__() + self.w1 = nn.Parameter(torch.randn(num_experts, expert_dim, hidden_dim)) # E, I, D + self.w2 = nn.Parameter(torch.randn(num_experts, hidden_dim, expert_dim)) # E, D, I + self.w3 = nn.Parameter(torch.randn(num_experts, expert_dim, hidden_dim)) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1 and not isinstance(self.w1, FakeExtraDimTensor): + #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + + cur_out = F.linear( y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] + + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]].to(torch.int64) for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + y1 = F.silu(F.linear(cur_x, w1)) + y3 = F.linear(cur_x, w3) + y2 = w2 + + cur_out = F.linear(y1 * y3, y2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) + return final_out diff --git a/torchao/quantization/prototype/moe_quant/run.sh b/torchao/quantization/prototype/moe_quant/run.sh new file mode 100644 index 0000000000..d7b819a828 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/run.sh @@ -0,0 +1 @@ +python llama4_quant.py #--repo_id "/meta-llama/Llama-4-Maverick-17B-128E-Instruct" diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py new file mode 100644 index 0000000000..debaa879bc --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -0,0 +1,239 @@ +import torch + +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +aten = torch.ops.aten + +from torchao.utils import fill_defaults + +from torchao.quantization.quant_api import AOBaseConfig, register_quantize_module_handler, dataclass +from typing import Union, List, Tuple, Optional + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]=None): + super().__init__() + self.weight = weight + self.bias = bias + +class FakeExtraDimTensor(torch.Tensor): + """This is a subclass of torch.Tensor that simulates a tensor of n+1 dimensions, akin to concatenating several tensors along the 0th dimension. + It takes a list of tensors with the same dtype, device and shape and creates a representation of shape (num_tensors, orig_shape). It can handle a + variety of ops like detach and clone but most importantly, supports any slicing and indexing along the extra dimension. + This is most useful when you have another tensor subclass that you'd like to concatenate together but don't want to support all the necessary + pieces of 3D scaffolding required to make it work. + + The structure of this tensor subclass is a linked_list of tensors with each instance of FakeExtraDimTensor containing a head tensor and a tail consisting of + either another intance of FakeExtraDimTensor or None if we've reached the end of the linked list. This implementation structure is necessary to + support compilation of this tensor subclass since compile requires each tensor component of the tensor subclass to have its own attribute. + """ + def __new__( + cls, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"]=None, + ): + assert len(tensors)>0 or tensor_tail is not None + num_tensors = len(tensors) + if tensor_tail is not None: + num_tensors += tensor_tail.num_tensors + test_tensor = tensor_tail.head_tensor + else: + test_tensor = tensors[0] + + dtype = test_tensor.dtype + shape = test_tensor.shape + device = test_tensor.device + layout = test_tensor.layout + for tensor in tensors: + assert tensor.dtype==dtype, f"all tensors in FakeExtraDimTensor must have same dtype but got {tensor.dtype} and {dtype}" + assert tensor.shape==shape, f"all tensors in FakeExtraDimTensor must have same shape but got {tensor.shape} and {shape}" + assert tensor.device == device, f"all tensors in FakeExtraDimTensor must have same device but got {tensor.device} and {device}" + assert tensor.layout == layout, f"all tensors in FakeExtraDimTensor must have same layout but got {tensor.layout} and {layout}" + kwargs = {} + kwargs["dtype"] = dtype + kwargs["layout"] = layout + kwargs["device"] = device + kwargs["requires_grad"]=False + new_shape = (num_tensors, *shape) + return torch.Tensor._make_wrapper_subclass(cls, new_shape, **kwargs) + + def __repr__( + self, + ): + return f"{self.__class__.__name__}(shape={self.shape}, containing {self.num_tensors}: {self.head_tensor})" + + def __init__( + self, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"]=None, + ): + tensors = list(tensors) + assert len(tensors)>0 or tensor_tail is not None + + # count num_tensors and make tensor_list + self.num_tensors = len(tensors) + if tensor_tail is not None: + self.num_tensors += tensor_tail.num_tensors + tail_list = tensor_tail.tensor_list + else: + tail_list = [] + self.tensor_list = tensors + tail_list + + # 3 cases + # 0) tensors has 0 elements -> take element from tail then do case 1 instead + # 1) tensors has 1 element, -> pop element and tail is None + # 2) tensors has >1 elements, -> pop element and recurse + + # convert case 0 to case 1 by taking 1 element from tail + if len(tensors) == 0 and tensor_tail is not None: + tensors = [tensor_tail.head_tensor,] + tensor_tail = tensor_tail.tensor_tail + + if len(tensors) > 1: + # case (1): remove first element from tensors, then recurse + self.head_tensor = tensors[0] # remove one + self.tensor_tail = self.__class__(tensors[1:], tensor_tail) # recurse + elif len(tensors) == 1: + # case (2) take final element from tensors, attach tensor_tail then stop recursion + self.head_tensor = tensors[0] + self.tensor_tail = tensor_tail + + def _apply_fn_to_data(self, fn): + self.head_tensor = fn(self.head_tensor) + if self.tensor_tail is not None: + self.tensor_tail = self.tensor_tail._apply_fn_to_data(fn) + return self.__class__([self.head_tensor], self.tensor_tail) + + def __tensor_flatten__(self): + try: + if self.tensor_tail is None: + return ["head_tensor",], [self.num_tensors] + else: + return ["head_tensor", "tensor_tail",], [self.num_tensors] + except: + import fbvscode; fbvscode.set_trace() + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, + ): + head_tensor = tensor_data_dict["head_tensor"] + tensor_tail = tensor_data_dict.get("tensor_tail", None) + return cls([head_tensor], tensor_tail) + + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if func is torch.nn.functional.linear: + x, w, bias = ( + args[0], args[1], args[2] if len(args) > 2 else None, + ) + assert w.num_tensors == 1, "FakeExtraDimTensor used in a linear op when it had multiple tensors" + return func(x, w.head_tensor, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception as e: + print(f"ERR: subclass {cls} doesn't implement {func}, got error: {e}") + + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func == aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim==0: + return return_and_correct_aliasing( + func, + args, + kwargs, + cls(self.tensor_list[start:end:step]) + ) + + elif func == aten.select.int: + self, dim, index = fill_defaults(args, 3, [0, 0]) + if dim==0: + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index]]) + ) + elif func == aten.index.Tensor: + self, indices, dim = fill_defaults(args, 3, [0]) + if dim==0: + try: + # this handles a weird bug where indices gets turned into a list + # between the function dispatch and torch dispatch but just for this function + if isinstance(indices, list) and len(indices)==1: + indices = indices[0] + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index] for index in indices]) + ) + except Exception as e: + import fbvscode; fbvscode.set_trace() + try: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data( + lambda x: func(x, *args[1:], **kwargs) + ) + ) + except Exception as e: + print( + f"function {func} failed for FakeExtraDimTensor, following error occured when trying to" + "run function on its elements: " + ) + raise e + +@dataclass +class MoEQuantConfig(AOBaseConfig): + """Configuration for applying quantization to MoE + Args: + `base_config`: normal AO Config + """ + base_config: AOBaseConfig + +@register_quantize_module_handler(MoEQuantConfig) +def moe_quant_fn(module, config: MoEQuantConfig): + import warnings + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER + + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + assert isinstance(config.base_config, AOBaseConfig), ( + f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + +"this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" + ) + handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + + # break 3D tensor + tensors = [param[i] for i in range(param.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + out_mods = list(map(lambda x: handler(x, config.base_config), dummy_modules)) + # pack quantized subclasses into FakeExtraDimTensor + new_param = FakeExtraDimTensor([mod.weight for mod in out_mods]) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + +def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + +def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 995030df67..4c09e22810 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter( device, extra_args, ) - if new_child is not child: + if new_child is not child and new_child is not None: setattr(model, name, new_child) if device is not None: model.to(device=device) # move parent module to device @@ -993,32 +993,25 @@ class Int4WeightOnlyConfig(AOBaseConfig): # TODO maybe change other callsites int4_weight_only = Int4WeightOnlyConfig - -@register_quantize_module_handler(Int4WeightOnlyConfig) -def _int4_weight_only_transform( - module: torch.nn.Module, config: Int4WeightOnlyConfig -) -> torch.nn.Module: +def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details # for now, make these local variables to allow the rest of the function # to be a direct copy-paste - weight = module.weight group_size = config.group_size layout = config.layout use_hqq = config.use_hqq zero_point_domain = config.zero_point_domain - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() if weight.shape[-1] % group_size != 0: logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) - return module + return weight mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + block_size = tuple([1 for _ in range(weight.dim()-1)]+[group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1070,9 +1063,31 @@ def _int4_weight_only_transform( _layout=layout, use_hqq=use_hqq, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int4_weight_only_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + f"applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int4_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1089,20 +1104,15 @@ class Int8WeightOnlyConfig(AOBaseConfig): int8_weight_only = Int8WeightOnlyConfig -@register_quantize_module_handler(Int8WeightOnlyConfig) -def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): - group_size = config.group_size - weight = module.weight - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 + group_size = config.group_size if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim()-1)] + [group_size]) new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1111,9 +1121,29 @@ def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyC eps=eps, zero_point_dtype=zero_point_dtype, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int8_weight_only_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + f"applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int8_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -1226,34 +1256,26 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): # for BC int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig - -@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) -def _int8_dynamic_activation_int8_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig -) -> torch.nn.Module: +def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type weight_only_decode = config.weight_only_decode - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - weight = module.weight - - in_features = weight.shape[1] + in_features = weight.shape[-1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: logger.info( f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" f" because `in_feature` is <= 16: {in_features}" ) - return module + return weight # weight settings mapping_type = MappingType.SYMMETRIC weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): - return (1, x.shape[1]) + return tuple([1 for _ in range(x.dim()-1)] + [x.shape[-1]]) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -1269,7 +1291,7 @@ def get_weight_block_size(x): input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -1279,9 +1301,31 @@ def get_weight_block_size(x): _layout=layout, zero_point_domain=weight_zero_point_domain, ) - weight = to_linear_activation_quantized(weight, input_quant_func) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + new_weight = to_linear_activation_quantized(new_weight, input_quant_func) + return new_weight + +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + f"applying int8 dynamic activation int8 weight quant requires module to have weight attribute" + + "but {module} does not have one" + ) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1318,18 +1362,9 @@ class Float8WeightOnlyConfig(AOBaseConfig): # for BC float8_weight_only = Float8WeightOnlyConfig - -@register_quantize_module_handler(Float8WeightOnlyConfig) -def _float8_weight_only_transform( - module: torch.nn.Module, config: Float8WeightOnlyConfig -) -> torch.nn.Module: +def _float8_weight_only_quant_tensor(weight, config): from torchao.dtypes import to_affine_quantized_floatx - - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - block_size = (1, weight.shape[1]) + block_size = tuple([1 for _ in range(weight.dim()-1)] + [weight.shape[-1]]) new_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1337,9 +1372,33 @@ def _float8_weight_only_transform( scale_dtype=None, _layout=Float8Layout(mm_config=None), ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module + return new_weight + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + new_weight = _float8_weight_only_quant_tensor(weight, config) + new_weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + f"applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) + + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1441,10 +1500,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: bool: True if the tensor can be quantized to float8, False otherwise """ assert ( - weight.dim() == 2 - ), f"float8 quantization only works for 2-D tensors, got {weight.dim()}D tensor" + weight.dim() in [2, 3] + ), f"float8 quantization only works for 2/3-D tensors, got {weight.dim()}D tensor" - out_dim, in_dim = weight.shape + out_dim, in_dim = weight.shape[-2:] is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0) if not is_compatible: @@ -1490,35 +1549,26 @@ def __post_init__(self): # for bc float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig - -@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) -def _float8_dynamic_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig -): - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config - weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked - return module + return weight if isinstance(weight_granularity, PerRow): assert ( weight.dtype == torch.bfloat16 ), "PerRow quantization only works for bfloat16 precision input weight" - block_size = get_block_size(weight.shape, weight_granularity) + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1]+list(block_size)) quantized_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1536,10 +1586,35 @@ def _float8_dynamic_activation_float8_weight_transform( quantized_weight = to_linear_activation_quantized( quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs ) + return quantized_weight - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + + if "ConditionalFeedForwardAOQuantizable" in str(type(module)): + for weight_attr in ["w1", "w2", "w3"]: + weight = getattr(module, weight_attr) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config) + new_weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + setattr(module, weight_attr, new_weight) + return module + else: + assert hasattr(module, "weight"), ( + f"applying float8 dynamic activation quant requires module to have weight attribute" + + "but {module} does not have one" + ) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index b6fac49ae9..a1147f8459 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -47,5 +47,6 @@ def _transform( @functools.wraps(config_type) def decorator(func): _QUANTIZE_CONFIG_HANDLER[config_type] = func + return func # needed to make the functions usable externally return decorator diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index f5bdfa9193..f999667c30 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -366,22 +366,23 @@ def get_groupwise_affine_qparams( def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) guard_dtype_size(zeros, "zeros", dtype=dtype) + dim = scales.dim() return ( torch.cat( [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), + scales.unsqueeze(-1), + zeros.unsqueeze(-1), ], - 2, + dim, ) - .transpose(0, 1) + .transpose(-3, -2) .contiguous() ) def unpack_tinygemm_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + assert scales_and_zeros.shape[-1] == 2 + return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): diff --git a/torchao/utils.py b/torchao/utils.py index c8465274ea..e7d69e697a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -170,7 +170,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): return measurement.mean * 1e6 -def find_multiple(n: int, *args: Tuple[int]) -> int: +def find_multiple(n: int, *args: int) -> int: k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n