Skip to content

Commit ae0e89a

Browse files
committed
Update on "Deprecate .mT,.T,.mH,.H on 0D tensors"
As discussed with ngimel, this is not only not documented, but also an unnecessary edge case. See #90463 (comment) [ghstack-poisoned]
1 parent b169ce7 commit ae0e89a

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,7 +3546,6 @@ Tensor numpy_T(const Tensor &self) {
35463546
if (n == 0) {
35473547
// Added in PyTorch 2.0
35483548
TORCH_WARN_ONCE("Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases.");
3549-
throw 1;
35503549
}
35513550
DimVector transpose_dims;
35523551
for (int64_t i = n - 1; i >= 0; --i) {
@@ -3560,7 +3559,6 @@ Tensor matrix_H(const Tensor &self) {
35603559
if (ndim == 0) {
35613560
// Added in PyTorch 2.0
35623561
TORCH_WARN_ONCE("Tensor.H is deprecated on 0-D tensors. Consider using x.conj().");
3563-
throw 1;
35643562
}
35653563
TORCH_CHECK(ndim == 2 || ndim == 0,
35663564
"tensor.H is only supported on matrices (2-D tensors). Got ", ndim, "-D tensor.",
@@ -3589,7 +3587,6 @@ Tensor mT(const Tensor &self) {
35893587
if (self.dim() == 0) {
35903588
// Added in PyTorch 2.0
35913589
TORCH_WARN_ONCE("Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.");
3592-
throw 1;
35933590
}
35943591
return _adjoint(self, /*transpose=*/true, "mT");
35953592
}
@@ -3598,7 +3595,6 @@ Tensor mH(const Tensor &self) {
35983595
if (self.dim() == 0) {
35993596
// Added in PyTorch 2.0
36003597
TORCH_WARN_ONCE("Tensor.mH is deprecated on 0-D tensors. Consider using x.conj().");
3601-
throw 1;
36023598
}
36033599
return _adjoint(self, /*transpose=*/false, "mH");
36043600
}

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6880,7 +6880,7 @@ def test_T(self, device="mps"):
68806880

68816881
def test_transposes(self, device="mps", dtype=torch.float32):
68826882
for op in ("T", "H", "mT", "mH", "adjoint"):
6883-
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
6883+
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
68846884
for shape in shapes:
68856885
a = make_tensor(shape, device=device, dtype=dtype)
68866886
t1 = getattr(a, op)

0 commit comments

Comments
 (0)