Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@

- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: mse_loss_backward(grad, self, target, reduction)
target: mse_loss_backward(grad, target, self, reduction)

- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
Expand Down Expand Up @@ -1554,6 +1555,7 @@
- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: mse_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
self: mse_loss_double_backward(grad * grad_output, self, reduction)
target: -mse_loss_double_backward(grad * grad_output, target, reduction)

- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
grad_output: nll_loss(grad, target, weight, reduction, ignore_index)
Expand Down
23 changes: 7 additions & 16 deletions torch/csrc/api/include/torch/nn/functional/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,13 @@ inline Tensor mse_loss(
"This will likely lead to incorrect results due to broadcasting. ",
"Please ensure they have the same size.");
}
torch::Tensor ret;
if (target.requires_grad()) {
ret = torch::pow(input - target, 2);
if (!c10::get_if<enumtype::kNone>(&reduction)) {
ret = c10::get_if<enumtype::kMean>(&reduction) ? torch::mean(ret) : torch::sum(ret);
}
} else {
std::vector<torch::Tensor> broadcast_tensors = torch::broadcast_tensors({input, target});
auto expanded_input = broadcast_tensors[0];
auto expanded_target = broadcast_tensors[1];
ret = torch::mse_loss(
expanded_input,
expanded_target,
enumtype::reduction_get_enum(reduction));
}
return ret;
std::vector<torch::Tensor> broadcast_tensors = torch::broadcast_tensors({input, target});
auto expanded_input = broadcast_tensors[0];
auto expanded_target = broadcast_tensors[1];
return torch::mse_loss(
expanded_input,
expanded_target,
enumtype::reduction_get_enum(reduction));
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
Expand Down
12 changes: 3 additions & 9 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2661,15 +2661,9 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = (input - target) ** 2
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
else:
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
return ret

expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))


def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
Expand Down
17 changes: 7 additions & 10 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3955,7 +3955,7 @@ def padding3d_circular(input, pad):
dict(
module_name='MSELoss',
input_size=(2, 3, 4, 5),
target_size=(2, 3, 4, 5),
target_fn=lambda: torch.randn((2, 3, 4, 5), requires_grad=True),
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
if get_reduction(m) == 'mean' else 1)),
check_sum_reduction=True,
Expand Down Expand Up @@ -4302,7 +4302,7 @@ def padding3d_circular(input, pad):
dict(
module_name='MSELoss',
input_size=(),
target_size=(),
target_fn=lambda: torch.randn((), requires_grad=True),
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
(i.numel() if get_reduction(m) == 'mean' else 1)),
check_sum_reduction=True,
Expand Down Expand Up @@ -5070,8 +5070,6 @@ def __call__(self, test_case):
self._do_extra_tests(test_case, module, input, target)

def _do_extra_tests(self, test_case, module, input, target):
test_case.assertFalse(target.requires_grad)

params = tuple(x for x in module.parameters())
if not isinstance(input, tuple):
inputs = (input,) + params
Expand All @@ -5084,14 +5082,13 @@ def apply_fn(input, *params):
def apply_fn(input1, input2, *params):
return module(input1, input2, target)

# TODO: we don't pass `target` as part of inputs because we don't
# currently compute the gradient w.r.t. target for loss functions.
gradcheck(apply_fn, inputs)
if target.requires_grad:
inputs = inputs + (target,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works but it is a bit fragile.
Your apply_fn above actually ignores the target that is given in *params. But it works because it captures the same target from the parent scope. And the gradcheck ensures that you can do that. But I think it would be clearer to just get the target from the apply_fn inputs properly no?


if not self.check_gradgrad:
return
gradcheck(apply_fn, inputs)

gradgradcheck(apply_fn, inputs)
if self.check_gradgrad:
gradgradcheck(apply_fn, inputs)

def test_cuda(self, test_case, dtype=None, extra_args=None):
def convert_dtype(obj, dtype, requires_grad=False):
Expand Down