diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py new file mode 100644 index 0000000000..842468a769 --- /dev/null +++ b/test/quantization/test_moe_quant.py @@ -0,0 +1,361 @@ +import unittest + +import torch +from parameterized import parameterized + +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, +) +from torchao.quantization.prototype.moe_quant.utils import ( + FakeExtraDimTensor, + MoEQuantConfig, + UseFakeExtraDimTensor, + cond_ffn_filter, +) +from torchao.quantization.quant_api import ( + AffineQuantizedTensor, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + LinearActivationQuantizedTensor, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_90, +) + + +class TestMoEQuantCompile(unittest.TestCase): + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k + + @torch.no_grad() + def _test_impl_moe_quant( + self, + config, + num_tokens=1, + model_params=None, + base_class=AffineQuantizedTensor, + tensor_impl_class=None, + dtype=torch.bfloat16, + device="cuda", + fullgraph=False, + ): + """ + Tests moe quant for techniques using fake extra dim + """ + if model_params is None: + model_params = self.DEFAULT_PARAMS + + input_shape = (num_tokens, model_params[0]) + model = ( + MOEFeedForwardAOQuantizable(*model_params, empty_init=False) + .to(dtype) + .to(device) + ) + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + + out = model(input) + + quantize_(model, config, cond_ffn_filter) + + if ( + isinstance(config, MoEQuantConfig) + and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE + ): + self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) + if base_class is not None: + self.assertIsInstance(model.experts.w1.head_tensor, base_class) + if tensor_impl_class is not None: + self.assertIsInstance( + model.experts.w1.head_tensor.tensor_impl, tensor_impl_class + ) + else: + if base_class is not None: + self.assertIsInstance(model.experts.w1, base_class) + if tensor_impl_class is not None: + self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) + + out_q = model(input) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph) + + model_c(input) + model_c(input) + out_qc = model_c(input).clone() + + for i in range(10): + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + model_c(input) + + self.assertGreaterEqual(compute_error(out_q, out), 10) + self.assertGreaterEqual(compute_error(out_qc, out), 10) + + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + + config = MoEQuantConfig( + Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int4wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + + config = MoEQuantConfig(Int4WeightOnlyConfig()) + tensor_impl_class = TensorCoreTiledAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + + config = MoEQuantConfig( + Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_6: + self.skipTest("Test only enabled for 2.6+") + + config = MoEQuantConfig(Int8WeightOnlyConfig()) + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): + if not TORCH_VERSION_AT_LEAST_2_6: + self.skipTest("Test only enabled for 2.6+") + + config = MoEQuantConfig(Int8WeightOnlyConfig()) + tensor_impl_class = PlainAQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + device="cpu", + ) + + @parameterized.expand( + [ + ("multiple_tokens", 32, False), + ] + ) + def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + + config = MoEQuantConfig( + Int8DynamicActivationInt8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("multiple_tokens", 32, False), + ] + ) + def test_int8dq_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest("Test only enabled for 2.5+") + + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + model_params=(512, 256, 2, 2), + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + config = MoEQuantConfig( + Float8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8wo_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + config = MoEQuantConfig(Float8WeightOnlyConfig()) + tensor_impl_class = Float8AQTTensorImpl + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + tensor_impl_class=tensor_impl_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, False), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + @parameterized.expand( + [ + ("single_token", 1, True), + ("multiple_tokens", 8, False), + ] + ) + def test_fp8dq_base(self, name, num_tokens, fullgraph): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) + base_class = LinearActivationQuantizedTensor + + self._test_impl_moe_quant( + config=config, + num_tokens=num_tokens, + base_class=base_class, + fullgraph=fullgraph, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/_models/mixtral-moe/README.md b/torchao/_models/mixtral-moe/README.md new file mode 100644 index 0000000000..22c318aab9 --- /dev/null +++ b/torchao/_models/mixtral-moe/README.md @@ -0,0 +1,8 @@ +## Mixtral-MoE + +This folder contains code and scripts for benchmarking the Mixtral-MoE model. +Running + +`sh scripts/prepare.sh` + +should download the model and `sh run.sh` will run teh benchmarks. diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py new file mode 100644 index 0000000000..0dcd86e74f --- /dev/null +++ b/torchao/_models/mixtral-moe/generate.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +from torchao.utils import get_model_size_in_bytes + +torch.manual_seed(0) + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._dynamo.config.capture_scalar_outputs = True + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import Transformer +from sentencepiece import SentencePieceProcessor + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[:, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + batch_size: int, + *, + interactive: bool, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + device, _ = prompt.device, prompt.dtype + + T = prompt.size(-1) + max_seq_length = ( + min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + ) + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt + + with torch.device(device): + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) + + input_pos = torch.arange(0, T, device=device) + next_token = prefill( + model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs + ) + seq[:, T] = next_token.squeeze() + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(batch_size, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1) + + return seq + + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision): + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +B_INST, E_INST = "[INST]", "[/INST]" + + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + batch_size: int = 1, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + moe_quant: Optional[str] = None, + profile: Optional[Path] = None, + memory_profile: Optional[Path] = None, + device="cuda", +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + print(f"Using device={device}") + precision = torch.bfloat16 + is_chat = "chat" in str(checkpoint_path) + + if device == "cuda" and memory_profile is not None: + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=500000, trace_alloc_record_context=True + ) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, "cpu", precision) + + print(f"Time to load model: {time.time() - t0:.02f} seconds") + t0 = time.time() + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(model.parameters(), model.buffers()) + ] + ) + + from torchao.quantization.prototype.moe_quant.utils import ( + MoEQuantConfig, + UseFakeExtraDimTensor, + cond_ffn_filter, + ) + from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8DynamicActivationIntxWeightConfig, + Int8WeightOnlyConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout, + PerRow, + quantize_, + ) + + if moe_quant: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + config = None + if "int8wo-base" in moe_quant: + config = MoEQuantConfig(Int8WeightOnlyConfig()) + + elif "int8wo" in moe_quant: + config = MoEQuantConfig( + Int8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + + elif "int8dq-base" in moe_quant: + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + + elif "int8dq" in moe_quant: + config = MoEQuantConfig( + Int8DynamicActivationInt8WeightConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + + elif "int4wo-base" in moe_quant: + config = MoEQuantConfig(Int4WeightOnlyConfig()) + + elif "int4wo" in moe_quant: + config = MoEQuantConfig( + Int4WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + + elif "fp8wo-base" in moe_quant: + config = MoEQuantConfig(Float8WeightOnlyConfig()) + + elif "fp8wo" in moe_quant: + config = MoEQuantConfig( + Float8WeightOnlyConfig(), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + + elif "fp8dq-base" in moe_quant: + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) + + elif "fp8dq" in moe_quant: + config = MoEQuantConfig( + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + + elif "intxdq" in moe_quant: + config = MoEQuantConfig( + Int8DynamicActivationIntxWeightConfig( + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + ) + else: + assert config is not None, ( + f"expected moe_quant to match one of the options but got {moe_quant}" + ) + + if config is not None: + quantize_(model, config, filter_fn=cond_ffn_filter, device=device) + print( + f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" + ) + + model.to(device=device) + device_sync(device=device) + + if compile: + # moe quant + compile causes repeated warnings + import warnings + + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + + torch._inductor.config.assert_indirect_indexing = False + + global decode_one_token, prefill + + if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant): + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + + if args.compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + aggregate_metrics = { + "tokens_per_sec": [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode(".")[0] + done_generating = False + + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print("".join(buffer), end="", flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x: x + t0 = time.perf_counter() + import contextlib + + if i != num_samples - 1 or not profile: + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + batch_size, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + pass + print(tokenizer.decode(y[0].tolist())) + else: + print() + tokens_generated = y.size(-1) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + + if i == 0 and device == "cuda" and memory_profile is not None: + snapshot = torch.cuda.memory._snapshot() + with open(f"{memory_profile}.pickle", "wb") as f: + from pickle import dump + + dump(snapshot, f) + print( + f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", + ) + break + + tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size * tokpersec:.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + print(f"model size: {get_model_size_in_bytes(model) / 1e9:.02f}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Your CLI description.") + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to benchmark with" + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') + parser.add_argument( + "--moe_quant", + type=str, + help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument( + "--memory_profile", type=Path, default=None, help="filename for memory profile." + ) + parser.add_argument("--device", type=str, default="cuda", help="device to use") + + args = parser.parse_args() + print(args) + main( + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.batch_size, + args.top_k, + args.temperature, + args.checkpoint_path, + args.compile, + args.compile_prefill, + args.moe_quant, + args.profile, + args.memory_profile, + args.device, + ) diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py new file mode 100644 index 0000000000..46a4ce79be --- /dev/null +++ b/torchao/_models/mixtral-moe/model.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-Instruct-v0.1": dict( + block_size=32768, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + rope_base=1000000.0, + num_experts=8, + num_activated_experts=2, + ), +} + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +# class ConditionalFeedForward(nn.Module): +# def __init__(self, config): +# super().__init__() +# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) +# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) +# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + +# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: +# w1_weights = self.w1[expert_indices] # [T, A, D, D] +# w3_weights = self.w3[expert_indices] # [T, A, D, D] +# w2_weights = self.w2[expert_indices] # [T, A, D, D] +# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) +# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) +# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) +# return expert_outs + + +# class MOEFeedForward(nn.Module): +# def __init__(self, config) -> None: +# super().__init__() +# self.gate = nn.Linear(config.dim, config.num_experts, bias=False) +# self.cond_ffn = ConditionalFeedForward(config) +# self.dim = config.dim +# self.num_activated_experts = config.num_activated_experts +# def forward(self, x: Tensor) -> Tensor: +# x = x.view(-1, self.dim) +# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts +# # x: [T, D] +# scores = self.gate(x) # [T, E] +# expert_weights = F.softmax(scores, dim=-1) +# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] +# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] +# expert_outs = self.cond_ffn(x, expert_indices) +# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.dim) # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + out = self.cond_ffn( + x, expert_indices, expert_weights, self.num_activated_experts + ) + return out.reshape(batch_size, -1, self.dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) # E, D, I + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) # E, I, D + self.num_experts = config.num_experts + + def forward( + self, + x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + if x.shape[0] == 1 and not isinstance( + self.w1, FakeExtraDimTensor + ): # only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices = expert_indices.view(num_activated_experts) + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + cur_out = F.linear(y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = ( + (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) + .sum(dim=0) + .unsqueeze(-1) + ) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort( + stable=True + ) # [A] + ordered_token_indices = ( + ordered_token_activations.div(num_activated_experts) + .floor() + .to(torch.int64) + ) # [T] + + if not expert_indices.is_cuda: # histc doesn't work on cpu for integers + num_tokens_per_expert = torch.bincount( + expert_indices.view(-1) + 1, minlength=self.num_experts + 1 + ) + else: + num_tokens_per_expert = torch.histc( + expert_indices, + bins=self.num_experts + 1, + min=-1, + max=self.num_experts, + ) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( + torch.int64 + ) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ): + token_indices_per_expert = [ + ordered_token_indices[ + cum_tokens_per_expert[expert] : cum_tokens_per_expert[ + expert + 1 + ] + ] + for expert in expert_list + ] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + token_indices_per_expert = group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ) + tokens_grouped_by_expert = [ + x[indices] for indices in token_indices_per_expert + ] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): + w1 = self.w1[expert] # I, D + w2 = self.w2[expert] # D, I + w3 = self.w3[expert] # I, D + + cur_out = F.linear( + F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2 + ) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1, 1)[ + ordered_token_activations + ].view(-1, 1) # [T*A, 1] + weighted_ordered_outs = ( + ordered_outs * ordered_token_activation_weights + ) # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, dim) + .to(torch.int64), + src=weighted_ordered_outs, + ) + return final_out diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh new file mode 100644 index 0000000000..d9e3a50405 --- /dev/null +++ b/torchao/_models/mixtral-moe/run.sh @@ -0,0 +1,39 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 +export CHECKPOINT_PATH=checkpoints/ + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile + +# # EXPERT CHOICE +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile +# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile +# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo-base --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq --compile + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile + +# ARM +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --device cpu +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --compile --device cpu diff --git a/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py new file mode 100644 index 0000000000..6a39578e32 --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import glob +import re +import sys +from pathlib import Path +from typing import Optional + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import ModelArgs + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + weight_map = { + "tok_embeddings.weight": "tok_embeddings.weight", + "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", + "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", + "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", + "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", + "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", + "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", + "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", + "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", + "norm.weight": "norm.weight", + "output.weight": "output.weight", + } + + pt_files = glob.glob(str(checkpoint_dir / "*.pt")) + + merged_result = {} + for file in sorted(pt_files): + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r".(\d+).", ".{}.", key) + layer_num = re.search(r"\d+", key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] + elif "w1" in key or "w3" in key: + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .contiguous() + ) + elif "w2" in key: + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .permute(0, 2, 1) + .contiguous() + ) + elif "gate" in key: + final_result[key] = final_result[key].contiguous() + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), + ) + parser.add_argument("--model_name", type=str, default=None) + + args = parser.parse_args() + convert_hf_checkpoint( + checkpoint_dir=args.checkpoint_dir, + model_name=args.model_name, + ) diff --git a/torchao/_models/mixtral-moe/scripts/download.py b/torchao/_models/mixtral-moe/scripts/download.py new file mode 100644 index 0000000000..8a451b001d --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/download.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: + from huggingface_hub import snapshot_download + + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + try: + snapshot_download( + repo_id, + local_dir=f"checkpoints/{repo_id}", + local_dir_use_symlinks=False, + token=hf_token, + ignore_patterns="*.safetensors", + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo_id", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token." + ) + + args = parser.parse_args() + hf_download(args.repo_id, args.hf_token) diff --git a/torchao/_models/mixtral-moe/scripts/prepare.sh b/torchao/_models/mixtral-moe/scripts/prepare.sh new file mode 100644 index 0000000000..8ca60b165b --- /dev/null +++ b/torchao/_models/mixtral-moe/scripts/prepare.sh @@ -0,0 +1,2 @@ +python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1 +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1 diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bf1bdacb68..1d70f5c7f3 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -476,12 +476,15 @@ def _(func, types, args, kwargs): shape = list(self.shape) shape[dim] = end - start block_size = self.block_size - assert len(block_size) == 2, ( - f"Slice only works for 2d block_size right now, got: {block_size}" - ) + assert len(block_size) in [ + 2, + 3, + ], f"Slice only works for 2 and 3d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + if len(block_size) == 2: + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__( aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, @@ -490,7 +493,53 @@ def _(func, types, args, kwargs): self.quant_max, self.zero_point_domain, dtype=self.dtype, - strides=self.stride(), + strides=self.stride() if len(block_size) == 2 else None, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + self, indices = args + assert len(indices) == 1, ( + f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}" + ) + new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices) + shape = tuple([indices[0].numel(), *self.shape[1:]]) + + block_size = self.block_size + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = fill_defaults(args, 3, [0, 0]) + assert dim == 0, f"op {func} currently only implemented for dim=0 but got dim={dim}" + assert self.dim() == 3, ( + f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}" + ) + + new_tensor_impl = aten.select.int(self.tensor_impl, dim, index) + + shape = self.shape[1:] + block_size = self.block_size[1:] + new = self.__class__( + new_tensor_impl, + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, ) return return_and_correct_aliasing(func, args, kwargs, new) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 872179bd9a..5914f00102 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -56,6 +56,9 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None +_fallback_warning_shown = False + + @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ @@ -100,12 +103,35 @@ def __init__( def _apply_fn_to_data(self, fn): """Applys a fn to all tensor components stored on this class""" - return self.__class__( - fn(self.float8_data), - fn(self.scale), - self.transposed, - self._layout, - ) + global _fallback_warning_shown + + try: + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) + except RuntimeError as e: + if '"index_cuda" not implemented for ' in str(e): + if not _fallback_warning_shown: + import warnings + + warnings.warn( + f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but " + + "note: You can torch.compile the model to avoid this problem.", + UserWarning, + ) + _fallback_warning_shown = True + + return self.__class__( # do indexing in bfloat16 then convert back + fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype), + fn(self.scale), + self.transposed, + self._layout, + ) + else: + raise e def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -159,6 +185,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func in [aten.select.int, aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 6caa0784d8..dc7b073f32 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -359,6 +359,7 @@ def _impl_2d_aten(input_tensor, weight_tensor): m, k = input_tensor.shape n, k_ = weight_tensor.shape + assert k_ == k group_size = weight_tensor.tensor_impl.get_layout().group_size packed_weight = weight_tensor.tensor_impl.packed_weight @@ -366,6 +367,9 @@ def _impl_2d_aten(input_tensor, weight_tensor): input_tensor, packed_weight, group_size, k, n ) + if input_tensor.numel() == 0: + return input_tensor + target = weight_tensor.tensor_impl.get_layout().target if weight_tensor.tensor_impl.get_layout().has_bias: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 516136bca7..3551214d7e 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -154,6 +154,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) + elif func in [aten.select.int, aten.index.Tensor]: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 9c37f58ada..3bf9ef6b72 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -93,11 +93,13 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: # handling for empty input + y = act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] @@ -119,7 +121,7 @@ class TensorCoreTiledLayout(Layout): inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( @@ -160,18 +162,18 @@ def post_process( zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - orig_out_features, orig_in_features = input.shape + orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) - assert len(block_size) == 2, ( - f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}" + assert len(block_size) == 2 or len(block_size) == 3, ( + f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}" ) - scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0] - scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1] + scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2] + scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1] scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) zero_point = torch.nn.functional.pad( zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) @@ -262,21 +264,44 @@ def from_plain( _layout: Layout, ): assert isinstance(_layout, TensorCoreTiledLayout) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + ) + + def quant_2d(int_data_2d): + if TORCH_VERSION_AT_LEAST_2_5: + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) + else: + assert int_data_2d.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + ) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d.contiguous(), _layout.inner_k_tiles + ) - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) + if zero_point is not None + else None ) else: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], -1) + if zero_point is not None + else None ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) @@ -336,6 +361,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + if func in [aten.select.int, aten.index.Tensor]: + assert not (func is aten.select.int and args[1] != 0), ( + "aten.select.int currently only has support for dim=0" + ) + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -416,11 +452,16 @@ def block_size(self): scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape - assert len(cur_shape) == 4 + if len(cur_shape) == 5: + ones = [1, 1] + cur_shape = cur_shape[1:] + else: + assert len(cur_shape) == 4 + ones = [1] inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) groupsize = int(original_shape[1] / scale.shape[-2]) - return (1, groupsize) + return tuple([*ones, groupsize]) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( @@ -429,35 +470,50 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros - scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + def dequant_4d(self): + cur_shape = self.shape + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + original_dtype = torch.bfloat16 + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=self.device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + return dequantized cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) - device = self.device - original_dtype = torch.bfloat16 + + if len(cur_shape) == 4: + dequantized = dequant_4d(self) + else: + assert len(cur_shape) == 5 + num_experts = cur_shape[0] + dequantized_list = [] + for expert in range(num_experts): + dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0)) + dequantized = torch.cat(dequantized_list, dim=0) + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + target_dtype = torch.int32 quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() int_data = quantize_affine( dequantized, - block_size, + self.block_size, scale, zero, target_dtype, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index da6c98cd6f..d1236e9183 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -629,6 +629,53 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( sqnr2 = compute_error(prepared_out, converted_out2).item() self.assertTrue(sqnr2 == float("inf")) + def test_moe_quant_intx(self): + from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, + ) + from torchao.quantization.prototype.moe_quant.utils import ( + FakeExtraDimTensor, + MoEQuantConfig, + UseFakeExtraDimTensor, + cond_ffn_filter, + ) + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + PackedLinearInt8DynamicActivationIntxWeightLayout, + quantize_, + ) + from torchao.quantization.utils import compute_error + + with torch.device("cpu"): + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to( + torch.float32 + ) + x = torch.randn(8, 512, dtype=torch.float32) + + out = model(x).clone() + + base_config = Int8DynamicActivationIntxWeightConfig( + layout=PackedLinearInt8DynamicActivationIntxWeightLayout() + ) + moe_config = MoEQuantConfig( + base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) + + quantize_(model, moe_config, cond_ffn_filter) + + out_q = model(x).clone() + assert isinstance(model.experts.w1, FakeExtraDimTensor) + + mod_c = torch.compile(model, mode="reduce-overhead") + + mod_c(x) + mod_c(x) + + out_qc = mod_c(x).clone() + + self.assertGreater(compute_error(out_q, out), 30) + self.assertGreater(compute_error(out_qc, out), 30) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4343a086f..aa946c064f 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -82,6 +82,8 @@ def __tensor_unflatten__( def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): + if input_tensor.numel() == 0: + return input_tensor input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor quant_kwargs = weight_tensor.quant_kwargs @@ -243,6 +245,34 @@ def _(func, types, args, kwargs): ) +@implements(aten.select.int) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) + + +@implements(aten.index.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + LinearActivationQuantizedTensor( + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, + ), + ) + + # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) def _(func, types, args, kwargs): diff --git a/torchao/quantization/prototype/moe_quant/README.md b/torchao/quantization/prototype/moe_quant/README.md new file mode 100644 index 0000000000..d774fae8fd --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/README.md @@ -0,0 +1,51 @@ +# MoE Quantization + +Our goal with this prototype implementation of moe quantization is to enable usage of existing linear quantization techniques for moe quantization. While it would likely be more performant to use a fused kernel for quantized moe, by decomposing the moe operation into a sequence of linear operations, we can utilize the existing tools and UX that work for lienar quantization and apply them to moe. + +Examples of the usage of these apis can be found in both the llama4_quant.py and ao/torchao/_models/mixtral-moe/generate.py + +## Quantization API + +The API for moe quantization is very similar to linear quantization, given a moe module that is decomposed into linear operations, is quantizable and compilable. In practice this requires us to use the modules found in quantizable_moe_modules.py or something similar. Once this change has been made the API is as follows for a few different quantization techniques: + +```python + +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, +from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig + +quantize_(model, MoEQuantConfig(Int8WeightOnlyConfig()), filter_fn=cond_ffn_filter) +model=torch.compile(model, mode="reduce-overhead") +# you can also use fullgraph=True for single token inference +``` + +This api is the same as for normal linear quantization but with a specific filter function. This works for several different quantization techniques where the quantized tensor subclass has been adapted to work with 3D tensors. Specifically this means Int8WeightOnlyConfig, Int4WeightOnlyConfig, Int4WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, and Int8DynamicActivationInt8WeightConfig. It should be noted that due to the requirements on minimum tensor input size (>16), Int8DynamicActivationInt8WeightConfig is best used for expert choice moe rather than token choice which is what the rest of the framework in this folder supports. + + +## Alternative Quantization API + +To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api by changing the fake_extra_dim_tensor flag of the MoEQuantConfig: + +```python + +from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig, UseFakeExtraDimTensor +from torchao.quantization.quant_api import quantize_, Int8DynamicActivationIntxWeightConfig + +config = MoEQuantConfig( + Int8DynamicActivationIntxWeightConfig(), + # this is the only difference from the above api + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, +) + +quantize_(model, , filter_fn=cond_ffn_filter) +model=torch.compile(model, mode="reduce-overhead") +``` + +It should also be noted that the default value for use_fake_extra_dim_tensor is AS_FALLBACK which means that it will try to use the base method but if not, will use the more general but less performant fake_extra_dim_tensor method. + +While this approach turns out to not be especially performant, it does allow for slightly better memory characteristics since all the tensors are held seperately and aren't actually modified or indexed. It is flexible enough to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though neither single token nor multi token inference works with fullgraph compilation. + +## Model API + +In practice the moe implementations of known models tend to not be easy to quantize and even of those that are, they are often either compiled with many graph breaks or impossible to torch.compile at all. + +The modules in the quantizable_moe_modules.py file were carefully written to satisfy both of those necessary characteristics but to apply moe quantization to your own model, it will require first a module swap from the existing MoE module type, to these more flexible ones. While there isn't a one size fits all way to do this, an example of how it was done for huggingface's llama4 implementation can be found in llama4_quant.py which can be seen as a proof of concept. diff --git a/torchao/quantization/prototype/moe_quant/__init__.py b/torchao/quantization/prototype/moe_quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/moe_quant/llama4_quant.py b/torchao/quantization/prototype/moe_quant/llama4_quant.py new file mode 100644 index 0000000000..67ad2ab464 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/llama4_quant.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +import torch +import torch.nn as nn +from transformers import AutoTokenizer, Llama4ForCausalLM +from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, +) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +def llama4_moe_filter_fn(module, fqn): + return isinstance(module, Llama4TextMoe) + + +def convert_fn(module): + # get data + hidden_dim = module.hidden_dim + expert_dim = module.experts.expert_dim + num_experts = module.num_experts + top_k = module.top_k + act_fn = module.experts.act_fn + shared_expert = module.shared_expert + return_scores = True + new_mod = MOEFeedForwardAOQuantizable( + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn, + shared_expert, + return_scores, + ) + + router = module.router + up_proj = module.experts.gate_up_proj + w1, w3 = up_proj.permute(0, 2, 1).chunk(2, dim=1) + w2 = module.experts.down_proj.permute(0, 2, 1) + + new_mod.router = router + new_mod.experts.w1 = nn.Parameter(w1, requires_grad=False) + new_mod.experts.w2 = nn.Parameter(w2, requires_grad=False) + new_mod.experts.w3 = nn.Parameter(w3, requires_grad=False) + return new_mod + + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +_replace_with_custom_fn_if_matches_filter( + model, + convert_fn, + llama4_moe_filter_fn, +) + +model = model + +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.quantization.prototype.moe_quant.utils import ( + MoEQuantConfig, + cond_ffn_filter, +) + +quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda") + +model.cuda() + +model = torch.compile(model, mode="reduce-overhead") + +prompt = "He is here, the one who will tear apart the very stars" +inputs = tokenizer(prompt, return_tensors="pt") +model.generate(inputs.input_ids.cuda(), max_length=30) +model.generate(inputs.input_ids.cuda(), max_length=30) +generate_ids = model.generate(inputs.input_ids.cuda(), max_length=50) +out = tokenizer.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +)[0] +print(out) diff --git a/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py new file mode 100644 index 0000000000..516341a3a8 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py @@ -0,0 +1,191 @@ +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from torchao.quantization.prototype.moe_quant.utils import FakeExtraDimTensor + + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__( + self, + hidden_dim, + expert_dim, + num_experts, + top_k, + act_fn=F.silu, + shared_expert=None, + return_scores=False, + empty_init=True, + ) -> None: + super().__init__() + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = ConditionalFeedForwardAOQuantizable( + num_experts, hidden_dim, expert_dim, act_fn, empty_init + ) + self.hidden_dim = hidden_dim + self.top_k = top_k + self.shared_expert = shared_expert + self.return_scores = return_scores + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] + x = x.view(-1, self.hidden_dim) # x: [T, D] + scores = self.router(x) # [T, E] + scores = F.softmax(scores, dim=-1) + scores, expert_indices = torch.topk( + scores, self.top_k, dim=-1 + ) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + + out = self.experts(x, expert_indices, scores, self.top_k) + if self.shared_expert: + out += self.shared_expert(x) + + if self.return_scores: + return out.reshape(batch_size, -1, self.hidden_dim), scores + else: + return out.reshape(batch_size, -1, self.hidden_dim) + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True): + super().__init__() + if empty_init: + self.w1 = nn.Parameter( + torch.empty(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.empty(num_experts, hidden_dim, expert_dim) + ) # E, D, I + self.w3 = nn.Parameter( + torch.empty(num_experts, expert_dim, hidden_dim) + ) # E, I, D + else: + self.w1 = nn.Parameter( + torch.randn(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.w2 = nn.Parameter( + torch.randn(num_experts, hidden_dim, expert_dim) + ) # E, D, I + self.w3 = nn.Parameter( + torch.randn(num_experts, expert_dim, hidden_dim) + ) # E, I, D + self.num_experts = num_experts + self.act_fn = act_fn + self.hidden_dim = hidden_dim + self.expert_dim = expert_dim + + def forward( + self, + x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + top_k: int, + ) -> Tensor: + num_tokens, _hidden_dim = x.shape + num_token_activations = num_tokens * top_k + + if x.shape[0] == 1 and not isinstance( + self.w1, FakeExtraDimTensor + ): # only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices = expert_indices.view(top_k) + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + # run token through each expert + for index in range(top_k): + y1 = F.silu(F.linear(x, w1[index])) + y3 = F.linear(x, w3[index]) + y2 = w2[index] + + cur_out = F.linear(y1 * y3, y2) + outs.append(cur_out) + + # combine outputs + final_out = ( + (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) + .sum(dim=0) + .reshape(x.shape) + ) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] + + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort( + stable=True + ) # [A] + ordered_token_indices = ( + ordered_token_activations.div(top_k).floor().to(torch.int64) + ) # [T] + if not expert_indices.is_cuda: # histc doesn't work on cpu for integers + num_tokens_per_expert = torch.bincount( + expert_indices.view(-1) + 1, minlength=self.num_experts + 1 + ) + else: + num_tokens_per_expert = torch.histc( + expert_indices, + bins=self.num_experts + 1, + min=-1, + max=self.num_experts, + ) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( + torch.int64 + ) # [E+1] + + @torch._dynamo.disable() + def group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ): + token_indices_per_expert = [ + ordered_token_indices[ + cum_tokens_per_expert[expert] : cum_tokens_per_expert[ + expert + 1 + ] + ].to(torch.int64) + for expert in expert_list + ] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + token_indices_per_expert = group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert, expert_list + ) + tokens_grouped_by_expert = [ + x[indices] for indices in token_indices_per_expert + ] + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): + w1 = self.w1[expert] # I, D + w2 = self.w2[expert] # D, I + w3 = self.w3[expert] # I, D + + y1 = F.silu(F.linear(cur_x, w1)) + y3 = F.linear(cur_x, w3) + y2 = w2 + + cur_out = F.linear(y1 * y3, y2) # [T'(e), D] + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1, 1)[ + ordered_token_activations + ].view(-1, 1) # [T*A, 1] + weighted_ordered_outs = ( + ordered_outs * ordered_token_activation_weights + ) # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, self.hidden_dim) + .to(torch.int64), + src=weighted_ordered_outs, + ) + return final_out diff --git a/torchao/quantization/prototype/moe_quant/utils.py b/torchao/quantization/prototype/moe_quant/utils.py new file mode 100644 index 0000000000..16fa8c8d33 --- /dev/null +++ b/torchao/quantization/prototype/moe_quant/utils.py @@ -0,0 +1,308 @@ +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +aten = torch.ops.aten + +from enum import Enum, auto +from typing import List, Optional, Tuple, Union + +from torchao.quantization.quant_api import ( + _QUANTIZE_CONFIG_HANDLER, + AOBaseConfig, + dataclass, + register_quantize_module_handler, +) +from torchao.utils import fill_defaults + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias + + +class FakeExtraDimTensor(torch.Tensor): + """This is a subclass of torch.Tensor that simulates a tensor of n+1 dimensions, akin to concatenating several tensors along the 0th dimension. + It takes a list of tensors with the same dtype, device and shape and creates a representation of shape (num_tensors, orig_shape). It can handle a + variety of ops like detach and clone but most importantly, supports any slicing and indexing along the extra dimension. + This is most useful when you have another tensor subclass that you'd like to concatenate together but don't want to support all the necessary + pieces of 3D scaffolding required to make it work. + + The structure of this tensor subclass is a linked_list of tensors with each instance of FakeExtraDimTensor containing a head tensor and a tail consisting of + either another intance of FakeExtraDimTensor or None if we've reached the end of the linked list. This implementation structure is necessary to + support compilation of this tensor subclass since compile requires each tensor component of the tensor subclass to have its own attribute. + """ + + def __new__( + cls, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"] = None, + ): + assert len(tensors) > 0 or tensor_tail is not None + num_tensors = len(tensors) + if tensor_tail is not None: + num_tensors += tensor_tail.num_tensors + test_tensor = tensor_tail.head_tensor + else: + test_tensor = tensors[0] + + dtype = test_tensor.dtype + shape = test_tensor.shape + device = test_tensor.device + layout = test_tensor.layout + for tensor in tensors: + assert tensor.dtype == dtype, ( + f"all tensors in FakeExtraDimTensor must have same dtype but got {tensor.dtype} and {dtype}" + ) + assert tensor.shape == shape, ( + f"all tensors in FakeExtraDimTensor must have same shape but got {tensor.shape} and {shape}" + ) + assert tensor.device == device, ( + f"all tensors in FakeExtraDimTensor must have same device but got {tensor.device} and {device}" + ) + assert tensor.layout == layout, ( + f"all tensors in FakeExtraDimTensor must have same layout but got {tensor.layout} and {layout}" + ) + kwargs = {} + kwargs["dtype"] = dtype + kwargs["layout"] = layout + kwargs["device"] = device + kwargs["requires_grad"] = False + new_shape = (num_tensors, *shape) + return torch.Tensor._make_wrapper_subclass(cls, new_shape, **kwargs) + + def __repr__( + self, + ): + return f"{self.__class__.__name__}(shape={self.shape}, containing {self.num_tensors}: {self.head_tensor})" + + def __init__( + self, + tensors: Union[Tuple[torch.Tensor], List[torch.Tensor]], + tensor_tail: Optional["FakeExtraDimTensor"] = None, + ): + tensors = list(tensors) + assert len(tensors) > 0 or tensor_tail is not None + + # count num_tensors and make tensor_list + self.num_tensors = len(tensors) + if tensor_tail is not None: + self.num_tensors += tensor_tail.num_tensors + tail_list = tensor_tail.tensor_list + else: + tail_list = [] + self.tensor_list = tensors + tail_list + + # 3 cases + # 0) tensors has 0 elements -> take element from tail then do case 1 instead + # 1) tensors has 1 element, -> pop element and tail is None + # 2) tensors has >1 elements, -> pop element and recurse + + # convert case 0 to case 1 by taking 1 element from tail + if len(tensors) == 0 and tensor_tail is not None: + tensors = [ + tensor_tail.head_tensor, + ] + tensor_tail = tensor_tail.tensor_tail + + if len(tensors) > 1: + # case (1): remove first element from tensors, then recurse + self.head_tensor = tensors[0] # remove one + self.tensor_tail = self.__class__(tensors[1:], tensor_tail) # recurse + elif len(tensors) == 1: + # case (2) take final element from tensors, attach tensor_tail then stop recursion + self.head_tensor = tensors[0] + self.tensor_tail = tensor_tail + + def _apply_fn_to_data(self, fn): + self.head_tensor = fn(self.head_tensor) + if self.tensor_tail is not None: + self.tensor_tail = self.tensor_tail._apply_fn_to_data(fn) + return self.__class__([self.head_tensor], self.tensor_tail) + + def __tensor_flatten__(self): + if self.tensor_tail is None: + return [ + "head_tensor", + ], [self.num_tensors] + else: + return [ + "head_tensor", + "tensor_tail", + ], [self.num_tensors] + + @classmethod + def __tensor_unflatten__( + cls, + tensor_data_dict, + tensor_attributes, + outer_size, + outer_stride, + ): + head_tensor = tensor_data_dict["head_tensor"] + tensor_tail = tensor_data_dict.get("tensor_tail", None) + return cls([head_tensor], tensor_tail) + + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if func is torch.nn.functional.linear: + x, w, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert w.num_tensors == 1, ( + "FakeExtraDimTensor used in a linear op when it had multiple tensors" + ) + return func(x, w.head_tensor, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception as e: + print(f"ERR: subclass {cls} doesn't implement {func}, got error: {e}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func == aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, cls(self.tensor_list[start:end:step]) + ) + + elif func == aten.select.int: + self, dim, index = fill_defaults(args, 3, [0, 0]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, cls([self.tensor_list[index]]) + ) + elif func == aten.index.Tensor: + self, indices, dim = fill_defaults(args, 3, [0]) + if dim == 0: + # this handles a weird bug where indices gets turned into a list + # between the function dispatch and torch dispatch but just for this function + if isinstance(indices, list) and len(indices) == 1: + indices = indices[0] + return return_and_correct_aliasing( + func, + args, + kwargs, + cls([self.tensor_list[index] for index in indices]), + ) + try: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + except Exception as e: + print( + f"function {func} failed for FakeExtraDimTensor, following error occured when trying to" + "run function on its elements: " + ) + raise e + + +class UseFakeExtraDimTensor(Enum): + """Enum that indicate whether to use FakeExtraDimTensor""" + + TRUE = auto() + FALSE = auto() + AS_FALLBACK = auto() + + +@dataclass +class MoEQuantConfig(AOBaseConfig): + """Configuration for applying quantization to MoE + Args: + `base_config`: normal AO Config + """ + + base_config: AOBaseConfig + use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK + set_inductor_config: bool = True + + +# Module-level flag to track if we've already printed the error +_moe_quant_tensor_has_printed_error = False + + +def _moe_quant_tensor(weight, config): + def _moe_quant_tensor_base(weight, config): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(weight) + quant_mod = base_config_handler(dummy_mod, config.base_config) + return quant_mod.weight + + def _moe_quant_tensor_fake_extra_dim_tensor(weight, config): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + # break 3D tensor + tensors = [weight[i] for i in range(weight.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + quant_mods = list( + map(lambda x: base_config_handler(x, config.base_config), dummy_modules) + ) + # pack quantized subclasses into FakeExtraDimTensor + quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods]) + return quant_weight + + global _moe_quant_tensor_has_printed_error + + use_fake = config.use_fake_extra_dim_tensor + if use_fake == UseFakeExtraDimTensor.FALSE: + return _moe_quant_tensor_base(weight, config) + elif use_fake == UseFakeExtraDimTensor.AS_FALLBACK: + try: + return _moe_quant_tensor_base(weight, config) + except Exception as e: + if not _moe_quant_tensor_has_printed_error: + print(f"tried to do moe_quant but got error: {e}") + _moe_quant_tensor_has_printed_error = True + return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) + else: # This handles UseFakeExtraDimTensor.TRUE + return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) + + +@register_quantize_module_handler(MoEQuantConfig) +def moe_quant_fn(module, config: MoEQuantConfig): + import warnings + + warnings.simplefilter("ignore", lineno=84) + warnings.simplefilter("ignore", lineno=105) + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + assert param.dim() == 3, ( + f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" + ) + assert isinstance(config.base_config, AOBaseConfig), ( + f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" + ) + new_param = _moe_quant_tensor(param, config) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + +def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + + +def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 890c2e2038..c20c37a194 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter( device, extra_args, ) - if new_child is not child: + if new_child is not child and new_child is not None: setattr(model, name, new_child) if device is not None: model.to(device=device) # move parent module to device @@ -1050,31 +1050,25 @@ class Int4WeightOnlyConfig(AOBaseConfig): int4_weight_only = Int4WeightOnlyConfig -@register_quantize_module_handler(Int4WeightOnlyConfig) -def _int4_weight_only_transform( - module: torch.nn.Module, config: Int4WeightOnlyConfig -) -> torch.nn.Module: +def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details # for now, make these local variables to allow the rest of the function # to be a direct copy-paste - weight = module.weight group_size = config.group_size layout = config.layout use_hqq = config.use_hqq zero_point_domain = config.zero_point_domain - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() if weight.shape[-1] % group_size != 0: logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) - return module + return weight mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1126,6 +1120,21 @@ def _int4_weight_only_transform( _layout=layout, use_hqq=use_hqq, ) + return new_weight + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int4_weight_only_quantize_tensor(module.weight, config) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1145,20 +1154,15 @@ class Int8WeightOnlyConfig(AOBaseConfig): int8_weight_only = Int8WeightOnlyConfig -@register_quantize_module_handler(Int8WeightOnlyConfig) -def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): - group_size = config.group_size - weight = module.weight - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 + group_size = config.group_size if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1167,6 +1171,19 @@ def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyC eps=eps, zero_point_dtype=zero_point_dtype, ) + return new_weight + + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int8_weight_only_quantize_tensor(module.weight, config) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1283,33 +1300,26 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig -@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) -def _int8_dynamic_activation_int8_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig -) -> torch.nn.Module: +def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type weight_only_decode = config.weight_only_decode - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - in_features = weight.shape[1] + in_features = weight.shape[-1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: logger.info( f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" f" because `in_feature` is <= 16: {in_features}" ) - return module + return weight # weight settings mapping_type = MappingType.SYMMETRIC weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): - return (1, x.shape[1]) + return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -1325,7 +1335,7 @@ def get_weight_block_size(x): input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -1335,8 +1345,25 @@ def get_weight_block_size(x): _layout=layout, zero_point_domain=weight_zero_point_domain, ) - weight = to_linear_activation_quantized(weight, input_quant_func) - module.weight = torch.nn.Parameter(weight, requires_grad=False) + new_weight = to_linear_activation_quantized(new_weight, input_quant_func) + return new_weight + + +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 dynamic activation int8 weight quant requires module to have weight attribute" + + "but {module} does not have one" + ) + new_weight = _int8_dynamic_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1375,17 +1402,10 @@ class Float8WeightOnlyConfig(AOBaseConfig): float8_weight_only = Float8WeightOnlyConfig -@register_quantize_module_handler(Float8WeightOnlyConfig) -def _float8_weight_only_transform( - module: torch.nn.Module, config: Float8WeightOnlyConfig -) -> torch.nn.Module: +def _float8_weight_only_quant_tensor(weight, config): from torchao.dtypes import to_affine_quantized_floatx - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - weight = module.weight - block_size = (1, weight.shape[1]) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) new_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1393,6 +1413,22 @@ def _float8_weight_only_transform( scale_dtype=None, _layout=Float8Layout(mm_config=None), ) + return new_weight + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -1496,11 +1532,12 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: Returns: bool: True if the tensor can be quantized to float8, False otherwise """ - assert weight.dim() == 2, ( - f"float8 quantization only works for 2-D tensors, got {weight.dim()}D tensor" - ) + assert weight.dim() in [ + 2, + 3, + ], f"float8 quantization only works for 2/3-D tensors, got {weight.dim()}D tensor" - out_dim, in_dim = weight.shape + out_dim, in_dim = weight.shape[-2:] is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0) if not is_compatible: @@ -1547,34 +1584,26 @@ def __post_init__(self): float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig -@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) -def _float8_dynamic_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig -): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - +def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config - weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked - return module + return weight if isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) - block_size = get_block_size(weight.shape, weight_granularity) + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) quantized_weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -1592,7 +1621,26 @@ def _float8_dynamic_activation_float8_weight_transform( quantized_weight = to_linear_activation_quantized( quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs ) + return quantized_weight + +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( + module.weight, config + ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index b6fac49ae9..339d46be35 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -47,5 +47,6 @@ def _transform( @functools.wraps(config_type) def decorator(func): _QUANTIZE_CONFIG_HANDLER[config_type] = func + return func # needed to make the functions usable externally return decorator diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0c30fba713..a9cad8060e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -365,22 +365,23 @@ def get_groupwise_affine_qparams( def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) guard_dtype_size(zeros, "zeros", dtype=dtype) + dim = scales.dim() return ( torch.cat( [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), + scales.unsqueeze(-1), + zeros.unsqueeze(-1), ], - 2, + dim, ) - .transpose(0, 1) + .transpose(-3, -2) .contiguous() ) def unpack_tinygemm_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + assert scales_and_zeros.shape[-1] == 2 + return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): diff --git a/torchao/utils.py b/torchao/utils.py index db269b4cb0..280da4e632 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -10,7 +10,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Tuple +from typing import Any, Callable import torch import torch.nn.utils.parametrize as parametrize @@ -170,7 +170,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): return measurement.mean * 1e6 -def find_multiple(n: int, *args: Tuple[int]) -> int: +def find_multiple(n: int, *args: int) -> int: k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n