File tree Expand file tree Collapse file tree 2 files changed +23
-0
lines changed
Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -8151,6 +8151,17 @@ def test_reversed(self):
81518151 val = torch .tensor (42 )
81528152 self .assertEqual (reversed (val ), torch .tensor (42 ))
81538153
8154+ def test_contains (self ):
8155+ x = torch .arange (0 , 10 )
8156+ self .assertEqual (4 in x , True )
8157+ self .assertEqual (12 in x , False )
8158+
8159+ x = torch .arange (1 , 10 ).view (3 , 3 )
8160+ val = torch .arange (1 , 4 )
8161+ self .assertEqual (val in x , True )
8162+ val += 10
8163+ self .assertEqual (val in x , False )
8164+
81548165 @staticmethod
81558166 def _test_rot90 (self , use_cuda = False ):
81568167 device = torch .device ("cuda" if use_cuda else "cpu" )
Original file line number Diff line number Diff line change 77import weakref
88from torch ._six import imap
99from torch ._C import _add_docstr
10+ from numbers import Number
1011
1112
1213# NB: If you subclass Tensor, and want to share the subclassed class
@@ -426,6 +427,17 @@ def __array_wrap__(self, array):
426427 array = array .astype ('uint8' )
427428 return torch .from_numpy (array )
428429
430+ def __contains__ (self , element ):
431+ r"""Check if `element` is present in tensor
432+
433+ Arguments:
434+ element (Tensor or scalar): element to be checked
435+ for presence in current tensor"
436+ """
437+ if isinstance (element , (torch .Tensor , Number )):
438+ return (element == self ).any ().item ()
439+ return NotImplemented
440+
429441 @property
430442 def __cuda_array_interface__ (self ):
431443 """Array view description for cuda tensors.
You can’t perform that action at this time.
0 commit comments