Skip to content

Commit e293c4e

Browse files
tczhangzhifacebook-github-bot
authored andcommitted
Fix 'in' return true incorrectly (#24156)
Summary: Because of 'return NotImplemented', __contains__ return True when the element is not a number. bool(NotImplemented) == True Pull Request resolved: #24156 Differential Revision: D16829895 Pulled By: zou3519 fbshipit-source-id: 9d3d58025b2b78b33a26fdfcfa6029d0d049f11f
1 parent 079cd4e commit e293c4e

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

test/test_torch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9952,6 +9952,15 @@ def test_contains(self):
99529952
val += 10
99539953
self.assertEqual(val in x, False)
99549954

9955+
self.assertRaisesRegex(
9956+
RuntimeError,
9957+
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type("foo")),
9958+
lambda: "foo" in x)
9959+
self.assertRaisesRegex(
9960+
RuntimeError,
9961+
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type([1, 2])),
9962+
lambda: [1, 2] in x)
9963+
99559964
@staticmethod
99569965
def _test_rot90(self, use_cuda=False):
99579966
device = torch.device("cuda" if use_cuda else "cpu")

torch/tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,11 @@ def __contains__(self, element):
427427
"""
428428
if isinstance(element, (torch.Tensor, Number)):
429429
return (element == self).any().item()
430-
return NotImplemented
430+
431+
raise RuntimeError(
432+
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
433+
type(element)
434+
)
431435

432436
@property
433437
def __cuda_array_interface__(self):

0 commit comments

Comments
 (0)