Skip to content

Commit 9c4c9c3

Browse files
xuhdevfacebook-github-bot
authored andcommitted
Delegate Python ~ (invert operator) to Tensor.bitwise_not().
Summary: Pull Request resolved: #22326 Test Plan: Imported from OSS Differential Revision: D16183577 Pulled By: colesbury fbshipit-source-id: f86838c407db4ded9ce70998bf1ab1ffd75b3b58
1 parent 574e808 commit 9c4c9c3

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

test/test_torch.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11334,10 +11334,6 @@ def test_bitwise_ops(self):
1133411334
else:
1133511335
self.assertFalse(x[idx] ^ y[idx])
1133611336

11337-
invert_result = ~x
11338-
for idx in iter_indices(x):
11339-
self.assertEqual(1 - x[idx], invert_result[idx])
11340-
1134111337
x_clone = x.clone()
1134211338
x_clone &= y
1134311339
self.assertEqual(x_clone, and_result)
@@ -11350,9 +11346,20 @@ def test_bitwise_ops(self):
1135011346
x_clone ^= y
1135111347
self.assertEqual(x_clone, xor_result)
1135211348

11353-
def test_invert(self):
11354-
x = torch.ByteTensor([0, 1, 1])
11355-
self.assertEqual((~x).tolist(), [1, 0, 0])
11349+
def test_op_invert(self):
11350+
res = 0xffff - torch.arange(127, dtype=torch.int8)
11351+
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
11352+
a = torch.arange(127, dtype=dtype)
11353+
self.assertEqual(res.type(dtype), ~a)
11354+
11355+
self.assertEqual(torch.tensor([True, False]),
11356+
~torch.tensor([False, True]))
11357+
11358+
# test exceptions
11359+
for dtype in(torch.half, torch.float, torch.double):
11360+
a = torch.zeros(10, dtype=dtype)
11361+
with self.assertRaises(TypeError):
11362+
b = ~a
1135611363

1135711364
def test_apply(self):
1135811365
x = torch.arange(1, 6)

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,14 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
285285
static Tensor dispatch_invert(const Tensor & self) {
286286
AutoNoGIL no_gil;
287287
OptionalDeviceGuard device_guard(device_of(self));
288-
return 1 - self;
288+
return self.bitwise_not();
289289
}
290290

291291
static PyObject * THPVariable_invert(PyObject* self, PyObject* args) {
292292
HANDLE_TH_ERRORS
293293
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
294-
if (self_.scalar_type() != at::kByte) {
295-
throw TypeError("~ (operator.invert) is only implemented on byte tensors");
294+
if (!isIntegralType(self_.scalar_type()) && self_.scalar_type() != at::kBool) {
295+
throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors");
296296
}
297297
return THPVariable_Wrap(dispatch_invert(self_));
298298
END_HANDLE_TH_ERRORS

0 commit comments

Comments
 (0)