Skip to content

Commit a1eae6d

Browse files
kiyosorafacebook-github-bot
authored andcommitted
Implementing NumPy-like function torch.heaviside() (#42523)
Summary: - Related with #38349 - Implementing the NumPy-like function `torch.heaviside()` . Pull Request resolved: #42523 Reviewed By: glaringlee Differential Revision: D23391941 Pulled By: mruberry fbshipit-source-id: 7b942321a62567a5fc0a3679a289f4c4c19e6134
1 parent 633d239 commit a1eae6d

File tree

12 files changed

+175
-0
lines changed

12 files changed

+175
-0
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ _(aten, hardsigmoid_backward) \
367367
_(aten, hardtanh) \
368368
_(aten, hardtanh_backward) \
369369
_(aten, hardtanh_forward) \
370+
_(aten, heaviside) \
370371
_(aten, hinge_embedding_loss) \
371372
_(aten, histc) \
372373
_(aten, hspmm) \

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ DEFINE_DISPATCH(gcd_stub);
4747
DEFINE_DISPATCH(lcm_stub);
4848
DEFINE_DISPATCH(hypot_stub);
4949
DEFINE_DISPATCH(nextafter_stub);
50+
DEFINE_DISPATCH(heaviside_stub);
5051

5152
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
5253
auto iter = TensorIterator::binary_op(result, self, other);
@@ -931,6 +932,33 @@ Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scal
931932
return self - (other * alpha);
932933
}
933934

935+
Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
936+
TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
937+
"heaviside is not yet implemented for complex tensors.");
938+
TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(),
939+
"heaviside is not yet implemented for tensors with different dtypes.");
940+
941+
auto iter = TensorIterator::binary_op(result, self, values, /*check_mem_overlap=*/true);
942+
heaviside_stub(iter.device_type(), iter);
943+
return result;
944+
}
945+
946+
Tensor heaviside(const Tensor& self, const Tensor& values) {
947+
TORCH_CHECK(!self.is_complex() && !values.is_complex(),
948+
"heaviside is not yet implemented for complex tensors.");
949+
TORCH_CHECK(self.dtype() == values.dtype(),
950+
"heaviside is not yet implemented for tensors with different dtypes.");
951+
952+
Tensor result;
953+
auto iter = TensorIterator::binary_op(result, self, values);
954+
heaviside_stub(iter.device_type(), iter);
955+
return iter.output();
956+
}
957+
958+
Tensor& heaviside_(Tensor& self, const Tensor& values) {
959+
return at::heaviside_out(self, self, values);
960+
}
961+
934962
// TODO: Deduplicate this with the TensorIterator logic. This would
935963
// also fix the TODOs below.
936964
Tensor binary_op_meta(const Tensor& self, const Tensor& other) {

aten/src/ATen/native/BinaryOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,6 @@ DECLARE_DISPATCH(binary_fn, gcd_stub);
6767
DECLARE_DISPATCH(binary_fn, lcm_stub);
6868
DECLARE_DISPATCH(binary_fn, hypot_stub);
6969
DECLARE_DISPATCH(binary_fn, nextafter_stub);
70+
DECLARE_DISPATCH(binary_fn, heaviside_stub);
7071

7172
}} // namespace at::native

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,14 @@ void nextafter_kernel(TensorIterator& iter) {
764764
});
765765
}
766766

767+
void heaviside_kernel(TensorIterator& iter) {
768+
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
769+
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
770+
return a == 0 ? b : static_cast<scalar_t>(a > 0);
771+
});
772+
});
773+
}
774+
767775
} // namespace
768776

769777
REGISTER_DISPATCH(add_stub, &add_kernel);
@@ -802,6 +810,7 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
802810
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
803811
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
804812
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
813+
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
805814

806815
} // namespace native
807816
} // namespace at

aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ void nextafter_kernel_cuda(TensorIterator& iter) {
101101
});
102102
}
103103

104+
void heaviside_kernel_cuda(TensorIterator& iter) {
105+
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
106+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
107+
return a == 0 ? b : static_cast<scalar_t>(a > 0);
108+
});
109+
});
110+
}
111+
104112
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
105113
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
106114
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
@@ -110,5 +118,6 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
110118
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
111119
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
112120
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
121+
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
113122

114123
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3613,6 +3613,17 @@
36133613
use_c10_dispatcher: full
36143614
variants: function
36153615

3616+
- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
3617+
dispatch:
3618+
CPU, CUDA: heaviside_out
3619+
3620+
- func: heaviside(Tensor self, Tensor values) -> Tensor
3621+
use_c10_dispatcher: full
3622+
variants: function, method
3623+
3624+
- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
3625+
variants: method
3626+
36163627
# For C++ only, until we have conversion from C++ numbers to Tensor
36173628
- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
36183629
use_c10_dispatcher: full

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ view of a storage and defines numeric operations on it.
332332
.. automethod:: gt_
333333
.. automethod:: half
334334
.. automethod:: hardshrink
335+
.. automethod:: heaviside
335336
.. automethod:: histc
336337
.. automethod:: hypot
337338
.. automethod:: hypot_

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Creation Ops
7575
dequantize
7676
complex
7777
polar
78+
heaviside
7879

7980
Indexing, Slicing, Joining, Mutating Ops
8081
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

test/test_torch.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6261,6 +6261,71 @@ def test_bitwise_xor(self, device):
62616261
torch.bitwise_xor(torch.tensor([True, True, False], device=device),
62626262
torch.tensor([False, True, False], device=device)))
62636263

6264+
@onlyOnCPUAndCUDA
6265+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
6266+
@dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
6267+
torch.testing.get_all_dtypes(include_complex=False))))
6268+
def test_heaviside(self, device, dtypes):
6269+
input_dtype = dtypes[0]
6270+
values_dtype = dtypes[1]
6271+
6272+
rng = np.random.default_rng()
6273+
input = np.array(rng.integers(-10, 10, size=10),
6274+
dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
6275+
input[0] = input[3] = input[7] = 0
6276+
values = np.array(rng.integers(-10, 10, size=10),
6277+
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
6278+
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
6279+
6280+
input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
6281+
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
6282+
out = torch.empty_like(input)
6283+
6284+
if input_dtype == values_dtype:
6285+
torch_result = torch.heaviside(input, values)
6286+
self.assertEqual(np_result, torch_result)
6287+
6288+
torch_result = input.heaviside(values)
6289+
self.assertEqual(np_result, torch_result)
6290+
6291+
torch.heaviside(input, values, out=out)
6292+
self.assertEqual(np_result, out)
6293+
6294+
input.heaviside_(values)
6295+
self.assertEqual(np_result, input)
6296+
else:
6297+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6298+
torch.heaviside(input, values)
6299+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6300+
input.heaviside(values)
6301+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6302+
torch.heaviside(input, values, out=out)
6303+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6304+
input.heaviside_(values)
6305+
6306+
6307+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
6308+
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
6309+
torch.testing.get_all_complex_dtypes())))
6310+
def test_heaviside_complex(self, device, dtypes):
6311+
input_dtype = dtypes[0]
6312+
values_dtype = dtypes[1]
6313+
6314+
data = (complex(0, -6), complex(-1, 3), complex(1, 1))
6315+
input = torch.tensor(data, device=device, dtype=input_dtype)
6316+
values = torch.tensor(data, device=device, dtype=values_dtype)
6317+
out = torch.empty_like(input)
6318+
real = input.real
6319+
6320+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6321+
torch.heaviside(input, real)
6322+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6323+
real.heaviside(values)
6324+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6325+
input.heaviside_(values)
6326+
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6327+
torch.heaviside(real, real, out=out)
6328+
62646329
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
62656330
@dtypes(*torch.testing.get_all_dtypes())
62666331
def test_logical_not(self, device, dtype):

torch/_tensor_docs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,20 @@ def add_docstr_all(method, docstr):
15351535
See :func:`torch.nn.functional.hardshrink`
15361536
""")
15371537

1538+
add_docstr_all('heaviside',
1539+
r"""
1540+
heaviside(values) -> Tensor
1541+
1542+
See :func:`torch.heaviside`
1543+
""")
1544+
1545+
add_docstr_all('heaviside_',
1546+
r"""
1547+
heaviside_(values) -> Tensor
1548+
1549+
In-place version of :meth:`~Tensor.heaviside`
1550+
""")
1551+
15381552
add_docstr_all('histc',
15391553
r"""
15401554
histc(bins=100, min=0, max=0) -> Tensor

0 commit comments

Comments
 (0)