Skip to content

Commit 38de981

Browse files
DenisVieriu97pytorchmergebot
authored andcommitted
[MPS] Add nonzero mps support (#91616)
Adds nonzero support for mps: **Pseudocode**: ``` // // inputTensor = [1, 0, 0, 3] // inputNonZero = [1, 0, 0, 1] (input != 0) // scan = [1, 1, 1, 2] (prefix sum) // maskedIndices = [0, -1, -1, 1] (select) // coordinates = [0, 1, 2, 3] (coordinateAlongAxis) // scatterResult = [0, 3] (scatter) ``` Pull Request resolved: #91616 Approved by: https://github.com/razarmehr
1 parent 97ff20d commit 38de981

File tree

4 files changed

+296
-3
lines changed

4 files changed

+296
-3
lines changed

aten/src/ATen/mps/MPSFallback.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ Tensor slow_conv2d_forward_mps(
6161
m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
6262
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
6363
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
64-
m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
6564
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
6665
}
6766

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/native/LinearAlgebraUtils.h>
1313
#include <ATen/native/mps/OperationUtils.h>
1414
#include <ATen/native/mps/operations/Indexing.h>
15+
#include <ATen/native/mps/MPSGraphVenturaOps.h>
1516
#include <ATen/native/Resize.h>
1617
#include <ATen/AccumulateType.h>
1718
#include <torch/library.h>
@@ -211,6 +212,185 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
211212
return result;
212213
}
213214

215+
static
216+
Tensor nonzero_fallback(const Tensor& self) {
217+
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
218+
"Falling back on CPU. This may have performance implications.");
219+
220+
return at::nonzero(self.to("cpu")).clone().to("mps");
221+
}
222+
223+
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
224+
if (!is_macos_13_or_newer()) {
225+
Tensor out_fallback = nonzero_fallback(self);
226+
at::native::resize_output(out_, out_fallback.sizes());
227+
out_.copy_(out_fallback.to("mps"));
228+
return out_;
229+
}
230+
231+
using namespace mps;
232+
const uint32_t maxDimensions = 16;
233+
234+
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
235+
file a support request");
236+
TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
237+
TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ",
238+
out_.device(), " and self on ", self.device());
239+
TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions");
240+
TORCH_CHECK(out_.is_mps());
241+
242+
MPSStream *stream = getCurrentMPSStream();
243+
struct CachedGraph : public MPSCachedGraph
244+
{
245+
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
246+
MPSGraphTensor* inputTensor_ = nil;
247+
MPSGraphTensor* outputTensor_ = nil;
248+
MPSGraphTensor* scatterDataTensor_ = nil;
249+
};
250+
251+
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
252+
int64_t nDim = self.dim();
253+
at::native::resize_output(out_, {total_nonzero, nDim});
254+
if (out_.numel() == 0) {
255+
return out_;
256+
}
257+
258+
bool contiguous_output = (out_.is_contiguous() && !out_.is_view());
259+
Tensor out = out_;
260+
if (!contiguous_output) {
261+
out = at::native::empty_mps(
262+
out_.sizes(),
263+
out_.scalar_type(),
264+
c10::nullopt,
265+
kMPS,
266+
c10::nullopt,
267+
c10::nullopt);
268+
}
269+
270+
int64_t _apparentInputShape = 1;
271+
for (auto dim : self.sizes()) {
272+
_apparentInputShape *= dim;
273+
}
274+
MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)];
275+
MPSShape *apparentInputShape = @[@(_apparentInputShape)];
276+
277+
// Pseudocode:
278+
//
279+
// inputTensor = [1, 0, 0, 3]
280+
// inputNonZero = [1, 0, 0, 1]
281+
// indices = [1, 1, 1, 2]
282+
// maskedIndices = [0, -1, -1, 1]
283+
// coordinates = [0, 1, 2, 3]
284+
// scatterResult = [0, 3]
285+
286+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
287+
@autoreleasepool {
288+
string key = "nonzero_out_mps" + getTensorsStringKey(self);
289+
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
290+
291+
if(!cachedGraph) {
292+
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
293+
CachedGraph *newCachedGraph = nil;
294+
@autoreleasepool {
295+
MPSDataType inputDataType = getMPSDataType(self.scalar_type());
296+
MPSShape* inputShape = getMPSShape(self);
297+
MPSGraph* mpsGraph = make_mps_graph();
298+
newCachedGraph = new CachedGraph(mpsGraph);
299+
300+
MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
301+
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
302+
MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
303+
MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
304+
MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
305+
MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
306+
secondaryTensor:zeroTensor
307+
name:nil];
308+
MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
309+
toType:MPSDataTypeInt32
310+
name:@"castToInt32"];
311+
MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor
312+
axis:0
313+
name:nil];
314+
MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
315+
secondaryTensor:oneTensor
316+
name:nil];
317+
MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
318+
truePredicateTensor:indicesMinusOneTensor
319+
falsePredicateTensor:minusMaxDimTensor
320+
name:nil];
321+
MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil]
322+
withShape:@[@-1]
323+
name:nil];
324+
if (nDim > 1) {
325+
NSMutableArray<MPSGraphTensor*> *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
326+
NSMutableArray<MPSGraphTensor*> *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
327+
328+
MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim
329+
dataType:MPSDataTypeInt32];
330+
maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
331+
secondaryTensor:constantRankTensor
332+
name:nil];
333+
coordinatesTensorArray[0] = coordinatesTensor;
334+
for (int i = 1; i < nDim; i++){
335+
maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
336+
secondaryTensor:oneTensor
337+
name:nil];
338+
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil]
339+
withShape:@[@-1]
340+
name:nil];
341+
}
342+
maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
343+
coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
344+
}
345+
346+
MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
347+
updatesTensor:coordinatesTensor
348+
indicesTensor:maskedIndicesTensor
349+
axis:0
350+
mode:MPSGraphScatterModeSet
351+
name:nil];
352+
353+
newCachedGraph->inputTensor_ = inputTensor;
354+
newCachedGraph->scatterDataTensor_ = scatterDataTensor;
355+
newCachedGraph->outputTensor_ = outputTensor;
356+
}
357+
return newCachedGraph;
358+
});
359+
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
360+
}
361+
362+
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
363+
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape);
364+
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape);
365+
366+
// Create dictionary of inputs and outputs
367+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
368+
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
369+
scatterPlaceholder.getMPSGraphTensor() : scatterPlaceholder.getMPSGraphTensorData()
370+
};
371+
372+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
373+
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
374+
};
375+
376+
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
377+
if (!contiguous_output) {
378+
out_.copy_(out);
379+
}
380+
}
381+
382+
return out_;
383+
}
384+
385+
Tensor nonzero_mps(const Tensor& self){
386+
if (!is_macos_13_or_newer()) {
387+
return nonzero_fallback(self);
388+
}
389+
390+
Tensor out = at::empty({0}, self.options().dtype(kLong));
391+
return nonzero_out_mps(self, out);
392+
}
393+
214394
Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
215395
namedinference::compute_broadcast_outnames(self, mask);
216396
Tensor result = at::empty({0}, self.options());

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8521,13 +8521,15 @@
85218521
dispatch:
85228522
CPU: nonzero_out_cpu
85238523
CUDA: nonzero_out_cuda
8524+
MPS: nonzero_out_mps
85248525
tags: dynamic_output_shape
85258526

85268527
- func: nonzero(Tensor self) -> Tensor
85278528
variants: method, function
85288529
dispatch:
85298530
CPU: nonzero_cpu
85308531
CUDA: nonzero_cuda
8532+
MPS: nonzero_mps
85318533
tags: [dynamic_output_shape, canonical]
85328534

85338535
- func: nonzero_numpy(Tensor self) -> Tensor[]

test/test_mps.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6687,6 +6687,116 @@ class TestAdvancedIndexing(TestCase):
66876687
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
66886688
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
66896689

6690+
def test_nonzero_no_warning(self):
6691+
device = "mps"
6692+
t = torch.randn((2, 2), device=device)
6693+
with warnings.catch_warnings(record=True) as w:
6694+
warnings.simplefilter("always")
6695+
torch.nonzero(t)
6696+
t.nonzero()
6697+
self.assertEqual(len(w), 0)
6698+
6699+
def test_nonzero(self):
6700+
def helper(dtype):
6701+
device = "mps"
6702+
shapes = [
6703+
torch.Size((12,)),
6704+
torch.Size((12, 1)),
6705+
torch.Size((1, 12)),
6706+
torch.Size((6, 2)),
6707+
torch.Size((3, 2, 2)),
6708+
torch.Size((5, 5, 5)),
6709+
]
6710+
6711+
def gen_nontrivial_input(shape, dtype, device):
6712+
if dtype != torch.bfloat16:
6713+
return torch.randint(2, shape, device=device, dtype=dtype)
6714+
else:
6715+
# windows does not work for bfloat16 randing
6716+
return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
6717+
6718+
for shape in shapes:
6719+
tensor = gen_nontrivial_input(shape, dtype, device)
6720+
dst1 = torch.nonzero(tensor, as_tuple=False)
6721+
dst2 = tensor.nonzero(as_tuple=False)
6722+
dst3 = torch.empty([], dtype=torch.long, device=device)
6723+
dst3 = dst3.resize_(0)
6724+
torch.nonzero(tensor, out=dst3)
6725+
np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
6726+
np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
6727+
self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
6728+
self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
6729+
self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
6730+
tup1 = torch.nonzero(tensor, as_tuple=True)
6731+
tup2 = tensor.nonzero(as_tuple=True)
6732+
tup1 = torch.stack(tup1).t().cpu()
6733+
tup2 = torch.stack(tup2).t().cpu()
6734+
self.assertEqual(tup1, np_result, atol=0, rtol=0)
6735+
self.assertEqual(tup2, np_result, atol=0, rtol=0)
6736+
[helper(dtype) for dtype in self.supported_dtypes]
6737+
6738+
def test_nonzero_astuple_out(self):
6739+
device = "mps"
6740+
t = torch.randn((3, 3, 3), device=device)
6741+
out = torch.empty([], dtype=torch.long, device=device)
6742+
out = out.resize_(0)
6743+
6744+
with self.assertRaises(RuntimeError):
6745+
torch.nonzero(t, as_tuple=True, out=out)
6746+
6747+
self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
6748+
6749+
# Verifies that JIT script cannot handle the as_tuple kwarg
6750+
# See Issue https://github.com/pytorch/pytorch/issues/45499.
6751+
def _foo(t):
6752+
tuple_result = torch.nonzero(t, as_tuple=True)
6753+
nontuple_result = torch.nonzero(t, as_tuple=False)
6754+
out = torch.empty_like(nontuple_result)
6755+
torch.nonzero(t, as_tuple=False, out=out)
6756+
return tuple_result, nontuple_result, out
6757+
6758+
with self.assertRaises(RuntimeError):
6759+
scripted_foo = torch.jit.script(_foo)
6760+
6761+
# Verifies that JIT tracing works fine
6762+
traced_foo = torch.jit.trace(_foo, t)
6763+
traced_tuple, traced_nontuple, traced_out = traced_foo(t)
6764+
expected_tuple = torch.nonzero(t, as_tuple=True)
6765+
expected_nontuple = torch.nonzero(t)
6766+
6767+
self.assertEqual(traced_tuple, expected_tuple)
6768+
self.assertEqual(traced_nontuple, expected_nontuple)
6769+
self.assertEqual(traced_out, expected_nontuple)
6770+
6771+
def test_nonzero_discontiguous(self):
6772+
device = "mps"
6773+
shape = (4, 4)
6774+
tensor = torch.randint(2, shape, device=device)
6775+
tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
6776+
dst1 = tensor.nonzero(as_tuple=False)
6777+
dst2 = tensor_nc.nonzero(as_tuple=False)
6778+
self.assertEqual(dst1, dst2, atol=0, rtol=0)
6779+
dst3 = torch.empty_like(dst1)
6780+
data_ptr = dst3.data_ptr()
6781+
# expect dst3 storage to be reused
6782+
torch.nonzero(tensor, out=dst3)
6783+
self.assertEqual(data_ptr, dst3.data_ptr())
6784+
self.assertEqual(dst1, dst3, atol=0, rtol=0)
6785+
# discontiguous out
6786+
dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
6787+
data_ptr = dst4.data_ptr()
6788+
strides = dst4.stride()
6789+
torch.nonzero(tensor, out=dst4)
6790+
self.assertEqual(data_ptr, dst4.data_ptr())
6791+
self.assertEqual(dst1, dst4, atol=0, rtol=0)
6792+
self.assertEqual(strides, dst4.stride())
6793+
6794+
def test_nonzero_non_diff(self):
6795+
device = "mps"
6796+
x = torch.randn(10, requires_grad=True)
6797+
nz = x.nonzero()
6798+
self.assertFalse(nz.requires_grad)
6799+
66906800
def test_masked_select(self):
66916801
x = torch.randn(3, 4)
66926802
x_mps = x.to("mps")
@@ -7841,7 +7951,8 @@ class TestConsistency(TestCase):
78417951
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78427952
'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78437953
'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7844-
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']
7954+
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7955+
'nonzero': ['f32', 'i16', 'i32', 'i64']
78457956
}
78467957

78477958

@@ -8066,6 +8177,8 @@ class TestConsistency(TestCase):
80668177
'slice_scatter': [torch.uint8],
80678178
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
80688179

8180+
# count_nonzero returns wrong results for these dtypes
8181+
'nonzero': [torch.uint8, torch.float16],
80698182
# ALLOW_LIST doesn't know about variants
80708183
'nn.functional.padconstant': None,
80718184

@@ -8141,7 +8254,6 @@ class TestConsistency(TestCase):
81418254
'eq': None,
81428255
'mul': None,
81438256
'cartesian_prod': None,
8144-
'nonzero': None,
81458257
'bool': None,
81468258
'inner': None,
81478259
'dstack': None,

0 commit comments

Comments
 (0)