Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1a8e29a
Add a bitwise NOT operator for integer and Boolean types (CUDA).
xuhdev Jun 27, 2019
e7188e4
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 27, 2019
10d3816
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 27, 2019
20de04d
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 27, 2019
6cde803
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 27, 2019
6cf4053
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 27, 2019
52aa0fb
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 30, 2019
ad0a74b
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jun 30, 2019
d76b541
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 1, 2019
524bce4
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 1, 2019
9ca2457
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 1, 2019
9b3472d
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 2, 2019
8351b8e
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 2, 2019
c03b85e
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 2, 2019
1d8b291
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 2, 2019
7bd290c
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 2, 2019
7613335
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 3, 2019
c7096d4
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 3, 2019
4102d18
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 3, 2019
31bb406
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 3, 2019
9abde47
Update on "Add a bitwise NOT operator for integer and Boolean types (…
xuhdev Jul 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@
namespace at {
namespace native {

Tensor bitwise_not(const Tensor& self) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_not_out(result, self);
}

Tensor& bitwise_not_(Tensor& self) {
return at::bitwise_not_out(self, self);
}

Tensor& bitwise_not_out(Tensor& result, const Tensor& self) {
checkBackend("bitwise_not", result, self.type().backend());
assert_no_internal_overlap(result, "bitwise_not");
auto iter = TensorIterator::unary_op(result, self);
bitwise_not_stub(iter->device_type(), *iter);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(result, self);
#endif
return result;
}

Tensor clamp(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
Tensor result = at::empty({0}, self.options());
return clamp_out(result, self, min, max);
Expand Down Expand Up @@ -167,7 +187,6 @@ IMPLEMENT_UNARY_OP_VEC(abs)
IMPLEMENT_UNARY_OP_VEC(acos)
IMPLEMENT_UNARY_OP_VEC(asin)
IMPLEMENT_UNARY_OP_VEC(atan)
IMPLEMENT_UNARY_OP_VEC(bitwise_not)
IMPLEMENT_UNARY_OP_VEC(ceil)
IMPLEMENT_UNARY_OP_VEC(cos)
IMPLEMENT_UNARY_OP_VEC(cosh)
Expand Down
17 changes: 16 additions & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
#include <limits>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <limits>

namespace at { namespace native {

void bitwise_not_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel(iter, []GPU_LAMBDA(bool a) {
return !a;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ~a;
});
});
}
}

template <typename scalar_t>
void fill_kernel_impl(TensorIterator& iter, Scalar value_scalar) {
auto value = value_scalar.to<scalar_t>();
Expand All @@ -24,5 +38,6 @@ static void fill_kernel_cuda(TensorIterator& iter, Scalar value) {
}

REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);
REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);

}}
5 changes: 2 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,11 @@

- func: bitwise_not_(Tensor(a!) self) -> Tensor(a!)
variants: method
dispatch:
CPU: _bitwise_not__cpu

- func: bitwise_not(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: _bitwise_not_out_cpu
CPU: bitwise_not_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to list them if the dispatch to the same function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other places like this, e.g.

CPU: dense_to_sparse
CUDA: dense_to_sparse
where CPU and CUDA dispatches go to the same functions - I assumed this is so that sparse tensors error out with a nice message if they call this function? If that's the case, makes sense to leave it here too.

CUDA: bitwise_not_out

- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

Expand Down
3 changes: 3 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,9 @@ def test_type_conversions_same_gpu(self):
def test_neg(self):
_TestTorchMixin._test_neg(self, lambda t: t.cuda())

def test_bitwise_not(self):
_TestTorchMixin._test_bitwise_not(self, 'cuda')

def test_isinf(self):
_TestTorchMixin._test_isinf(self, lambda t: t.cuda())

Expand Down
29 changes: 16 additions & 13 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,40 +1748,43 @@ def _test_neg(self, cast):
def test_neg(self):
self._test_neg(self, lambda t: t)

def test_bitwise_not(self):
res = 0xffff - torch.arange(127, dtype=torch.int8)
for t in (torch.BoolTensor,
torch.ByteTensor, torch.LongTensor, torch.IntTensor, torch.ShortTensor, torch.CharTensor):
if t == torch.BoolTensor:
a = torch.tensor([True, False])
expected_res = torch.tensor([False, True])
@staticmethod
def _test_bitwise_not(self, device):
res = 0xffff - torch.arange(127, dtype=torch.int8, device=device)
for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
if dtype == torch.bool:
a = torch.tensor([True, False], device=device)
expected_res = torch.tensor([False, True], device=device)
else:
a = torch.arange(127, dtype=t.dtype)
expected_res = res.type(t)
a = torch.arange(127, dtype=dtype, device=device)
expected_res = res.type(dtype)
# new tensor
self.assertEqual(expected_res, a.bitwise_not())
# out
b = t()
b = torch.empty(0, dtype=dtype, device=device)
torch.bitwise_not(a, out=b)
self.assertEqual(expected_res, b)
# in-place
a.bitwise_not_()
self.assertEqual(expected_res, a)

# test exceptions
for t in(torch.HalfTensor, torch.FloatTensor, torch.DoubleTensor):
a = torch.zeros(10, dtype=t.dtype)
for dtype in(torch.half, torch.float, torch.double):
a = torch.zeros(10, dtype=dtype, device=device)
# new tensor
with self.assertRaises(RuntimeError):
a.bitwise_not()
# out
b = t()
b = torch.empty(0, dtype=dtype, device=device)
with self.assertRaises(RuntimeError):
torch.bitwise_not(a, out=b)
# in-place
with self.assertRaises(RuntimeError):
a.bitwise_not_()

def test_bitwise_not(self):
self._test_bitwise_not(self, 'cpu')

def test_threshold(self):
for dtype in torch.testing.get_all_math_dtypes('cpu'):
if dtype != torch.uint8 and dtype != torch.float16:
Expand Down