-
Notifications
You must be signed in to change notification settings - Fork 26.3k
checking device types of input and weights at RNN #10185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
weiyangfb
commented
Aug 2, 2018
- fixes CPU hidden state tensor in GPU lstm layer causes CUDA corruption #9534
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
1 similar comment
|
@pytorchbot retest this please |
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I wonder if those checks shouldn't be in the functional interface instead?
e5d097e to
2658720
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
aten/src/ATen/native/RNN.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Also, it might be a better idea to push the checks to the cuDNN path, because otherwise we'll end up repeating them later anyway in the autograd code. |
|
@apaszke I moved the checks to cuDNN path, and also keep those at non-cuDNN path |
aten/src/ATen/native/RNN.h
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused now. What I meant is that the device checks are really strictly necessary only in the cuDNN path, but what you did in here is to add the in both paths, and made the cuDNN path pass through checks twice.
aten/src/ATen/native/cudnn/RNN.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
e617a89 to
e6ce1a0
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@apaszke I think it is reasonable to check device consistency in both cudnn and noncudnn code though. |
|
@ssnl but the devices will be checked in the native path anyway, since every single function we call will verify them. |
aten/src/ATen/native/RNN.h
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/RNN.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/RNN.h
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
#11680 is also related. |
|
@apaszke Yes, I agree that the noncudnn path check is redundant. But it would be nice to give users a better error message. I'm fine with either actually. |
6ee12bb to
04491ff
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
aten/src/ATen/native/cudnn/RNN.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ssnl note that the error message we can give in the C++ API is not very helpful anyway. The weights have very complex names in Python, and I don't think we'll be able to reproduce them easily. |
723b058 to
26e166e
Compare
2. add check_device() function
26e166e to
e8fff55
Compare
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to clean up check_tensors to use at::Device instead of unnecessarily comparing everything manually.
aten/src/ATen/native/RNN.h
Outdated
| auto check_tensors = [&](const std::string& name, const Tensor& t) { | ||
| if (!t.defined()) return; | ||
| auto t_device = t.device(); | ||
| bool t_device_is_cuda = t_device.is_cuda(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/RNN.h
Outdated
| } | ||
|
|
||
| for (auto p : params) { | ||
| // if (!p.defined()) continue; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - fixes #9534 Pull Request resolved: pytorch/pytorch#10185 Differential Revision: D9141222 Pulled By: weiyangfb fbshipit-source-id: bb652e42cc15917019df080d6bce2926b18f3476