-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Raise error on device mismatch in addmm #43505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit e303390 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 8 times. |
|
letting @ngimel handle this one |
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks, I have minor comments
| auto m2_strides = m2.strides(); | ||
| auto m2_sizes = m2.sizes(); | ||
|
|
||
| TORCH_CHECK(self.device() == kCPU && m1.device() == kCPU && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, multidispatch will always send to cuda if one of the tensors is cuda, so this is not needed
0a32767 to
d53064b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes gh-42282
This adds a device-mismatch check to
addmmon CPU and CUDA. Although it seems like the dispatcher is always selecting the CUDA version here if any of the inputs are on GPU. So in theory the CPU check is unnecessary, but probably better to err on the side of caution.