Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
LinearActivationQuantizedTensor,
PerGroup,
)
from torchao.quantization.qat import (
FakeQuantizedLinear,
QATConfig,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Expand Down Expand Up @@ -1199,6 +1203,32 @@ def __init__(self):
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_fqn_config_quantized_nested_module_module_swap(self):
class NestedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)

class TopLevelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedModule()
self.linear1 = torch.nn.Linear(16, 16)

m = TopLevelModule()
config = QATConfig(Int4WeightOnlyConfig(), step="prepare")
quant_config = FqnToConfig(
{
"nested.linear": config,
"linear1": config,
}
)
quantize_(m, quant_config, filter_fn=None)

assert isinstance(m.nested.linear, FakeQuantizedLinear)
assert isinstance(m.linear1, FakeQuantizedLinear)

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_fqn_config_quantized_nested_module_param(self):
class NestedModule(torch.nn.Module):
Expand Down
19 changes: 11 additions & 8 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,19 +480,23 @@ def quantize_(
raise ValueError(
"Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified."
)

for module_fqn, module in model.named_modules():
named_modules = dict(model.named_modules())
for module_fqn, module in named_modules.items():
if (
fqn_matches_fqn_config(module_fqn, config)
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
# this replaces inplace, so no need to reassign
_fqn_to_config_handler(module, module_fqn, config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can implement something like

def set_module_by_fqn(model, fqn, new_module):
    # set model[fqn] to new_module

so then we can keep it iterative? Only thing is that we have to get the modules in top-down order (this does appear to be how named_modules is implemented)

    for module_fqn, module in model.named_modules():
            if (
                fqn_matches_fqn_config(module_fqn, config)
                or _module_param_matches_fqn_config(module, module_fqn, config)
                or ("_default" in config.fqn_to_config and _is_linear(module))
            ):
                new_module = _fqn_to_config_handler(module, module_fqn, config)
                # reassign module
                set_module_by_fqn(model, fqn, new_module) 

replacement = _fqn_to_config_handler(module, module_fqn, config)
if device is not None:
module.to(device=device)
return
if isinstance(config, AOBaseConfig):
replacement = replacement.to(device=device)
# handle module swap
if replacement is not module and module_fqn != "":
child_name = module_fqn.split(".")[-1]
parent_fqn = module_fqn.removesuffix(child_name).removesuffix(".")
parent_module = named_modules[parent_fqn]
setattr(parent_module, child_name, replacement)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're missing a return here at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the next case to elif, is that OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh missed that, yeah lgtm

elif isinstance(config, AOBaseConfig):
filter_fn = _is_linear if filter_fn is None else filter_fn
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
# for each linear in the model, apply the transform if filtering passes
Expand All @@ -503,7 +507,6 @@ def quantize_(
device=device,
extra_args=(config,),
)

else:
raise AssertionError(
"""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead."""
Expand Down
Loading