-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[quant] Add quant APIs to save/load observer state_dict #44846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e1275f7 Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 296c73e Pull Request resolved: #44846
💊 CI failures summary and remediationsAs of commit efc434f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 49 times. |
Codecov Report
@@ Coverage Diff @@
## gh/supriyar/181/base #44846 +/- ##
=======================================================
Coverage ? 68.07%
=======================================================
Files ? 396
Lines ? 51293
Branches ? 0
=======================================================
Hits ? 34920
Misses ? 16373
Partials ? 0 Continue to review full report at Codecov.
|
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 923da1a Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8e69ddd Pull Request resolved: #44846
|
|
||
| # save state_dict of model | ||
| import io | ||
| obs_dict = torch.quantization.get_observer_state_dict(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we seem to be doing this only for activation observers. As we move to QAT, we will need similar functionality, but we will need to consider weight fake-quants too. Is there a plan to add weight observers to the state dict?
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 63d3db8 Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: df483ce Pull Request resolved: #44846
| missing_keys, unexpected_keys, error_msgs) | ||
|
|
||
| @torch.jit.export | ||
| def _load_from_state_dict_script(self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to call _load_from_state_dict, just curious on what would break if this function wasn't there? I might be missing something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that without this, the call to load_from_state_dict was calling the one defined in module.py. Once we script the model we lose access to the overridden load_from_state_dict so we need to call a different fn here to get around it.
| _is_observer_script_module(module, "torch.quantization.observer.MovingAveragePerChannelMinMaxObserver") | ||
| return False | ||
|
|
||
| def get_observer_state_dict(mod): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thoughts about making most of these FB-only, since the use case is pretty specific? Or are there OSS needs as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That may be true right now, but Raghu mentioned these will come in use when we develop the numeric suite and auto quantizer for OSS. One can argue that the OSS use cases might not need this functionality for script models, but it might be better to keep it all in the same place.
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c27c255 Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f0d591a Pull Request resolved: pytorch/pytorch#44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 69c958d Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0b6130c Pull Request resolved: #44846
Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23746821](https://our.internmc.facebook.com/intern/diff/D23746821) [ghstack-poisoned]
…line with a title. Use 1 line only, 67 chars or less> Summary: The save function traverses the model state dict to pick out the observer stats load function traverse the module hierarchy to load the state dict into module attributes depending on observer type Test Plan: python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict python test/test_quantization.py TestQuantizeJitPasses.test_save_observer_state_dict Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 246eac0 Pull Request resolved: #44846
|
This pull request has been merged in 489af4d. |
Stack from ghstack:
Summary:
The save function traverses the model state dict to pick out the observer stats
load function traverse the module hierarchy to load the state dict into module attributes depending on observer type
Test Plan:
python test/test_quantization.py TestQuantizeFx.test_save_observer_state_dict
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D23746821