Skip to content

Commit 19264b5

Browse files
soof-golanpytorchmergebot
authored andcommitted
[MPS] Add support for nansum on mps (#93845)
* Add `nansum_out_mps` and `nansum_mps` functions * Moved `get_dtype_from_self` into ReduceOpsUtils.h Fixes #86809 Pull Request resolved: #93845 Approved by: https://github.com/malfet
1 parent 8a9ea44 commit 19264b5

File tree

5 files changed

+116
-39
lines changed

5 files changed

+116
-39
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,6 @@
128128
namespace at {
129129
namespace native {
130130

131-
inline ScalarType get_dtype_from_self(
132-
const Tensor& self,
133-
const optional<ScalarType>& dtype,
134-
bool promote_integers) {
135-
if (dtype.has_value()) {
136-
return dtype.value();
137-
}
138-
ScalarType src_type = self.scalar_type();
139-
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
140-
return kLong;
141-
}
142-
return src_type;
143-
}
144-
145131
} // namespace native
146132

147133
namespace meta {
@@ -1163,14 +1149,6 @@ std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_o
11631149

11641150
// ALL REDUCE #################################################################
11651151

1166-
inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dtype) {
1167-
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
1168-
if (dtype.has_value()) {
1169-
return dtype.value();
1170-
} else {
1171-
return result.scalar_type();
1172-
}
1173-
}
11741152

11751153
TORCH_IMPL_FUNC(sum_out)
11761154
(const Tensor& self,

aten/src/ATen/native/ReduceOpsUtils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,30 @@ static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_i
320320
at::native::resize_output(result_indices, sizes);
321321
}
322322

323+
inline ScalarType get_dtype_from_self(
324+
const Tensor& self,
325+
const c10::optional<ScalarType>& dtype,
326+
bool promote_integers) {
327+
if (dtype.has_value()) {
328+
return dtype.value();
329+
}
330+
ScalarType src_type = self.scalar_type();
331+
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
332+
return kLong;
333+
}
334+
return src_type;
335+
}
336+
337+
inline ScalarType get_dtype_from_result(Tensor& result, c10::optional<ScalarType> dtype) {
338+
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
339+
if (dtype.has_value()) {
340+
return dtype.value();
341+
} else {
342+
return result.scalar_type();
343+
}
344+
}
345+
346+
323347
} // native
324348

325349
namespace meta {

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
PROD,
3232
MEAN,
3333
COUNT_NONZERO,
34-
TRACE
34+
TRACE,
35+
NANSUM,
3536
};
3637

3738
using namespace mps;
@@ -247,6 +248,22 @@ void reduction_out_mps(
247248
castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor
248249
axes:@[@0, @1]
249250
name:nil];
251+
} else if (reduction_type == MPSReductionType::NANSUM) {
252+
// Create a 0 tensor of the same shape as inputTensor
253+
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0
254+
dataType:castInputTensor.dataType];
255+
// Find NaNs
256+
MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor
257+
name:nil];
258+
// Replace NaNs with 0
259+
MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
260+
truePredicateTensor:zeros
261+
falsePredicateTensor:castInputTensor
262+
name:nil];
263+
// Sum
264+
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced
265+
axes:wrappedAxes
266+
name:nil];
250267
}
251268

252269
MPSGraphTensor* outputTensor = nil;
@@ -289,6 +306,33 @@ void reduction_out_mps(
289306
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
290307
}
291308

309+
Tensor& nansum_out_mps(
310+
const Tensor& self,
311+
OptionalIntArrayRef dim,
312+
bool keepdim,
313+
c10::optional<ScalarType> opt_dtype,
314+
Tensor& result) {
315+
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs");
316+
if (c10::isIntegralType(self.scalar_type(), true)){
317+
return at::sum_out(result, self, dim, keepdim, opt_dtype);
318+
}
319+
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
320+
const auto mask = make_dim_mask(dim, self.dim());
321+
resize_reduction_result(result, self, mask, keepdim, dtype);
322+
reduction_out_mps(self, dim, keepdim, dtype, result, MPSReductionType::NANSUM, "nansum_out_mps");
323+
return result;
324+
}
325+
326+
Tensor nansum_mps(
327+
const Tensor& self,
328+
OptionalIntArrayRef dim,
329+
bool keepdim,
330+
c10::optional<ScalarType> opt_dtype) {
331+
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
332+
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
333+
return nansum_out_mps(self, dim, keepdim, dtype, result);
334+
}
335+
292336
Tensor trace_mps_out(const Tensor& self) {
293337
Tensor output_t = at::native::empty_mps(
294338
{},
@@ -316,22 +360,6 @@ Tensor trace_mps_out(const Tensor& self) {
316360
reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps");
317361
}
318362

319-
// Taken from ReduceOps.cpp
320-
inline ScalarType get_dtype_from_self(
321-
const Tensor& self,
322-
const c10::optional<ScalarType>& dtype,
323-
bool promote_integers) {
324-
if (dtype.has_value()) {
325-
return dtype.value();
326-
}
327-
328-
ScalarType src_type = self.scalar_type();
329-
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
330-
return kLong;
331-
}
332-
return src_type;
333-
}
334-
335363
TORCH_IMPL_FUNC(amax_out_mps)(
336364
const Tensor& input_t,
337365
IntArrayRef dim,

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5354,10 +5354,12 @@
53545354
variants: function, method
53555355
dispatch:
53565356
CPU, CUDA: nansum
5357+
MPS: nansum_mps
53575358

53585359
- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
53595360
dispatch:
53605361
CPU, CUDA: nansum_out
5362+
MPS: nansum_out_mps
53615363

53625364
- func: sum_to_size(Tensor self, int[] size) -> Tensor
53635365
variants: method

test/test_mps.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,51 @@ def test_binops_dtype_precedence(self):
22792279
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
22802280
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
22812281

2282+
def test_nansum(self):
2283+
def helper(dtype, noncontiguous, dim):
2284+
zero_cpu = torch.zeros((), dtype=dtype)
2285+
2286+
# Randomly scale the values
2287+
scale = random.randint(10, 100)
2288+
x_cpu: torch.Tensor = make_tensor(
2289+
(5, 5), dtype=dtype, device='cpu',
2290+
low=-scale, high=scale, noncontiguous=noncontiguous)
2291+
2292+
if dtype.is_floating_point:
2293+
nan_mask_cpu = x_cpu < (0.2 * scale)
2294+
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
2295+
x_cpu[nan_mask_cpu] = np.nan
2296+
else:
2297+
x_no_nan_cpu = x_cpu
2298+
2299+
x_mps = x_cpu.to('mps')
2300+
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
2301+
expect_out_cpu = torch.empty(0, dtype=dtype)
2302+
dim_kwargs = {"dim": dim} if dim is not None else {}
2303+
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
2304+
2305+
actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
2306+
# Sanity check on CPU
2307+
self.assertEqual(expect, actual_cpu)
2308+
2309+
# Test MPS
2310+
actual_mps = torch.nansum(x_mps, **dim_kwargs)
2311+
# Test out= variant
2312+
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
2313+
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
2314+
self.assertEqual(expect, actual_mps)
2315+
self.assertEqual(expect_out_cpu, actual_out_mps)
2316+
2317+
args = itertools.product(
2318+
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
2319+
(True, False), # noncontiguous
2320+
(0, 1, None), # dim
2321+
)
2322+
2323+
for dtype, noncontiguous, dim in args:
2324+
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
2325+
helper(dtype, noncontiguous, dim)
2326+
22822327

22832328
class TestLogical(TestCase):
22842329
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):

0 commit comments

Comments
 (0)