Skip to content

Commit 9c85e5e

Browse files
authored
Speed up half tensors printing
This removes cast of reduced precision types to float before testing, which were added in #14418 (Reusing old test plan) Before the PR: ```python In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16) In [2]: %timeit str(a) 621 μs ± 5.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` after the PR ```python In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16) In [2]: %timeit str(a) 449 μs ± 2.34 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` Also, this allows one printing 15Gb Metal tensors on 32GB Mac machine: ``` % python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))" tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='mps:0', dtype=torch.float16) ``` Before this change it failed with non-descriptive ``` % python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))" Traceback (most recent call last): File "<string>", line 1, in <module> import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16)) ~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/_tensor.py", line 568, in __repr__ return torch._tensor_str._str(self, tensor_contents=tensor_contents) ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 708, in _str return _str_intern(self, tensor_contents=tensor_contents) File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 625, in _str_intern tensor_str = _tensor_str(self, indent) File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 339, in _tensor_str self = self.float() RuntimeError: Invalid buffer size: 19.45 GB ```
1 parent 4959784 commit 9c85e5e

File tree

1 file changed

+0
-13
lines changed

1 file changed

+0
-13
lines changed

torch/_tensor_str.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,19 +328,6 @@ def _tensor_str(self, indent):
328328
if self.is_neg():
329329
self = self.resolve_neg()
330330

331-
if self.dtype in [
332-
torch.float16,
333-
torch.bfloat16,
334-
torch.float8_e5m2,
335-
torch.float8_e5m2fnuz,
336-
torch.float8_e4m3fn,
337-
torch.float8_e4m3fnuz,
338-
]:
339-
self = self.float()
340-
341-
if self.dtype is torch.complex32:
342-
self = self.cfloat()
343-
344331
if self.dtype.is_complex:
345332
# handle the conjugate bit
346333
self = self.resolve_conj()

0 commit comments

Comments
 (0)