Skip to content

Commit 5b223c4

Browse files
lezcanopytorchmergebot
authored andcommitted
Avoid calling allclose in the backward if there are tensor subclasses (#91444)
`allclose` it's data-dependent (returns a bool) so it does not play well with functorch. We are skipping that check in the context of subclasses to avoid hard errors. Partially fixes #90499 Pull Request resolved: #91444 Approved by: https://github.com/albanD
1 parent 4444138 commit 5b223c4

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

test/functorch/test_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,6 @@ def fn(inp, *args, **kwargs):
731731
# (2) attempting to use a Tensor in some data-dependent control flow or
732732
# (3) encountering this error in PyTorch internals.
733733
xfail("index_reduce"),
734-
xfail("linalg.eig"), # vmap over torch.allclose
735-
xfail("linalg.eigvals"), # vmap over torch.allclose
736734
xfail("linalg.householder_product"), # vmap: inplace into a regular tensor
737735
xfail("nanquantile", device_type='cpu'), # vmap not implemented for at::equal.
738736
xfail("native_layer_norm"), # vmap: inplace into a regular tensor
@@ -879,7 +877,6 @@ def vjp_of_vjp(*args_and_cotangents):
879877
skip("native_batch_norm"),
880878
skip("_native_batch_norm_legit"),
881879
xfail('__getitem__', ''), # dynamic error
882-
xfail('linalg.eig'), # Uses aten::allclose
883880
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
884881
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
885882
xfail('narrow'), # .item() call
@@ -1135,7 +1132,6 @@ def test():
11351132
xfail('special.log_ndtr'),
11361133
xfail('index_copy'),
11371134
xfail('index_fill'),
1138-
xfail('linalg.eig'),
11391135
xfail('linalg.householder_product'),
11401136
xfail('lu'),
11411137
xfail('lu_solve'),
@@ -1487,7 +1483,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
14871483
xfail('float'), # required rank 4 tensor to use channels_last format
14881484
xfail('half'), # required rank 4 tensor to use channels_last format
14891485
xfail('index_reduce'), # Forward AD not implemented and no decomposition
1490-
xfail('linalg.eig'), # vmap over torch.allclose isn't supported yet.
14911486
xfail('logcumsumexp'), # Forward AD not implemented and no decomposition
14921487
xfail('mvlgamma', 'mvlgamma_p_1'), # vmap: inplace into a regular tensor
14931488
xfail('mvlgamma', 'mvlgamma_p_3'), # vmap: inplace into a regular tensor
@@ -1771,7 +1766,6 @@ def fn(input, weight, bias):
17711766
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
17721767
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float32, torch.double))
17731768
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
1774-
xfail('linalg.eig'), # all close?
17751769
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
17761770
xfail('masked_select'),
17771771
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3549,9 +3549,9 @@ Tensor linalg_eig_backward(
35493549
auto VhgV = at::matmul(V.mH(), gV);
35503550
const auto diag_VhgV = VhgV.diagonal(0, -2, -1);
35513551

3552-
if (V.is_complex()) {
3553-
// Check invariance of the loss function wrt the transformation V -> V
3554-
// e^{i\phi}
3552+
if (V.is_complex() && !at::isTensorSubclassLike(diag_VhgV)) {
3553+
// Check invariance of the loss function wrt the transformation
3554+
// V -> V * e^{i\phi} for an arbitrary phi in RR^n
35553555
const auto imdiag_VhgV = at::imag(diag_VhgV);
35563556
TORCH_CHECK(
35573557
at::allclose(

0 commit comments

Comments
 (0)