-
Notifications
You must be signed in to change notification settings - Fork 386
Make FqnToConfig handle module swap configs #3492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -480,19 +480,23 @@ def quantize_( | |
| raise ValueError( | ||
| "Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified." | ||
| ) | ||
|
|
||
| for module_fqn, module in model.named_modules(): | ||
| named_modules = dict(model.named_modules()) | ||
| for module_fqn, module in named_modules.items(): | ||
| if ( | ||
| fqn_matches_fqn_config(module_fqn, config) | ||
| or _module_param_matches_fqn_config(module, module_fqn, config) | ||
| or ("_default" in config.fqn_to_config and _is_linear(module)) | ||
| ): | ||
| # this replaces inplace, so no need to reassign | ||
| _fqn_to_config_handler(module, module_fqn, config) | ||
| replacement = _fqn_to_config_handler(module, module_fqn, config) | ||
| if device is not None: | ||
| module.to(device=device) | ||
| return | ||
| if isinstance(config, AOBaseConfig): | ||
| replacement = replacement.to(device=device) | ||
| # handle module swap | ||
| if replacement is not module and module_fqn != "": | ||
| child_name = module_fqn.split(".")[-1] | ||
| parent_fqn = module_fqn.removesuffix(child_name).removesuffix(".") | ||
| parent_module = named_modules[parent_fqn] | ||
| setattr(parent_module, child_name, replacement) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're missing a return here at the end
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the next case to elif, is that OK?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh missed that, yeah lgtm |
||
| elif isinstance(config, AOBaseConfig): | ||
| filter_fn = _is_linear if filter_fn is None else filter_fn | ||
| handler = _QUANTIZE_CONFIG_HANDLER[type(config)] | ||
| # for each linear in the model, apply the transform if filtering passes | ||
|
|
@@ -503,7 +507,6 @@ def quantize_( | |
| device=device, | ||
| extra_args=(config,), | ||
| ) | ||
|
|
||
| else: | ||
| raise AssertionError( | ||
| """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can implement something like
so then we can keep it iterative? Only thing is that we have to get the modules in top-down order (this does appear to be how named_modules is implemented)