Skip to content

Commit b57fe3c

Browse files
bhushan23facebook-github-bot
authored andcommitted
Introducing array-like sequence methods __contains__ (#17733)
Summary: for tensor Fixes: #17000 Pull Request resolved: #17733 Differential Revision: D14401952 Pulled By: soumith fbshipit-source-id: c841b128c5a1fceda1094323ed4ef1d0cf494909
1 parent 906f9ef commit b57fe3c

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/test_torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8151,6 +8151,17 @@ def test_reversed(self):
81518151
val = torch.tensor(42)
81528152
self.assertEqual(reversed(val), torch.tensor(42))
81538153

8154+
def test_contains(self):
8155+
x = torch.arange(0, 10)
8156+
self.assertEqual(4 in x, True)
8157+
self.assertEqual(12 in x, False)
8158+
8159+
x = torch.arange(1, 10).view(3, 3)
8160+
val = torch.arange(1, 4)
8161+
self.assertEqual(val in x, True)
8162+
val += 10
8163+
self.assertEqual(val in x, False)
8164+
81548165
@staticmethod
81558166
def _test_rot90(self, use_cuda=False):
81568167
device = torch.device("cuda" if use_cuda else "cpu")

torch/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import weakref
88
from torch._six import imap
99
from torch._C import _add_docstr
10+
from numbers import Number
1011

1112

1213
# NB: If you subclass Tensor, and want to share the subclassed class
@@ -426,6 +427,17 @@ def __array_wrap__(self, array):
426427
array = array.astype('uint8')
427428
return torch.from_numpy(array)
428429

430+
def __contains__(self, element):
431+
r"""Check if `element` is present in tensor
432+
433+
Arguments:
434+
element (Tensor or scalar): element to be checked
435+
for presence in current tensor"
436+
"""
437+
if isinstance(element, (torch.Tensor, Number)):
438+
return (element == self).any().item()
439+
return NotImplemented
440+
429441
@property
430442
def __cuda_array_interface__(self):
431443
"""Array view description for cuda tensors.

0 commit comments

Comments
 (0)