-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Check for internal memory overlap in some indexing-type functions #43423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 84b1711 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 38 times. |
…nctions" [ghstack-poisoned]
|
cc @neerajprad @fritzo for changes to distribution tests. My impression was that broadcasted tensors are routinely used in distributions, so requiring non-self-overlapping outputs could lead to significant memory penaly. |
| value, probs = broadcast_all(value, self.probs.clone(memory_format=torch.contiguous_format)) | ||
| value, probs = broadcast_all(value, self.probs) | ||
| probs = probs.clone(memory_format=torch.contiguous_format) | ||
| probs[(probs == 1) & (value == 0)] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My impression was that broadcasted tensors are routinely used in distributions, so requiring non-self-overlapping outputs could lead to significant memory penaly.
This is a bug fix. If probs is broadcasted here, then this line can write to locations where value != 0 if it happens to overlap with a location where value is zero. The full tensor size will be required in the following calculation anyway, so I think this is fine.
The same reasoning applies to the change in multinomial.
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/peterbell10/9/base #43423 +/- ##
========================================================
Coverage ? 69.31%
========================================================
Files ? 378
Lines ? 46747
Branches ? 0
========================================================
Hits ? 32405
Misses ? 14342
Partials ? 0 Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for fixing this!
@ngimel I believe the only place where broadcasting would help reduce memory footprint is in MyDist(...),log_prob(), but those already perform arithmetic operations like (-probs).log1p() so I believe max extra overhead of this PR would be +25%.
|
Can someone else please review the added |
I can review it after I finish reviewing #43422 |
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some not-uncommon use cases of masked_fill_ and index_put_ are idempotent. For example, I've seen user code out in the wild that tries to zero-out nans that does the following, but I don't know how many users actually try to do this on broadcasted tensors.
x[torch.isnan(x)] = 0
y.masked_fill_(torch.isnan(y), 0)
The other operators in this PR look fine, though.
My opinion is to not change masked_fill_ and index_put_, at least not in this PR, because it is a common paradigm to do x[x == blah] = value. If we do decide we want to fix masked_fill_ and index_put_, I think we should consider that change to be BC-breaking and have a deprecation cycle.
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
|
The original issue (#39639) was originally about |
| TORCH_WARN("Use of index_put_ on expanded tensors is deprecated. " | ||
| "Please clone() the tensor before performing this operation."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to get here is via advanced indexing. We should talk about that as well in the error message so that the user can connect these two ops. "Use of index_put_ on expanded tensors is deprecated. Please clone() the tensor before performing this operation. You may also encounter this error message when assigning items to a tensor using advanced indexing, that is also deprecated"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Added a note in the warning for index_put_ and also masked_fill_ as well.
Sure, let's do the deprecation warning in this PR then |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a mention of advanced indexing into the index_put_ deprecation. After that this should be good to go
…nctions" Differential Revision: [D23298652](https://our.internmc.facebook.com/intern/diff/D23298652) [ghstack-poisoned]
Stack from ghstack:
Differential Revision: D23298652