Commit e499b46
Speed up half tensors printing (#141927)
This PR removes copycast of reduced precision types to float before printing, that was added in #14418 to probably unblock printing when many operations, like `is_nan` and `max` were not supported on CPUs
(Reusing old test plan) Before the PR:
```python
In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16)
In [2]: %timeit str(a)
621 μs ± 5.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
after the PR
```python
In [1]: import torch; a = torch.rand(1, 1700, 34, 50, dtype=torch.float16)
In [2]: %timeit str(a)
449 μs ± 2.34 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
Also, this allows one printing 15Gb Metal tensors on 32GB Mac machine:
```
% python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))"
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='mps:0', dtype=torch.float16)
```
Before this change it failed with non-descriptive
```
% python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))"
Traceback (most recent call last):
File "<string>", line 1, in <module>
import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))
~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malfet/git/pytorch/pytorch/torch/_tensor.py", line 568, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 708, in _str
return _str_intern(self, tensor_contents=tensor_contents)
File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 625, in _str_intern
tensor_str = _tensor_str(self, indent)
File "/Users/malfet/git/pytorch/pytorch/torch/_tensor_str.py", line 339, in _tensor_str
self = self.float()
RuntimeError: Invalid buffer size: 19.45 GB
```
Convert fp8 dtypes to float16, as float range is an overkill
Pull Request resolved: #141927
Approved by: https://github.com/ezyang1 parent d035db3 commit e499b46
1 file changed
+2
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
328 | 328 | | |
329 | 329 | | |
330 | 330 | | |
| 331 | + | |
331 | 332 | | |
332 | | - | |
333 | | - | |
334 | 333 | | |
335 | 334 | | |
336 | 335 | | |
337 | 336 | | |
338 | 337 | | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
| 338 | + | |
343 | 339 | | |
344 | 340 | | |
345 | 341 | | |
| |||
0 commit comments