-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add API logging for interpolate() #88212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. |
torch/nn/functional.py
Outdated
| # Logging logic | ||
| is_channels_last = input.is_contiguous(memory_format=torch.channels_last) | ||
| is_image_or_mask = input.ndim == 4 and input.shape[1] < 4 | ||
| log_string = f"torch.nn.functional.interpolate_dtype={input.dtype}_mode={mode}_antialias={antialias}_channelslast={is_channels_last}_imageormask={is_image_or_mask}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is ok. Creating such string is quite expensive and I don't think we want all our users to pay that price.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this may not be desirable on more ubiquitous ops like sum() or max(), but is this a problem in practice for interpolate() which is in general a fairly slow operation? On a typical input the string construction would only account for 0.1% of the interpolate() call:
[ins] In [13]: import torch
[ins] In [14]: input = torch.rand(1, 3, 350, 280); antialias = True; mode = "bilinear"
[ins] In [16]: torch.set_num_threads(1)
[ins] In [17]: %timeit torch.nn.functional.interpolate(input, size=(224, 224), mode=mode, antialias=antialias)
1.05 ms ± 449 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
[ins] In [18]: %%timeit
...: is_channels_last = input.is_contiguous(memory_format=torch.channels_last)
...: is_image_or_mask = input.ndim == 4 and input.shape[1] < 4
...: log_string = f"torch.nn.functional.interpolate_dtype={input.dtype}_mode={mode}_antialias={antialias}
...: _channelslast={is_channels_last}_imageormask={is_image_or_mask}"
...:
1.15 µs ± 1.97 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)If you feel strongly against this, would you mind sharing your thoughts on this thread? According to our internal discussion with @jisaacso and @colin2328 , this kind of logging should be possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sent there.
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From internal discussion, this is an ok short term solution while we are building a better infra for this.
This will be removed when the new infra is in place.
This PR calls `torch._C._log_api_usage_once()` for `torch.nn.functional.interpolate()`. See [this internal post](https://fb.workplace.com/groups/1144215345733672/permalink/2382133531941841/) for more context about the level of details needed in the logger cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
| f"torch.nn.functional.interpolate_dtype={input.dtype}_mode={mode}_antialias={antialias}_" | ||
| f"channelslast={is_channels_last}_imageormask={is_image_or_mask}" | ||
| ) | ||
| torch._C._log_api_usage_once(log_string) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be clear @albanD , this logger has been OSS in Pytorch for years #20745. @NicolasHug is not introducing a new logging mechanism, but rather using it in a new place. I'm happy to have the convo about changing the implementation of the logger long term, but I don't want to block use cases in the short term.
cc @dzhulgakov who introduced me to the logger years ago. From Dima, when I asked about perf, he said
"in C++ I think it's super fast as it logs the first time only (thread local)
in Python it still has to go through the whole call stack and does one hashmap lookup afair"
@dzhulgakov do you have any other historical context that's useful here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's doing a uniqueness check for the log message, so definitely has some overhead: https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L1159
What is the need for this particular logging? If you're trying to capture usage at scale, you could probably look at ATen level operator logging (which at Meta is hooked up to Scuba)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for chiming in @dzhulgakov . The reason for this logging is to identify which low-level kernels are more often used internally (and by whom), so we can prioritize perf improvements. Would you mind sharing more context about the Aten operator logging you're mentioning? Are you saying that all the calls to torch._C._nn.XYZ like these ones are logged somewhere internally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all aten operators are getting logged with low sampling probability. It's done by installing an observer like described here: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#fleet-wide-operator-profiling . At Meta look up pytorch_operator_stats in Scuba.
This PR calls `torch._C._log_api_usage_once()` for `torch.nn.functional.interpolate()`. See [this internal post](https://fb.workplace.com/groups/1144215345733672/permalink/2382133531941841/) for more context about the level of details needed in the logger cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
This PR calls
torch._C._log_api_usage_once()fortorch.nn.functional.interpolate(). See this internal post for more context about the level of details needed in the loggercc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are