Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Aug 24, 2020

[bc-breaking note]. Previously, in case there were multiple max/min/median elements with the same value, gradient propagated only to the first element with this value, now gradient is evenly distributed between all the elements. This results in a minimum subnorm gradient.
cc: @ngimel @mruberry

return grad_input;
Tensor evenly_dispatch_backward(Tensor grad, const Tensor & input, const Tensor & value) {
auto mask = (input == value);
auto count = mask.sum(input.scalar_type());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tensors with >2^24 elements that all happen to be max/min/median, this is going to be inaccurate, so maybe better leave in int64 or double, depending on which is faster?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a scalar, so I don't think it makes any difference on int64 vs fp64. Let's just use the default (int64).

Tensor evenly_dispatch_backward(Tensor grad, const Tensor & input, const Tensor & value) {
auto mask = (input == value);
auto count = mask.sum(input.scalar_type());
return at::zeros_like(input).masked_fill_(mask, grad / count);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be faster to do mask.to(input.scalar_type()) * (grad / count) here?
Also it might be worth it to special case when count=1 where we would be able to do something more efficient than masked_fill_().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, not even mask.to because TensorIterator would take care of it, and on cuda type promotion will be implicit. grad/count has to be converted to input.scalar_type for this to work though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether to choose * vs masked_fill depend on CPU vs CUDA:

import torch
for device in ['cpu', 'cuda']:
    t = torch.torch.randint(100, (1024, 1024, 64), device=device, dtype=torch.float)
    s = t.sum()
    m = t.max()
    mask = (t == m)
    go = t.new_tensor(1.)
    torch.cuda.synchronize()
    print(device)
    %timeit mask * (go / s); torch.cuda.synchronize() if device == 'cuda' else None
    %timeit torch.zeros_like(t).masked_fill_(mask, go / s); torch.cuda.synchronize() if device == 'cuda' else None
cpu
44.5 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
25.2 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cuda
848 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.08 ms ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

so there should be some logic

if (cuda) {
  *
} else {
  masked_fill
}

and I don't think it should have a separate case for s == 1, because the only difference is grad / count vs grad, and grad is a scalar tensor. scalar_tensor.item() is not faster than scalar_tensor / another_scalar_tensor.

@zasdfgbnm
Copy link
Collaborator Author

Should be ready now. See my replies to reviews.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 25, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zasdfgbnm zasdfgbnm mentioned this pull request Aug 25, 2020
@zasdfgbnm zasdfgbnm deleted the grad-to-all branch August 25, 2020 23:42
@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 348e78b.

@ngimel ngimel added the module: bc-breaking Related to a BC-breaking change label Aug 26, 2020
@michaelklachko
Copy link

@zasdfgbnm @albanD is this also the case for kthvalue op?

@albanD
Copy link
Collaborator

albanD commented Oct 30, 2020

kthvalue is different as it returns indices (just like max(dim=)) and so to be consistent, these ops only returns gradients for the index that was chosen during the forward.
So no I don't think this applies to the current kthvalue function.
If we had a version of kthvalue that would do a full reduction and not return indices, then yes that would apply.

@michaelklachko
Copy link

I see. What what the reason to fix this issue? Are there examples of where the old behavior could lead to training instability? What should I do if I have to backprop through kthvalue during training and I want to spread the gradients evenly?

@albanD
Copy link
Collaborator

albanD commented Oct 30, 2020

Hi,

Instability is a strong word but unexpected behaviors. In particular because the value that was getting all the gradient was chosen in a non-deterministic way and could change across devices.

But also from a more principled point of view. When computing subgradients, we prefer the one with minimum norm (as it is always a descent direction). So in this case, the even distribution across all inputs that realize the value.

To do the same thing with kthvalue I guess you will need to do it yourself:

import torch
from torch import autograd

# No dim can be specified an only full reduction is done
class MyKthvalue(autograd.Function):
    @staticmethod
    def forward(ctx, inp, k):
        res = inp.contiguous().view(-1).kthvalue(k).values
        ctx.save_for_backward(inp, res)
        return res

    @staticmethod
    def backward(ctx, gO):
        inp, res = ctx.saved_tensors
        mask = (inp == res)
        count = mask.sum()
        gO = gO / count
        return mask * gO, None


a = torch.randint(0, 10, (4, 4, 4), dtype=torch.float, requires_grad=True)
k = 3

MyKthvalue.apply(a, k).sum().backward()

print(a)
print(a.grad)

@chanshing
Copy link

Does this change ReLU and related layers that depend on the max function? If so, what are the implications?

@zasdfgbnm
Copy link
Collaborator Author

@chanshing ReLU are not changed

pytorchmergebot pushed a commit that referenced this pull request Jun 15, 2025
Fixes #155048

The behavior of `min` and `max` were changed in #43519. The note about gradient behavior in torch.amin and torch.amax docs are updated to reflect this change:

New note:
`amax, amin, max(dim), min(dim) evenly distributes gradient between equal values
        when there are multiple input elements with the same minimum or maximum value.`

cc - @spzala @svekars @soulitzer @sekyondaMeta @AlannaBurke @ezyang @gqchen @nikitaved @Varal7 @xmfan
Pull Request resolved: #155071
Approved by: https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants