Skip to content

Commit e921ed1

Browse files
committed
[MPS] move test_nansum into TestMPS
1 parent 3bd6b32 commit e921ed1

File tree

1 file changed

+45
-47
lines changed

1 file changed

+45
-47
lines changed

test/test_mps.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,51 @@ def test_binops_dtype_precedence(self):
22792279
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
22802280
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
22812281

2282+
def test_nansum(self):
2283+
def helper(dtype, noncontiguous, dim):
2284+
zero_cpu = torch.zeros((), dtype=dtype)
2285+
2286+
# Randomly scale the values
2287+
scale = random.randint(10, 100)
2288+
x_cpu: torch.Tensor = make_tensor(
2289+
(5, 5), dtype=dtype, device='cpu',
2290+
low=-scale, high=scale, noncontiguous=noncontiguous)
2291+
2292+
if dtype.is_floating_point:
2293+
nan_mask_cpu = x_cpu < (0.2 * scale)
2294+
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
2295+
x_cpu[nan_mask_cpu] = np.nan
2296+
else:
2297+
x_no_nan_cpu = x_cpu
2298+
2299+
x_mps = x_cpu.to('mps')
2300+
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
2301+
expect_out_cpu = torch.empty(0, dtype=dtype)
2302+
dim_kwargs = {"dim": dim} if dim is not None else {}
2303+
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
2304+
2305+
actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
2306+
# Sanity check on CPU
2307+
self.assertEqual(expect, actual_cpu)
2308+
2309+
# Test MPS
2310+
actual_mps = torch.nansum(x_mps, **dim_kwargs)
2311+
# Test out= variant
2312+
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
2313+
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
2314+
self.assertEqual(expect, actual_mps)
2315+
self.assertEqual(expect_out_cpu, actual_out_mps)
2316+
2317+
args = itertools.product(
2318+
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
2319+
(True, False), # noncontiguous
2320+
(0, 1, None), # dim
2321+
)
2322+
2323+
for dtype, noncontiguous, dim in args:
2324+
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
2325+
helper(dtype, noncontiguous, dim)
2326+
22822327

22832328
class TestLogical(TestCase):
22842329
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
@@ -8252,53 +8297,6 @@ def test_serialization_map_location(self):
82528297
self.assertEqual(x2.device.type, "cpu")
82538298

82548299

8255-
class TestNanSum(TestCase):
8256-
8257-
def helper(self, dtype, noncontiguous, dim):
8258-
zero_cpu = torch.zeros((), dtype=dtype)
8259-
8260-
# Randomly scale the values
8261-
scale = random.randint(10, 100)
8262-
x_cpu: torch.Tensor = make_tensor(
8263-
(5, 5), dtype=dtype, device='cpu',
8264-
low=-scale, high=scale, noncontiguous=noncontiguous)
8265-
8266-
if dtype.is_floating_point:
8267-
nan_mask_cpu = x_cpu < (0.2 * scale)
8268-
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
8269-
x_cpu[nan_mask_cpu] = np.nan
8270-
else:
8271-
x_no_nan_cpu = x_cpu
8272-
8273-
x_mps = x_cpu.to('mps')
8274-
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
8275-
expect_out_cpu = torch.empty(0, dtype=dtype)
8276-
dim_kwargs = {"dim": dim} if dim is not None else {}
8277-
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
8278-
8279-
actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
8280-
# Sanity check on CPU
8281-
self.assertEqual(expect, actual_cpu)
8282-
8283-
# Test MPS
8284-
actual_mps = torch.nansum(x_mps, **dim_kwargs)
8285-
# Test out= variant
8286-
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
8287-
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
8288-
self.assertEqual(expect, actual_mps)
8289-
self.assertEqual(expect_out_cpu, actual_out_mps)
8290-
8291-
def test_nansum(self):
8292-
args = itertools.product(
8293-
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
8294-
(True, False), # noncontiguous
8295-
(0, 1, None), # dim
8296-
)
8297-
8298-
for dtype, noncontiguous, dim in args:
8299-
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
8300-
self.helper(dtype, noncontiguous, dim)
8301-
83028300
MPS_DTYPES = get_all_dtypes()
83038301
for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]:
83048302
del MPS_DTYPES[MPS_DTYPES.index(t)]

0 commit comments

Comments
 (0)