Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,7 @@

- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: l1_loss_backward(grad, self, target, reduction)
target: l1_loss_backward(grad, target, self, reduction)

- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: mse_loss_backward(grad, self, target, reduction)
Expand Down Expand Up @@ -1523,6 +1524,7 @@
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
self: zeros_like(grad, at::MemoryFormat::Preserve)
target: zeros_like(grad, at::MemoryFormat::Preserve)

- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
Expand Down
13 changes: 4 additions & 9 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,15 +2629,10 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = torch.abs(input - target)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
else:
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
return ret


expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))


def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3879,7 +3879,7 @@ def padding3d_circular(input, pad):
dict(
module_name='L1Loss',
input_size=(2, 3, 4),
target_size=(2, 3, 4),
target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
),
Expand Down Expand Up @@ -4277,7 +4277,7 @@ def padding3d_circular(input, pad):
dict(
module_name='L1Loss',
input_size=(),
target_size=(),
target_fn=lambda: torch.randn((), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
desc='scalar',
),
Expand Down