-
Notifications
You must be signed in to change notification settings - Fork 26.3k
torch.where : Scalar Support #40336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.where : Scalar Support #40336
Conversation
💊 CI failures summary and remediationsAs of commit ea11981 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 39 times. |
|
A lot of real-world usage cases can be found in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp It would be good if they were refactored to not allocate zeros tensors. Another question is that they often allocate indicator tensors (often it's like |
|
Will type promotion be also be supported, i.e. using 0 or 1 in place of 0.0 and 1.0? |
Thanks! Would be a good test as well.
Sounds interesting but surely out-of-scope for this PR.
I am not planning to do it in this PR as it would also affect the behaviour of Tensor-Tensor overload. |
Got it. for future, type promotion would be useful for tensor-tensor overload too (although, not as pressing, given the new scalar support) |
aten/src/ATen/native/Loss.cpp
Outdated
| auto margin_clamp = (margin - self).clamp_min_(0); | ||
| auto output_margin = at::where(target != 1, margin_clamp, zeros); | ||
| auto output_self = at::where(target != -1, self, zeros); | ||
| auto output_margin = at::where(target != 1, margin_clamp, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will there not be problem because of int scalar type? margin_clamp has probably float tensor type
or will aten promote types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, if there is one scalar and one tensor, the scalar's dtype is cast to the type of tensor, this felt natural when implementing. But looking back, I guess this isn't a good behaviour, as it may lead to promotion and demotion for the type of scalar based on tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, cool. So there is type promotion scalar -> tensor. For me that's what I would expect
|
Also, not sure if torch.lerp currently supports scalars or not |
|
Also, at some point there were problems of torch.where producing NaN gradients, it would be good to have a note in docs about current state of affairs (canonical erroring example is entropy computation in presence of 0 Many of these issues were circularly closed (one links another), the only currently open issue is: #23156 |
|
@mruberry Please review :) |
|
Will sth like x = torch.rand(3, 4, 5).to('cuda')
y = torch.where(x > 0.5, x, torch.tensor(0))work? |
|
This PR doesn't actually touch the current version. >>> a = torch.randn(3,4).to('cuda')
>>> torch.where(a > 0.5, a, torch.tensor(0.))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and cpu respectivelyHowever after this PR >>> a = torch.randn(3,4).to('cuda')
>>> torch.where(a > 0.5, a, 0)
tensor([[0.0000, 0.0000, 0.5484, 0.0000],
[0.8243, 0.0000, 1.0878, 1.2405],
[0.0000, 0.0000, 0.0000, 1.7502]], device='cuda:0') |
|
@mruberry |
|
|
||
| namespace { | ||
|
|
||
| inline at::Tensor scalar_to_tensor_default_dtype( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better add a comment explaining what this function is for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I had a few suggestions for how to modernize this function, but now I have a different idea: what about copying and putting it next to "scalar_to_tensor?":
pytorch/aten/src/ATen/ScalarOps.h
Line 12 in 2b02d15
| inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { |
Except the dtypes would be different, as you're acquiring them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have moved the function and added some comments.
Please review.
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Nice work, @kshitij12345. I made a few small comments about code organization and docs. I'm curious to see how type promotion will affect the tests, and wondering if in the next iteration we could get some tests that compare directly with np.where, too.
Just ping me when this is ready.
* fix doc for argument * fix example
|
Have address the changes. For next iteration (type-promotion), we would most likely have tests directly against numpy! |
|
@mruberry Gentle Ping:) |
1 similar comment
|
@mruberry Gentle Ping:) |
Sorry to keep you waiting, @kshitij12345, it's been a very busy week! Thank you for your patience. Let's not hold this PR up anymore on small issues. We can look at improving that comment, if we want to, in the future. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Internal tests flagged a few potential issues / lint problems, @kshitij12345. I'll sort through them tomorrow. Should be a few simple fixes. |
|
@mruberry |
Does the entropy case work now? Is torch.where NaN-safe for this case? |
No. The behavior of |
It would be good to have this as a test, even if it's known to fail at the moment. A note in docs would also be good. This is a very common mistake (and entropy computation usecase specifically is a good/common example of this I think) and source of surprises... import torch
def where_is_not_nan_safe():
p = torch.zeros(4).requires_grad_()
e = torch.where(p > 0, p*p.log(), torch.zeros_like(p)).sum(-1)
assert torch.isnan(torch.autograd.grad(e, (p,))[0]).all() |
|
The linked issue is less about |
|
You're right. From current gradient processing + no special processing in torch.where, it's how things are. But if torch.where does special processing, this could be a way around current gradient processing. All I wanted to point out is that trying to implement entropy and work around the naive |
Reference: #38349 #9190
TODO