Skip to content

Commit d07d25a

Browse files
gchananfacebook-github-bot
authored andcommitted
Fix MSELoss when target.requires_grad is True. (#44437)
Summary: Pull Request resolved: #44437 MSELoss had a completely different (and incorrect, see #43228) path when target.requires_grad was True. This PR does the following: 1) adds derivative support for target via the normal derivatives.yaml route 2) kill the different (and incorrect) path for when target.requires_grad was True 3) modify the MSELoss CriterionTests to verify that the target derivative is checked. TODO: 1) do we still need check_criterion_jacobian when we run grad/gradgrad checks? 2) ensure the Module tests check when target.requires_grad 3) do we actually test when reduction='none' and reduction='mean'? Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D23612166 Pulled By: gchanan fbshipit-source-id: 4f74d38d8a81063c74e002e07fbb7837b2172a10
1 parent 9a3b83c commit d07d25a

File tree

4 files changed

+19
-35
lines changed

4 files changed

+19
-35
lines changed

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@
12051205

12061206
- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12071207
self: mse_loss_backward(grad, self, target, reduction)
1208+
target: mse_loss_backward(grad, target, self, reduction)
12081209

12091210
- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
12101211
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
@@ -1551,6 +1552,7 @@
15511552
- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
15521553
grad_output: mse_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
15531554
self: mse_loss_double_backward(grad * grad_output, self, reduction)
1555+
target: -mse_loss_double_backward(grad * grad_output, target, reduction)
15541556

15551557
- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
15561558
grad_output: nll_loss(grad, target, weight, reduction, ignore_index)

torch/csrc/api/include/torch/nn/functional/loss.h

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,13 @@ inline Tensor mse_loss(
107107
"This will likely lead to incorrect results due to broadcasting. ",
108108
"Please ensure they have the same size.");
109109
}
110-
torch::Tensor ret;
111-
if (target.requires_grad()) {
112-
ret = torch::pow(input - target, 2);
113-
if (!c10::get_if<enumtype::kNone>(&reduction)) {
114-
ret = c10::get_if<enumtype::kMean>(&reduction) ? torch::mean(ret) : torch::sum(ret);
115-
}
116-
} else {
117-
std::vector<torch::Tensor> broadcast_tensors = torch::broadcast_tensors({input, target});
118-
auto expanded_input = broadcast_tensors[0];
119-
auto expanded_target = broadcast_tensors[1];
120-
ret = torch::mse_loss(
121-
expanded_input,
122-
expanded_target,
123-
enumtype::reduction_get_enum(reduction));
124-
}
125-
return ret;
110+
std::vector<torch::Tensor> broadcast_tensors = torch::broadcast_tensors({input, target});
111+
auto expanded_input = broadcast_tensors[0];
112+
auto expanded_target = broadcast_tensors[1];
113+
return torch::mse_loss(
114+
expanded_input,
115+
expanded_target,
116+
enumtype::reduction_get_enum(reduction));
126117
}
127118
} // namespace detail
128119
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

torch/nn/functional.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2661,15 +2661,9 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
26612661
stacklevel=2)
26622662
if size_average is not None or reduce is not None:
26632663
reduction = _Reduction.legacy_get_string(size_average, reduce)
2664-
if target.requires_grad:
2665-
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
2666-
ret = (input - target) ** 2
2667-
if reduction != 'none':
2668-
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
2669-
else:
2670-
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2671-
ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2672-
return ret
2664+
2665+
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2666+
return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
26732667

26742668

26752669
def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,

torch/testing/_internal/common_nn.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,7 +3955,7 @@ def padding3d_circular(input, pad):
39553955
dict(
39563956
module_name='MSELoss',
39573957
input_size=(2, 3, 4, 5),
3958-
target_size=(2, 3, 4, 5),
3958+
target_fn=lambda: torch.randn((2, 3, 4, 5), requires_grad=True),
39593959
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
39603960
if get_reduction(m) == 'mean' else 1)),
39613961
check_sum_reduction=True,
@@ -4302,7 +4302,7 @@ def padding3d_circular(input, pad):
43024302
dict(
43034303
module_name='MSELoss',
43044304
input_size=(),
4305-
target_size=(),
4305+
target_fn=lambda: torch.randn((), requires_grad=True),
43064306
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
43074307
(i.numel() if get_reduction(m) == 'mean' else 1)),
43084308
check_sum_reduction=True,
@@ -5070,8 +5070,6 @@ def __call__(self, test_case):
50705070
self._do_extra_tests(test_case, module, input, target)
50715071

50725072
def _do_extra_tests(self, test_case, module, input, target):
5073-
test_case.assertFalse(target.requires_grad)
5074-
50755073
params = tuple(x for x in module.parameters())
50765074
if not isinstance(input, tuple):
50775075
inputs = (input,) + params
@@ -5084,14 +5082,13 @@ def apply_fn(input, *params):
50845082
def apply_fn(input1, input2, *params):
50855083
return module(input1, input2, target)
50865084

5087-
# TODO: we don't pass `target` as part of inputs because we don't
5088-
# currently compute the gradient w.r.t. target for loss functions.
5089-
gradcheck(apply_fn, inputs)
5085+
if target.requires_grad:
5086+
inputs = inputs + (target,)
50905087

5091-
if not self.check_gradgrad:
5092-
return
5088+
gradcheck(apply_fn, inputs)
50935089

5094-
gradgradcheck(apply_fn, inputs)
5090+
if self.check_gradgrad:
5091+
gradgradcheck(apply_fn, inputs)
50955092

50965093
def test_cuda(self, test_case, dtype=None, extra_args=None):
50975094
def convert_dtype(obj, dtype, requires_grad=False):

0 commit comments

Comments
 (0)