Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,6 @@
namespace at {
namespace native {

inline ScalarType get_dtype_from_self(
const Tensor& self,
const optional<ScalarType>& dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}
ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}

} // namespace native

namespace meta {
Expand Down Expand Up @@ -1163,14 +1149,6 @@ std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_o

// ALL REDUCE #################################################################

inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dtype) {
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
if (dtype.has_value()) {
return dtype.value();
} else {
return result.scalar_type();
}
}

TORCH_IMPL_FUNC(sum_out)
(const Tensor& self,
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/ReduceOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,30 @@ static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_i
at::native::resize_output(result_indices, sizes);
}

inline ScalarType get_dtype_from_self(
const Tensor& self,
const c10::optional<ScalarType>& dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}
ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}

inline ScalarType get_dtype_from_result(Tensor& result, c10::optional<ScalarType> dtype) {
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
if (dtype.has_value()) {
return dtype.value();
} else {
return result.scalar_type();
}
}


} // native

namespace meta {
Expand Down
62 changes: 45 additions & 17 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
PROD,
MEAN,
COUNT_NONZERO,
TRACE
TRACE,
NANSUM,
};

using namespace mps;
Expand Down Expand Up @@ -247,6 +248,22 @@ void reduction_out_mps(
castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor
axes:@[@0, @1]
name:nil];
} else if (reduction_type == MPSReductionType::NANSUM) {
// Create a 0 tensor of the same shape as inputTensor
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0
dataType:castInputTensor.dataType];
// Find NaNs
MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor
name:nil];
// Replace NaNs with 0
MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:zeros
falsePredicateTensor:castInputTensor
name:nil];
// Sum
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced
axes:wrappedAxes
name:nil];
}

MPSGraphTensor* outputTensor = nil;
Expand Down Expand Up @@ -289,6 +306,33 @@ void reduction_out_mps(
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
}

Tensor& nansum_out_mps(
const Tensor& self,
OptionalIntArrayRef dim,
bool keepdim,
c10::optional<ScalarType> opt_dtype,
Tensor& result) {
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs");
if (c10::isIntegralType(self.scalar_type(), true)){
return at::sum_out(result, self, dim, keepdim, opt_dtype);
}
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
const auto mask = make_dim_mask(dim, self.dim());
resize_reduction_result(result, self, mask, keepdim, dtype);
reduction_out_mps(self, dim, keepdim, dtype, result, MPSReductionType::NANSUM, "nansum_out_mps");
return result;
}

Tensor nansum_mps(
const Tensor& self,
OptionalIntArrayRef dim,
bool keepdim,
c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
return nansum_out_mps(self, dim, keepdim, dtype, result);
}

Tensor trace_mps_out(const Tensor& self) {
Tensor output_t = at::native::empty_mps(
{},
Expand Down Expand Up @@ -316,22 +360,6 @@ Tensor trace_mps_out(const Tensor& self) {
reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps");
}

// Taken from ReduceOps.cpp
inline ScalarType get_dtype_from_self(
const Tensor& self,
const c10::optional<ScalarType>& dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}

ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}

TORCH_IMPL_FUNC(amax_out_mps)(
const Tensor& input_t,
IntArrayRef dim,
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5353,10 +5353,12 @@
variants: function, method
dispatch:
CPU, CUDA: nansum
MPS: nansum_mps

- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nansum_out
MPS: nansum_out_mps

- func: sum_to_size(Tensor self, int[] size) -> Tensor
variants: method
Expand Down
45 changes: 45 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,51 @@ def test_binops_dtype_precedence(self):
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))

def test_nansum(self):
def helper(dtype, noncontiguous, dim):
zero_cpu = torch.zeros((), dtype=dtype)

# Randomly scale the values
scale = random.randint(10, 100)
x_cpu: torch.Tensor = make_tensor(
(5, 5), dtype=dtype, device='cpu',
low=-scale, high=scale, noncontiguous=noncontiguous)

if dtype.is_floating_point:
nan_mask_cpu = x_cpu < (0.2 * scale)
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
x_cpu[nan_mask_cpu] = np.nan
else:
x_no_nan_cpu = x_cpu

x_mps = x_cpu.to('mps')
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
expect_out_cpu = torch.empty(0, dtype=dtype)
dim_kwargs = {"dim": dim} if dim is not None else {}
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)

actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
# Sanity check on CPU
self.assertEqual(expect, actual_cpu)

# Test MPS
actual_mps = torch.nansum(x_mps, **dim_kwargs)
# Test out= variant
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
self.assertEqual(expect, actual_mps)
self.assertEqual(expect_out_cpu, actual_out_mps)

args = itertools.product(
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
(True, False), # noncontiguous
(0, 1, None), # dim
)

for dtype, noncontiguous, dim in args:
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
helper(dtype, noncontiguous, dim)


class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
Expand Down