Skip to content

Commit 34ec924

Browse files
committed
Fix L1Loss when target.requires_grad is True.
L1Loss had a completely different (and incorrect, see #43228) path when target.requires_grad was True. This PR does the following: 1) adds derivative support for target via the normal derivatives.yaml route 2) kill the different (and incorrect) path for when target.requires_grad was True 3) modify the L1Loss CriterionTests to verify that the target derivative is checked. ghstack-source-id: c8664d9 Pull Request resolved: #44471
1 parent 3849a07 commit 34ec924

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@
12051205

12061206
- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12071207
self: l1_loss_backward(grad, self, target, reduction)
1208+
target: l1_loss_backward(grad, target, self, reduction)
12081209

12091210
- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12101211
self: mse_loss_backward(grad, self, target, reduction)
@@ -1523,6 +1524,7 @@
15231524
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
15241525
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
15251526
self: zeros_like(grad, at::MemoryFormat::Preserve)
1527+
target: zeros_like(grad, at::MemoryFormat::Preserve)
15261528

15271529
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
15281530
grad_output: log_sigmoid_backward(grad, self, buffer)

torch/nn/functional.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,15 +2629,10 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
26292629
stacklevel=2)
26302630
if size_average is not None or reduce is not None:
26312631
reduction = _Reduction.legacy_get_string(size_average, reduce)
2632-
if target.requires_grad:
2633-
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
2634-
ret = torch.abs(input - target)
2635-
if reduction != 'none':
2636-
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
2637-
else:
2638-
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2639-
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2640-
return ret
2632+
2633+
2634+
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2635+
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
26412636

26422637

26432638
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):

torch/testing/_internal/common_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,7 +3879,7 @@ def padding3d_circular(input, pad):
38793879
dict(
38803880
module_name='L1Loss',
38813881
input_size=(2, 3, 4),
3882-
target_size=(2, 3, 4),
3882+
target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
38833883
reference_fn=lambda i, t, _: 1. / i.numel() *
38843884
sum((a - b).abs().sum() for a, b in zip(i, t)),
38853885
),
@@ -4277,7 +4277,7 @@ def padding3d_circular(input, pad):
42774277
dict(
42784278
module_name='L1Loss',
42794279
input_size=(),
4280-
target_size=(),
4280+
target_fn=lambda: torch.randn((), requires_grad=True),
42814281
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
42824282
desc='scalar',
42834283
),

0 commit comments

Comments
 (0)