Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Jun 20, 2020

Reference: #38349 #9190

TODO

  • Add Tests
  • Update Docs

@kshitij12345 kshitij12345 changed the title torch.where : Scalar Support [WIP] torch.where : Scalar Support Jun 20, 2020
@dr-ci
Copy link

dr-ci bot commented Jun 20, 2020

💊 CI failures summary and remediations

As of commit ea11981 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 39 times.

@vadimkantorov
Copy link
Contributor

A lot of real-world usage cases can be found in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp

It would be good if they were refactored to not allocate zeros tensors.

Another question is that they often allocate indicator tensors (often it's like x == 1, x == -1, x > 0, x == 0), but without automatic fusion or introducing comparison op argument indicating on how to interpret a float tensor (or choosing one by default, e.g. 0 is 0, non-zero is 1, like in C), it's impossible to get rid of and maybe not worth it - also discussed in #9190

@vadimkantorov
Copy link
Contributor

Will type promotion be also be supported, i.e. using 0 or 1 in place of 0.0 and 1.0?

@kshitij12345
Copy link
Collaborator Author

@vadimkantorov

A lot of real-world usage cases can be found in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp
It would be good if they were refactored to not allocate zeros tensors.

Thanks! Would be a good test as well.

Another question is that they often allocate indicator tensors (often it's like x == 1, x == -1, x > 0, x == 0), but without automatic fusion or introducing comparison op argument indicating on how to interpret a float tensor (or choosing one by default, e.g. 0 is 0, non-zero is 1, like in C), it's impossible to get rid of and maybe not worth it - also discussed in #9190

Sounds interesting but surely out-of-scope for this PR.

Will type promotion be also be supported, i.e. using 0 or 1 in place of 0.0 and 1.0?

I am not planning to do it in this PR as it would also affect the behaviour of Tensor-Tensor overload.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 20, 2020

I am not planning to do it in this PR as it would also affect the behaviour of Tensor-Tensor overload.

Got it. for future, type promotion would be useful for tensor-tensor overload too (although, not as pressing, given the new scalar support)

auto margin_clamp = (margin - self).clamp_min_(0);
auto output_margin = at::where(target != 1, margin_clamp, zeros);
auto output_self = at::where(target != -1, self, zeros);
auto output_margin = at::where(target != 1, margin_clamp, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there not be problem because of int scalar type? margin_clamp has probably float tensor type
or will aten promote types here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, if there is one scalar and one tensor, the scalar's dtype is cast to the type of tensor, this felt natural when implementing. But looking back, I guess this isn't a good behaviour, as it may lead to promotion and demotion for the type of scalar based on tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, cool. So there is type promotion scalar -> tensor. For me that's what I would expect

@vadimkantorov
Copy link
Contributor

Also, not sure if torch.lerp currently supports scalars or not

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 20, 2020

Also, at some point there were problems of torch.where producing NaN gradients, it would be good to have a note in docs about current state of affairs (canonical erroring example is entropy computation in presence of 0 torch.where(p > 0, p*p.log(), 0).sum(-1)):
#18287 #23395

Many of these issues were circularly closed (one links another), the only currently open issue is: #23156

@kshitij12345 kshitij12345 marked this pull request as ready for review July 11, 2020 12:28
@kshitij12345 kshitij12345 changed the title [WIP] torch.where : Scalar Support torch.where : Scalar Support Jul 11, 2020
@kshitij12345
Copy link
Collaborator Author

@mruberry Please review :)

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 11, 2020

Will sth like

x = torch.rand(3, 4, 5).to('cuda')
y = torch.where(x > 0.5, x, torch.tensor(0))

work?
i.e. cpu->cuda scalar copy

@kshitij12345
Copy link
Collaborator Author

This PR doesn't actually touch the current version.

>>> a = torch.randn(3,4).to('cuda')
>>> torch.where(a > 0.5, a, torch.tensor(0.))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and cpu respectively

However after this PR

>>> a = torch.randn(3,4).to('cuda')
>>> torch.where(a > 0.5, a, 0)
tensor([[0.0000, 0.0000, 0.5484, 0.0000],
        [0.8243, 0.0000, 1.0878, 1.2405],
        [0.0000, 0.0000, 0.0000, 1.7502]], device='cuda:0')

@mruberry mruberry self-requested a review July 12, 2020 07:16
@mruberry mruberry added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Jul 12, 2020
@kshitij12345
Copy link
Collaborator Author

@mruberry
Gentle ping :)


namespace {

inline at::Tensor scalar_to_tensor_default_dtype(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add a comment explaining what this function is for

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I had a few suggestions for how to modernize this function, but now I have a different idea: what about copying and putting it next to "scalar_to_tensor?":

inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) {

Except the dtypes would be different, as you're acquiring them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have moved the function and added some comments.

Please review.

@mruberry mruberry self-requested a review July 21, 2020 17:56
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Nice work, @kshitij12345. I made a few small comments about code organization and docs. I'm curious to see how type promotion will affect the tests, and wondering if in the next iteration we could get some tests that compare directly with np.where, too.

Just ping me when this is ready.

@kshitij12345
Copy link
Collaborator Author

@mruberry

Have address the changes.
Do let me know if there is a better way to phrase the comment in ScalarOps.h.

For next iteration (type-promotion), we would most likely have tests directly against numpy!

@kshitij12345
Copy link
Collaborator Author

@mruberry Gentle Ping:)

1 similar comment
@kshitij12345
Copy link
Collaborator Author

@mruberry Gentle Ping:)

@mruberry
Copy link
Collaborator

@mruberry

Have address the changes.
Do let me know if there is a better way to phrase the comment in ScalarOps.h.

For next iteration (type-promotion), we would most likely have tests directly against numpy!

Sorry to keep you waiting, @kshitij12345, it's been a very busy week! Thank you for your patience.

Let's not hold this PR up anymore on small issues. We can look at improving that comment, if we want to, in the future.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

Internal tests flagged a few potential issues / lint problems, @kshitij12345. I'll sort through them tomorrow. Should be a few simple fixes.

@kshitij12345
Copy link
Collaborator Author

@mruberry
Sure. Thank You!
Do let me know if I can help with anything.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 31d41f9.

@kshitij12345 kshitij12345 deleted the develop/where/scalar branch July 31, 2020 08:19
@vadimkantorov
Copy link
Contributor

Also, at some point there were problems of torch.where producing NaN gradients, it would be good to have a note in docs about current state of affairs (canonical erroring example is entropy computation in presence of 0 torch.where(p > 0, p*p.log(), 0).sum(-1)):
#18287 #23395

Many of these issues were circularly closed (one links another), the only currently open issue is: #23156

Does the entropy case work now? Is torch.where NaN-safe for this case?

@mruberry
Copy link
Collaborator

mruberry commented Aug 4, 2020

Does the entropy case work now? Is torch.where NaN-safe for this case?

No. The behavior of torch.where is unchanged. This is just sugar allowing scalars to be interpreted as tensor arguments.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 4, 2020

Does the entropy case work now? Is torch.where NaN-safe for this case?

No. The behavior of torch.where is unchanged. This is just sugar allowing scalars to be interpreted as tensor arguments.

It would be good to have this as a test, even if it's known to fail at the moment. A note in docs would also be good. This is a very common mistake (and entropy computation usecase specifically is a good/common example of this I think) and source of surprises...

import torch

def where_is_not_nan_safe():
  p = torch.zeros(4).requires_grad_()
  e = torch.where(p > 0, p*p.log(), torch.zeros_like(p)).sum(-1)
  assert torch.isnan(torch.autograd.grad(e, (p,))[0]).all()

@mruberry
Copy link
Collaborator

mruberry commented Aug 4, 2020

The linked issue is less about torch.where and more about how gradients are represented, so I don't think we'd take a test for this in particular.

@vadimkantorov
Copy link
Contributor

You're right. From current gradient processing + no special processing in torch.where, it's how things are. But if torch.where does special processing, this could be a way around current gradient processing. All I wanted to point out is that trying to implement entropy and work around the naive p * p.log() with torch.where is quite frequent (and failing currently), to the point that it might deserve a note in docs, and at best be fixed by special processing in torch.where (if technically possible).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: numpy Related to numpy support, and also numpy compatibility of our operators open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants