diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 3f9bc4d345..13aa6fd6e7 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -40,6 +40,10 @@ LinearActivationQuantizedTensor, PerGroup, ) +from torchao.quantization.qat import ( + FakeQuantizedLinear, + QATConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, @@ -1199,6 +1203,32 @@ def __init__(self): assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) assert isinstance(m.linear1.weight, AffineQuantizedTensor) + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") + def test_fqn_config_quantized_nested_module_module_swap(self): + class NestedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + class TopLevelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.nested = NestedModule() + self.linear1 = torch.nn.Linear(16, 16) + + m = TopLevelModule() + config = QATConfig(Int4WeightOnlyConfig(), step="prepare") + quant_config = FqnToConfig( + { + "nested.linear": config, + "linear1": config, + } + ) + quantize_(m, quant_config, filter_fn=None) + + assert isinstance(m.nested.linear, FakeQuantizedLinear) + assert isinstance(m.linear1, FakeQuantizedLinear) + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_fqn_config_quantized_nested_module_param(self): class NestedModule(torch.nn.Module): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bc6b7ccc8d..963142f574 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -480,19 +480,23 @@ def quantize_( raise ValueError( "Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified." ) - - for module_fqn, module in model.named_modules(): + named_modules = dict(model.named_modules()) + for module_fqn, module in named_modules.items(): if ( fqn_matches_fqn_config(module_fqn, config) or _module_param_matches_fqn_config(module, module_fqn, config) or ("_default" in config.fqn_to_config and _is_linear(module)) ): - # this replaces inplace, so no need to reassign - _fqn_to_config_handler(module, module_fqn, config) + replacement = _fqn_to_config_handler(module, module_fqn, config) if device is not None: - module.to(device=device) - return - if isinstance(config, AOBaseConfig): + replacement = replacement.to(device=device) + # handle module swap + if replacement is not module and module_fqn != "": + child_name = module_fqn.split(".")[-1] + parent_fqn = module_fqn.removesuffix(child_name).removesuffix(".") + parent_module = named_modules[parent_fqn] + setattr(parent_module, child_name, replacement) + elif isinstance(config, AOBaseConfig): filter_fn = _is_linear if filter_fn is None else filter_fn handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -503,7 +507,6 @@ def quantize_( device=device, extra_args=(config,), ) - else: raise AssertionError( """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead."""