@@ -6261,6 +6261,71 @@ def test_bitwise_xor(self, device):
62616261 torch.bitwise_xor(torch.tensor([True, True, False], device=device),
62626262 torch.tensor([False, True, False], device=device)))
62636263
6264+ @onlyOnCPUAndCUDA
6265+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
6266+ @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
6267+ torch.testing.get_all_dtypes(include_complex=False))))
6268+ def test_heaviside(self, device, dtypes):
6269+ input_dtype = dtypes[0]
6270+ values_dtype = dtypes[1]
6271+
6272+ rng = np.random.default_rng()
6273+ input = np.array(rng.integers(-10, 10, size=10),
6274+ dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
6275+ input[0] = input[3] = input[7] = 0
6276+ values = np.array(rng.integers(-10, 10, size=10),
6277+ dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
6278+ np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
6279+
6280+ input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
6281+ values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
6282+ out = torch.empty_like(input)
6283+
6284+ if input_dtype == values_dtype:
6285+ torch_result = torch.heaviside(input, values)
6286+ self.assertEqual(np_result, torch_result)
6287+
6288+ torch_result = input.heaviside(values)
6289+ self.assertEqual(np_result, torch_result)
6290+
6291+ torch.heaviside(input, values, out=out)
6292+ self.assertEqual(np_result, out)
6293+
6294+ input.heaviside_(values)
6295+ self.assertEqual(np_result, input)
6296+ else:
6297+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6298+ torch.heaviside(input, values)
6299+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6300+ input.heaviside(values)
6301+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6302+ torch.heaviside(input, values, out=out)
6303+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
6304+ input.heaviside_(values)
6305+
6306+
6307+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
6308+ @dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
6309+ torch.testing.get_all_complex_dtypes())))
6310+ def test_heaviside_complex(self, device, dtypes):
6311+ input_dtype = dtypes[0]
6312+ values_dtype = dtypes[1]
6313+
6314+ data = (complex(0, -6), complex(-1, 3), complex(1, 1))
6315+ input = torch.tensor(data, device=device, dtype=input_dtype)
6316+ values = torch.tensor(data, device=device, dtype=values_dtype)
6317+ out = torch.empty_like(input)
6318+ real = input.real
6319+
6320+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6321+ torch.heaviside(input, real)
6322+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6323+ real.heaviside(values)
6324+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6325+ input.heaviside_(values)
6326+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
6327+ torch.heaviside(real, real, out=out)
6328+
62646329 @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
62656330 @dtypes(*torch.testing.get_all_dtypes())
62666331 def test_logical_not(self, device, dtype):
0 commit comments