Skip to content

Commit e499b46

Browse files
malfetpytorchmergebot
authored andcommitted
Speed up half tensors printing (#141927)
This PR removes copycast of reduced precision types to float before printing, that was added in #14418 to probably unblock printing when many operations, like `is_nan` and `max` were not supported on CPUs (Reusing old test plan) Before the PR: ```python In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16) In [2]: %timeit str(a) 621 μs ± 5.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` after the PR ```python In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16) In [2]: %timeit str(a) 449 μs ± 2.34 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` Also, this allows one printing 15Gb Metal tensors on 32GB Mac machine: ``` % python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))" tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='mps:0', dtype=torch.float16) ``` Before this change it failed with non-descriptive ``` % python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))" Traceback (most recent call last): File "<string>", line 1, in <module> import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16)) ~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/_tensor.py", line 568, in __repr__ return torch._tensor_str._str(self, tensor_contents=tensor_contents) ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 708, in _str return _str_intern(self, tensor_contents=tensor_contents) File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 625, in _str_intern tensor_str = _tensor_str(self, indent) File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 339, in _tensor_str self = self.float() RuntimeError: Invalid buffer size: 19.45 GB ``` Convert fp8 dtypes to float16, as float range is an overkill Pull Request resolved: #141927 Approved by: https://github.com/ezyang
1 parent d035db3 commit e499b46

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torch/_tensor_str.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,18 +328,14 @@ def _tensor_str(self, indent):
328328
if self.is_neg():
329329
self = self.resolve_neg()
330330

331+
# TODO: Remove me when `masked_select` is implemented for FP8
331332
if self.dtype in [
332-
torch.float16,
333-
torch.bfloat16,
334333
torch.float8_e5m2,
335334
torch.float8_e5m2fnuz,
336335
torch.float8_e4m3fn,
337336
torch.float8_e4m3fnuz,
338337
]:
339-
self = self.float()
340-
341-
if self.dtype is torch.complex32:
342-
self = self.cfloat()
338+
self = self.half()
343339

344340
if self.dtype.is_complex:
345341
# handle the conjugate bit

0 commit comments

Comments
 (0)