Skip to content

Commit cf71385

Browse files
tunzsoumith
authored andcommitted
Implement torch.isnan (#5273)
* Implement torch.isnan * Simple python implementation * Fix typo
1 parent fae6c67 commit cf71385

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ Comparison Ops
152152
.. autofunction:: equal
153153
.. autofunction:: ge
154154
.. autofunction:: gt
155+
.. autofunction:: isnan
155156
.. autofunction:: kthvalue
156157
.. autofunction:: le
157158
.. autofunction:: lt

test/test_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3032,6 +3032,10 @@ def test_logical(self):
30323032
self.assertEqual(neqs.sum(), xne.sum(), 0)
30333033
self.assertEqual(x.nelement(), all.sum())
30343034

3035+
def test_isnan(self):
3036+
x = torch.Tensor([1, float('nan'), 2])
3037+
self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
3038+
30353039
def test_RNGState(self):
30363040
state = torch.get_rng_state()
30373041
stateCloned = state.clone()

torch/functional.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
__all__ = [
77
'split', 'chunk', 'empty_like', 'stack', 'unbind', 'btriunpack', 'matmul', 'det', 'stft',
8-
'hann_window', 'hamming_window', 'bartlett_window', 'where',
8+
'hann_window', 'hamming_window', 'bartlett_window', 'where', 'isnan'
99
]
1010

1111

@@ -514,3 +514,25 @@ def where(condition, x, y):
514514
# the parameter order is changed here; the functional order is the same as numpy; the
515515
# method follows the usual torch mask semantics of x.fn(mask, y)
516516
return torch._C._VariableBase.where(x, condition, y)
517+
518+
519+
def isnan(tensor):
520+
r"""Returns a new tensor with boolean elements representing if each element is NaN or not.
521+
522+
Arguments:
523+
tensor (Tensor): A tensor to check
524+
525+
Returns:
526+
Tensor: A ``torch.ByteTensor`` containing a 1 at each location of NaN elements.
527+
528+
Example::
529+
530+
>>> torch.isnan(torch.Tensor([1, float('nan'), 2]))
531+
0
532+
1
533+
0
534+
[torch.ByteTensor of size 3]
535+
"""
536+
if not torch.is_tensor(tensor):
537+
raise ValueError("The argument is not a tensor")
538+
return tensor != tensor

0 commit comments

Comments
 (0)