Skip to content

Commit f18ac1b

Browse files
authored
Release version of fixed nll decomp (#95853)
* fix nll loss decomposition to properly ignore ignore_index * remove branch
1 parent c04134c commit f18ac1b

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

test/inductor/test_torchinductor.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,18 +4127,38 @@ def fn(a, b):
41274127

41284128
self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
41294129

4130-
@requires_decomp(aten.nll_loss_forward)
41314130
def test_nll_loss_forward(self):
41324131
def fn(a, b):
41334132
return aten.nll_loss_forward(a, b, None, 1, -100)
41344133

4135-
self.common(
4136-
fn,
4137-
(
4138-
torch.randn([5, 5]),
4139-
torch.zeros([5], dtype=torch.int64),
4140-
),
4134+
labels = (
4135+
torch.zeros([5], dtype=torch.int64),
4136+
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
41414137
)
4138+
inps = (torch.randn(5, 5), torch.randn(5, 5))
4139+
for a, b in zip(inps, labels):
4140+
self.common(
4141+
fn,
4142+
(a, b),
4143+
)
4144+
4145+
def test_nll_loss_backward(self):
4146+
def fn(a, b, c):
4147+
return aten.nll_loss_backward(
4148+
a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
4149+
)
4150+
4151+
labels = (
4152+
torch.zeros([5], dtype=torch.int64),
4153+
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
4154+
)
4155+
inps = (torch.randn(5, 5), torch.randn(5, 5))
4156+
grad_outs = (torch.randn(()), torch.randn(()))
4157+
for a, b, c in zip(grad_outs, inps, labels):
4158+
self.common(
4159+
fn,
4160+
(a, b, c),
4161+
)
41424162

41434163
def test_isinf(self):
41444164
def fn(x):

torch/_decomp/decompositions.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,9 @@ def _nll_loss_backward(
405405
grad_output = grad_output / total_weight
406406

407407
target = target.unsqueeze(channel_dim)
408+
safe_target = torch.where(target != ignore_index, target, 0)
408409
grad_input = torch.zeros_like(self)
409-
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
410+
grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
410411

411412
if grad_input.dim() > grad_output.dim() > 0:
412413
grad_output = grad_output.unsqueeze(channel_dim)
@@ -417,9 +418,7 @@ def _nll_loss_backward(
417418
weight = weight.reshape(new_shape)
418419
grad_output = grad_output * weight
419420

420-
has_ignore_index = ignore_index >= 0
421-
if has_ignore_index:
422-
grad_output = torch.where(target != ignore_index, grad_output, 0)
421+
grad_output = torch.where(target != ignore_index, grad_output, 0)
423422

424423
return grad_input * grad_output
425424

@@ -2798,37 +2797,30 @@ def nll_loss_forward(
27982797
if weight is not None:
27992798
w = weight.unsqueeze(0) if n_dims > 1 else weight
28002799
self = self * w
2801-
2802-
target_ = target.unsqueeze(channel_dim)
2800+
safe_target = torch.where(target != ignore_index, target, 0)
2801+
safe_target_ = safe_target.unsqueeze(channel_dim)
28032802
# target can be [N, 1] or [1]
28042803

2805-
result = -torch.gather(self, channel_dim, target_).squeeze(channel_dim)
2804+
result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
28062805

2807-
if ignore_index >= 0:
2808-
result = torch.where(target != ignore_index, result, 0)
2806+
result = torch.where(target != ignore_index, result, 0)
28092807

28102808
if reduction == Reduction.NONE.value and n_dims > 1:
28112809
total_weight = self.new_full((), 0.0)
28122810
return result, total_weight
28132811

28142812
if weight is not None:
28152813
w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
2816-
wsum = torch.gather(w, channel_dim, target_).squeeze(channel_dim)
2817-
if ignore_index >= 0:
2818-
wsum = torch.where(target != ignore_index, wsum, 0)
2814+
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
2815+
wsum = torch.where(target != ignore_index, wsum, 0)
28192816
total_weight = wsum.sum()
2820-
elif ignore_index >= 0:
2821-
total_weight = (target != ignore_index).sum().to(self)
28222817
else:
2823-
total_weight = self.new_full((), 1.0 * result.numel())
2818+
total_weight = (target != ignore_index).sum().to(self)
28242819

28252820
if reduction == Reduction.SUM.value:
28262821
result = result.sum()
28272822
elif reduction == Reduction.MEAN.value:
2828-
if weight is None:
2829-
result = result.sum() / total_weight if ignore_index >= 0 else result.mean()
2830-
else:
2831-
result = result.sum() / total_weight
2823+
result = result.sum() / total_weight
28322824

28332825
return result, total_weight
28342826

0 commit comments

Comments
 (0)