Skip to content

Commit 97b7e4c

Browse files
malfetpytorchmergebot
authored andcommitted
Fix GroupNorm backward prop on CUDA (#92671)
Fixes regression introduced by #89485 Adds test to prevent those regressions from happening in the future In process, discovered that GroupNormBackwards on CPU does not produce the same results if input and gradient memory_format is different Fixes #92166 Pull Request resolved: #92671 Approved by: https://github.com/ngimel, https://github.com/xuzhao9
1 parent 8c0289a commit 97b7e4c

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

test/test_nn.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8229,6 +8229,34 @@ def helper(self, size, groups, memory_format, is_mixed):
82298229
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, True)
82308230
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, True)
82318231

8232+
@onlyNativeDeviceTypes
8233+
def test_GroupNorm_memory_format(self, device):
8234+
# Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166
8235+
8236+
def helper(input_format, grad_format, B=2, C=4, W=4, H=4):
8237+
import copy
8238+
net_orig = torch.nn.GroupNorm(B, C).to(device=device)
8239+
net = copy.deepcopy(net_orig)
8240+
x_orig = torch.rand(B, C, W, H, device=device, requires_grad=True)
8241+
grad_orig = torch.rand(B, C, W, H, device=device)
8242+
x = x_orig.clone().detach().to(memory_format=input_format).requires_grad_(True)
8243+
grad = grad_orig.detach().to(memory_format=grad_format)
8244+
8245+
y = net(x)
8246+
y.backward(grad)
8247+
8248+
y_orig = net_orig(x_orig)
8249+
y_orig.backward(grad_orig)
8250+
8251+
self.assertEqual(y, y_orig)
8252+
# TODO: Fix me, CPU should produce valid results here, but it is not
8253+
if device != "cpu":
8254+
self.assertEqual(x.grad, x_orig.grad)
8255+
8256+
for input_format in [torch.contiguous_format, torch.channels_last]:
8257+
for grad_format in [torch.contiguous_format, torch.channels_last]:
8258+
helper(input_format, grad_format)
8259+
82328260
@onlyNativeDeviceTypes
82338261
def test_GroupNorm_numeric(self, device):
82348262
def group_norm_ref(X, gamma, beta, groups, channels, eps):

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@
11711171
rstd: not_implemented("native_layer_norm_backward rstd")
11721172

11731173
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
1174-
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].suggest_memory_format()), input.device().is_xpu() ? input : input.contiguous(input.suggest_memory_format()), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
1174+
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? grads[0].suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
11751175
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
11761176
result1: group_norm_mean_jvp(input_t, result1, group)
11771177
result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group)

0 commit comments

Comments
 (0)