-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[fix] output of embedding_bag with non-contiguous weight
#44032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] output of embedding_bag with non-contiguous weight
#44032
Conversation
💊 CI failures summary and remediationsAs of commit fe6961b (more details on the Dr. CI page):
ci.pytorch.org: 2 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 8 times. |
Codecov Report
@@ Coverage Diff @@
## master #44032 +/- ##
=======================================
Coverage 69.29% 69.29%
=======================================
Files 381 381
Lines 47214 47214
=======================================
+ Hits 32717 32718 +1
+ Misses 14497 14496 -1
Continue to review full report at Codecov.
|
glaringlee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests.
I have a nit comment, please take a look at.
* use method chaining.
|
@kshitij12345 Thanks a lot for fixing this. i will approve this once the CI test is done without issues. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
glaringlee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now.
|
@glaringlee Gentle Ping :) |
|
@kshitij12345 |
|
@glaringlee Oops I forgot it was Federal Holiday. Thanks! |
|
@glaringlee merged this pull request in 6dd53fb. |
| def test_embedding_bag_non_contiguous_weight(self, device, dtype): | ||
| weight_tensor = torch.randn(4, 3, dtype=dtype, device=device) | ||
|
|
||
| weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is contiguous tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! Great catch!!
Was supposed to be
weight_tensor = torch.randn(3, 4, dtype=dtype, device=device)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kshitij12345 I'll put a fix for u.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Thanks! Btw I am free this evening so I can put it up as well. Let me know if I should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, no worries. I have a clean pytorch repo on hand, will patch this shortly.
|
#44382 is entered for patching this. |
| auto* output_data = output.data_ptr<float>(); | ||
|
|
||
| if (isFastPathIndexSelect(src, output)) { | ||
| auto* src_data = src.contiguous().data_ptr<float>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, this is still a bug, right? If the tensor is non-contiguous then contiguous() will return a new tensor and it will be immediately destroyed (because we don't keep a reference to it around). So src_data will point to the deallocated memory :(
I wonder why ASAN doesn't catch it.
It should be
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<float>();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzhulgakov oh, shoot.........my bad, will fix soon.
fix dangling ptr in #44032 Differential Revision: [D23661007](https://our.internmc.facebook.com/intern/diff/D23661007) [ghstack-poisoned]
Fixes #43723
use weight.contiguous on fast-path as it expects contiguous tensor.
TODO: