-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Speed up HistogramObserver by vectorizing critical path #41041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 11ce38b (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 17 times. |
|
nice! In general looks good, can we just add to the test plan:
|
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
torch/quantization/observer.py
Outdated
| norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) | ||
| return norm | ||
|
|
||
| src_bin = torch.arange(self.bins).numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch doesn't have a NumPy dependency for its functionality (although we do for some tests), and we shouldn't use NumPy functionality in lieu of our own. Uses of NumPy should be restricted to testing and NumPy interop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback -- I changed my code to get rid of the numpy dependency.
| delta_end = src_bin_end - dst_bin_of_end_center | ||
| norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) | ||
| return norm | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be optimized further by the following approximation:
Quantization error = (StepSize^2/12)Q + sum(P[i](BinCenter[i]-next_start_bin)^2) + sum(Pi]*(BinCenter[i] - end_start_bin)^2).
Q = sum(hist[next_start_bin:next_end_bin])
where the first sum is over the bins less than the start_bin and the second sum is over bins greater than the end bin. In this approximation, we only need to compute two indices: Where do the next_start_bin and next_end_bin map to in terms of the original histogram indices
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Roughly a 22x speedup over the code this replaces when tested on ResNet18 on a devvm using CPU only, using default parameters for HistogramObserver (i.e. 2048 bins). The script I ran to test this is here.
Roughly a 14x speedup when tested using the benchmark from #42138 (also CPU only).
Stack from ghstack:
Differential Revision: D22400755