Skip to content

Commit c21388f

Browse files
bhushan23jramseyer
authored andcommitted
Introducing IsInf (pytorch#9169)
Summary: torch.isinf - checks element wise +/- inf implements pytorch#9132 Pull Request resolved: pytorch#9169 Reviewed By: SsnL Differential Revision: D8768614 Pulled By: zou3519 fbshipit-source-id: dd1b5f6c976deb421d626e22cdd25500ec04d796
1 parent b6ea6e8 commit c21388f

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ Comparison Ops
223223
.. autofunction:: equal
224224
.. autofunction:: ge
225225
.. autofunction:: gt
226+
.. autofunction:: isinf
226227
.. autofunction:: isnan
227228
.. autofunction:: kthvalue
228229
.. autofunction:: le

test/test_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,6 +4711,10 @@ def test_isnan(self):
47114711
x = torch.Tensor([1, float('nan'), 2])
47124712
self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
47134713

4714+
def test_isinf(self):
4715+
x = torch.Tensor([1, float('inf'), 2, float('-inf'), float('nan')])
4716+
self.assertEqual(torch.isinf(x), torch.ByteTensor([0, 1, 0, 1, 0]))
4717+
47144718
def test_RNGState(self):
47154719
state = torch.get_rng_state()
47164720
stateCloned = state.clone()

torch/functional.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
'argmin',
99
'btrifact',
1010
'btriunpack',
11+
'isinf',
1112
'isnan',
1213
'split',
1314
'unique',
@@ -135,6 +136,25 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
135136
return P, L, U
136137

137138

139+
def isinf(tensor):
140+
r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not.
141+
142+
Arguments:
143+
tensor (Tensor): A tensor to check
144+
145+
Returns:
146+
Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `+/-INF` elements and 0 otherwise
147+
148+
Example::
149+
150+
>>> torch.isinf(torch.Tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
151+
tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8)
152+
"""
153+
if not isinstance(tensor, torch.Tensor):
154+
raise ValueError("The argument is not a tensor", str(tensor))
155+
return tensor.abs() == float('inf')
156+
157+
138158
def isnan(tensor):
139159
r"""Returns a new tensor with boolean elements representing if each element is `NaN` or not.
140160
@@ -150,7 +170,7 @@ def isnan(tensor):
150170
tensor([ 0, 1, 0], dtype=torch.uint8)
151171
"""
152172
if not isinstance(tensor, torch.Tensor):
153-
raise ValueError("The argument is not a tensor")
173+
raise ValueError("The argument is not a tensor", str(tensor))
154174
return tensor != tensor
155175

156176

0 commit comments

Comments
 (0)