Skip to content

Commit 37c574f

Browse files
committed
Deprecate .mT,.T,.mH,.H on 0D tensors
As discussed with ngimel, this is not only not documented, but it's also an unnecessary edge case. See #90463 (comment) ghstack-source-id: b4e85ca Pull Request resolved: #92143
1 parent ec3941a commit 37c574f

File tree

7 files changed

+24
-13
lines changed

7 files changed

+24
-13
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,6 +3543,10 @@ Tensor numpy_T(const Tensor &self) {
35433543
"or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor."
35443544
);
35453545
}
3546+
if (n == 0) {
3547+
// Added in PyTorch 2.0
3548+
TORCH_WARN_ONCE("Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases.");
3549+
}
35463550
DimVector transpose_dims;
35473551
for (int64_t i = n - 1; i >= 0; --i) {
35483552
transpose_dims.push_back(i);
@@ -3552,6 +3556,10 @@ Tensor numpy_T(const Tensor &self) {
35523556

35533557
Tensor matrix_H(const Tensor &self) {
35543558
const auto ndim = self.dim();
3559+
if (ndim == 0) {
3560+
// Added in PyTorch 2.0
3561+
TORCH_WARN_ONCE("Tensor.H is deprecated on 0-D tensors. Consider using x.conj().");
3562+
}
35553563
TORCH_CHECK(ndim == 2 || ndim == 0,
35563564
"tensor.H is only supported on matrices (2-D tensors). Got ", ndim, "-D tensor.",
35573565
ndim > 2 ? " For batches of matrices, consider using tensor.mH" : "");
@@ -3576,14 +3584,25 @@ Tensor _adjoint(const Tensor &self, const bool transpose, const char* const name
35763584
} // anonymous namespace
35773585

35783586
Tensor mT(const Tensor &self) {
3587+
if (self.dim() == 0) {
3588+
// Added in PyTorch 2.0
3589+
TORCH_WARN_ONCE("Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.");
3590+
}
35793591
return _adjoint(self, /*transpose=*/true, "mT");
35803592
}
35813593

35823594
Tensor mH(const Tensor &self) {
3595+
if (self.dim() == 0) {
3596+
// Added in PyTorch 2.0
3597+
TORCH_WARN_ONCE("Tensor.mH is deprecated on 0-D tensors. Consider using x.conj().");
3598+
}
35833599
return _adjoint(self, /*transpose=*/false, "mH");
35843600
}
35853601

35863602
Tensor adjoint(const Tensor &self) {
3603+
if (self.dim() == 0) {
3604+
TORCH_WARN_ONCE("adjoint() is deprecated on 0-D tensors. Consider using x.conj().");
3605+
}
35873606
return _adjoint(self, /*transpose=*/false, "adjoint()");
35883607
}
35893608

test/functorch/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2625,7 +2625,6 @@ def op(t):
26252625
test = self._vmap_view_test
26262626
B0, B1, B2 = 7, 11, 13
26272627
test(op, (torch.rand(B0, 2, 3, 5),))
2628-
test(op, (torch.rand(B0),))
26292628
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
26302629
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
26312630
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,6 @@ def forward(self, x):
12101210
return x.T
12111211

12121212
self.run_test(NumpyTranspose(), torch.randn(4, 7))
1213-
self.run_test(NumpyTranspose(), torch.tensor(-42.0))
12141213

12151214
# Conversion of Transpose depends on input shape to be known.
12161215
# The following test only works when onnx shape inference is enabled.

test/test_legacy_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,7 +2040,6 @@ def op(t):
20402040
test = self._vmap_view_test
20412041
B0, B1, B2 = 7, 11, 13
20422042
test(op, (torch.rand(B0, 2, 3, 5),))
2043-
test(op, (torch.rand(B0),))
20442043
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
20452044
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
20462045
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)

test/test_mps.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6877,12 +6877,10 @@ def test_T(self, device="mps"):
68776877
self.assertEqual(t2, t1)
68786878
b = torch.randn(10, device=device)
68796879
self.assertEqual(b, b.T)
6880-
scalar = torch.tensor(5, device=device)
6881-
self.assertEqual(scalar, scalar.T)
68826880

68836881
def test_transposes(self, device="mps", dtype=torch.float32):
68846882
for op in ("T", "H", "mT", "mH", "adjoint"):
6885-
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
6883+
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
68866884
for shape in shapes:
68876885
a = make_tensor(shape, device=device, dtype=dtype)
68886886
t1 = getattr(a, op)

test/test_view_ops.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,21 +1307,18 @@ def test_T(self, device):
13071307
self.assertEqual(t2, t1)
13081308
b = torch.randn(10, device=device)
13091309
self.assertEqual(b, b.T)
1310-
scalar = torch.tensor(5, device=device)
1311-
self.assertEqual(scalar, scalar.T)
13121310

13131311
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
13141312
def test_transposes(self, device, dtype):
13151313
for op in ("T", "H", "mT", "mH", "adjoint"):
1316-
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
1314+
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
13171315
for shape in shapes:
13181316
a = make_tensor(shape, device=device, dtype=dtype)
13191317
t1 = getattr(a, op)
13201318
if op == "adjoint":
13211319
t1 = t1()
13221320
t2 = a
1323-
if a.ndim != 0:
1324-
t2 = t2.transpose(-2, -1)
1321+
t2 = t2.transpose(-2, -1)
13251322
if op[-1] == "H" or op == "adjoint":
13261323
t2 = t2.conj()
13271324
self.assertEqual(t2, t1)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,13 +1656,13 @@ def _numpy_ref_transpose(a, dim0, dim1):
16561656
def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs):
16571657
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
16581658

1659-
shapes = ((1, 2, 3), (), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
1659+
shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
16601660
return (SampleInput(make_arg(shape)) for shape in shapes)
16611661

16621662
def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
16631663
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
16641664

1665-
shapes = ((), (M, M), (M, L))
1665+
shapes = ((M, M), (M, L))
16661666
return (SampleInput(make_arg(shape)) for shape in shapes)
16671667

16681668
def error_inputs_T(self, device, has_ndims_error=False):

0 commit comments

Comments
 (0)