Skip to content

Commit 3de2c0b

Browse files
gchananfacebook-github-bot
authored andcommitted
Fix L1Loss when target.requires_grad is True. (#44471)
Summary: Pull Request resolved: #44471 L1Loss had a completely different (and incorrect, see #43228) path when target.requires_grad was True. This PR does the following: 1) adds derivative support for target via the normal derivatives.yaml route 2) kill the different (and incorrect) path for when target.requires_grad was True 3) modify the L1Loss CriterionTests to verify that the target derivative is checked. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D23626008 Pulled By: gchanan fbshipit-source-id: 2828be16b56b8dabe114962223d71b0e9a85f0f5
1 parent ea55820 commit 3de2c0b

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,7 @@
12021202

12031203
- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12041204
self: l1_loss_backward(grad, self, target, reduction)
1205+
target: l1_loss_backward(grad, target, self, reduction)
12051206

12061207
- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12071208
self: mse_loss_backward(grad, self, target, reduction)
@@ -1520,6 +1521,7 @@
15201521
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
15211522
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
15221523
self: zeros_like(grad, at::MemoryFormat::Preserve)
1524+
target: zeros_like(grad, at::MemoryFormat::Preserve)
15231525

15241526
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
15251527
grad_output: log_sigmoid_backward(grad, self, buffer)

torch/nn/functional.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,15 +2629,10 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
26292629
stacklevel=2)
26302630
if size_average is not None or reduce is not None:
26312631
reduction = _Reduction.legacy_get_string(size_average, reduce)
2632-
if target.requires_grad:
2633-
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
2634-
ret = torch.abs(input - target)
2635-
if reduction != 'none':
2636-
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
2637-
else:
2638-
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2639-
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2640-
return ret
2632+
2633+
2634+
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2635+
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
26412636

26422637

26432638
def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):

torch/testing/_internal/common_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,7 +3879,7 @@ def padding3d_circular(input, pad):
38793879
dict(
38803880
module_name='L1Loss',
38813881
input_size=(2, 3, 4),
3882-
target_size=(2, 3, 4),
3882+
target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
38833883
reference_fn=lambda i, t, _: 1. / i.numel() *
38843884
sum((a - b).abs().sum() for a, b in zip(i, t)),
38853885
),
@@ -4277,7 +4277,7 @@ def padding3d_circular(input, pad):
42774277
dict(
42784278
module_name='L1Loss',
42794279
input_size=(),
4280-
target_size=(),
4280+
target_fn=lambda: torch.randn((), requires_grad=True),
42814281
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
42824282
desc='scalar',
42834283
),

0 commit comments

Comments
 (0)