Skip to content

Commit beaa5c5

Browse files
kulinsethrazarmehrqqaatwDenisVieriu97
authored
[MPS] View fixes (#95323)
* [MPS] Fix the uint8 type issue with View ops kernels (#95145) This should fix the problem in Resnet model with image artifacts due to saturation on int8 type and also the incorrect class recognition reported in #86954. Fixes #86954 Pull Request resolved: #95145 Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97 * [MPS] Fix tensor with non-zero storage offset graph gathering (#91071) Previously, the "can slice" flag in Placeholder constructor in `OperationUtils.mm` is conditioned on whether the numbers of dimensions of base shape and view shape are the same. This doesn't consider the situation that a view tensor could be the base tensor's sliced and then unsqueezed version, resulting in different num of dims. For example, if we want to stack `y_mps` and `x_mps` on the last dim: ``` t_mps = torch.tensor([1, 2, 3, 4], device="mps") x_mps = t_mps[2:] # [3, 4] y_mps = t_mps[:2] # [1, 2] res_mps = torch.stack((y_mps, x_mps), dim=-1) ``` the kernel will unsqueeze both of them on the last dim and then concatenate them, which is equivalent to: ``` res_mps = torch.cat((y_mps.unsqueeze(-1), x_mps.unsqueeze(-1)), dim=-1) ``` `x_mps.unsqueeze(-1)` is an unsqueezed and contiguous tensor with a storage offset, this kind of tensors should be sliceable without cloning its storage. Fixes #87856 Fixes #91065 Pull Request resolved: #91071 Approved by: https://github.com/kulinseth * [MPS] Fix fill_ where input tensor has a storage offset (#95113) Fixes #94390 Apart from fixing the issue above, this PR also fixes a bug that when an input tensor can be sliced, a sliced array view is created. This array view seems to be not writable or have a different storage from the original tensor, causing incorrect results with the in-place `fill`. Pull Request resolved: #95113 Approved by: https://github.com/kulinseth * [MPS] Fix view op slicing for 2nd dim in case of 0 offset (#95381) * Fix view op slicing for 2nd dim in case of 0 offset Pull Request resolved: #95381 Approved by: https://github.com/razarmehr --------- Co-authored-by: Ramin Azarmehr <razarmehr@apple.com> Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com> Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
1 parent 4bd5c1e commit beaa5c5

File tree

4 files changed

+173
-31
lines changed

4 files changed

+173
-31
lines changed

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void printTensorNDArray(const Tensor& t) {
289289
} else {
290290
if (!mpsShape) {
291291
mpsShape = getMPSShape(_tensor);
292-
}
292+
}
293293

294294
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
295295
shape:mpsShape

aten/src/ATen/native/mps/operations/ConstantOps.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
}
1313
Tensor output = self;
1414
bool needsCopyToOutput = false;
15-
if (!self.is_contiguous()) {
15+
if (!self.is_contiguous() || self.storage_offset()) {
1616
output = empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS);
1717
needsCopyToOutput = true;
1818
}
@@ -89,7 +89,7 @@ bool fill_mps_tensor_(Tensor& self, uint8_t value) {
8989
if (self.is_contiguous()) {
9090
MPSStream* stream = getCurrentMPSStream();
9191
auto storage_byte_offset = self.storage_offset() * self.itemsize();
92-
stream->fill(mps::getMTLBufferStorage(self), 0, self.nbytes(), storage_byte_offset);
92+
stream->fill(mps::getMTLBufferStorage(self), 0, self.storage().nbytes(), storage_byte_offset);
9393
return true;
9494
}
9595
return false;

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -424,38 +424,76 @@
424424
}
425425

426426
static
427-
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape) {
427+
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
428428
bool hasMPSShape = (mpsShape != nil);
429429
std::vector<int64_t> src_view_shape;
430430
if (hasMPSShape) {
431431
int src_ndim_view = [mpsShape count];
432-
src_view_shape.resize(src_ndim_view);
433-
for (const auto i : c10::irange(src_ndim_view)) {
434-
src_view_shape[i] = [mpsShape[i] intValue];
432+
if (squeeze) {
433+
for (const auto i : c10::irange(src_ndim_view)) {
434+
if ([mpsShape[i] intValue] == 1)
435+
continue;
436+
src_view_shape.emplace_back([mpsShape[i] intValue]);
437+
}
438+
} else {
439+
src_view_shape.resize(src_ndim_view);
440+
for (const auto i : c10::irange(src_ndim_view)) {
441+
src_view_shape[i] = [mpsShape[i] intValue];
442+
}
435443
}
444+
436445
} else {
437-
src_view_shape = src.sizes().vec();
446+
if (squeeze) {
447+
IntArrayRef src_shape = src.sizes();
448+
size_t src_ndim_view = src_shape.size();
449+
for (const auto i : c10::irange(src_ndim_view)) {
450+
if (src_shape[i] == 1)
451+
continue;
452+
src_view_shape.emplace_back(src_shape[i]);
453+
}
454+
} else {
455+
src_view_shape = src.sizes().vec();
456+
}
438457
}
439458

440459
return src_view_shape;
441460
}
442461

462+
463+
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
464+
std::vector<int64_t> src_base_shape;
465+
for (const auto i : c10::irange(shape.size())) {
466+
if (shape[i] == 1)
467+
continue;
468+
src_base_shape.emplace_back(shape[i]);
469+
}
470+
471+
return src_base_shape;
472+
}
473+
474+
443475
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
444476
if (!src.is_contiguous()) {
445477
return false;
446478
}
447479

448480
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
481+
std::vector<int64_t> src_base_squeezed_shape = getSqueezedBaseShape(src, src_base_shape);
449482
size_t src_ndim_base = src_base_shape.size();
450-
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
451-
size_t src_ndim_view = src_view_shape.size();
483+
size_t src_squeezed_ndim_base = src_base_squeezed_shape.size();
484+
std::vector<int64_t> src_view_squeezed_shape = getViewShape(src, mpsShape, true);
485+
size_t src_ndim_view = getViewShape(src, mpsShape, false).size();
486+
size_t src_squeezed_ndim_view = src_view_squeezed_shape.size();
487+
452488
if (src_ndim_base != src_ndim_view) {
453489
return false;
454490
}
455491

456-
for (const auto i: c10::irange(src_ndim_base)) {
457-
if (src_view_shape[i] > src_base_shape[i]) {
458-
return false;
492+
if (src_squeezed_ndim_base == src_squeezed_ndim_view) {
493+
for (const auto i: c10::irange(src_squeezed_ndim_base)) {
494+
if (src_view_squeezed_shape[i] > src_base_squeezed_shape[i]) {
495+
return false;
496+
}
459497
}
460498
}
461499

@@ -464,40 +502,63 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
464502

465503
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) {
466504
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
467-
int src_ndim_base = src_base_shape.size();
468-
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
469-
int src_ndim_view = src_view_shape.size();
470-
471-
TORCH_CHECK(src_ndim_base == src_ndim_view);
505+
size_t src_ndim_base = src_base_shape.size();
506+
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
507+
size_t src_ndim_view = src_view_shape.size();
472508

473509
MPSNDArray *srcTensorNDArrayView = nil;
474510
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
475511
MPSNDArray *srcTensorNDArray = nil;
476512
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
477513

514+
int64_t base_idx = 0;
515+
516+
std::vector<int64_t> src_base_shape_vec;
517+
518+
if (src_ndim_view != src_ndim_base) {
519+
src_base_shape_vec.reserve(src_ndim_view);
520+
for (const auto i : c10::irange(src_ndim_view)) {
521+
if (src_view_shape[i] == 1 && src_base_shape[base_idx] != 1) {
522+
src_base_shape_vec.emplace_back(1);
523+
} else {
524+
src_base_shape_vec.emplace_back(src_base_shape[base_idx]);
525+
if (base_idx < src_ndim_base - 1)
526+
base_idx += 1;
527+
}
528+
}
529+
src_base_shape = IntArrayRef(src_base_shape_vec);
530+
src_ndim_base = src_base_shape.size();
531+
}
532+
478533
srcTensorNDArray = ndArrayFromTensor(src, getMPSShape(src_base_shape), mpsDataType);
479534
srcTensorNDArrayDesc = srcTensorNDArray.descriptor;
480535

481-
int firstDimToSlice = 0;
536+
size_t firstDimToSlice = 0;
482537
while (src_base_shape[firstDimToSlice] == src_view_shape[firstDimToSlice]) {
483538
firstDimToSlice++;
484539
}
485540

486-
int view_numel = 1;
541+
int64_t view_numel = 1;
487542
for (const auto i : c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
488543
view_numel *= src_base_shape[i];
489544
}
490545

491-
int sliceOffset = src.storage_offset() / view_numel;
546+
int64_t sliceOffset = src.storage_offset() / view_numel;
492547
// There are cases where both dimensions of a view can shrink
493548
// E.g: x = torch.randn((3,6))[1, 1:3]
494-
int nextSliceOffset = src.storage_offset() % view_numel;
549+
int64_t nextSliceOffset = 0;
550+
bool sliceNextDim = (firstDimToSlice < (src_base_shape.size() - 1)) &&
551+
(src_view_shape[firstDimToSlice + 1] != src_base_shape[firstDimToSlice + 1]);
495552

496553
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
497-
if (nextSliceOffset) {
554+
if (sliceNextDim) {
555+
if (firstDimToSlice + 1 == src_base_shape.size() - 1) {
556+
nextSliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
557+
} else {
558+
nextSliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[firstDimToSlice + 1]);
559+
}
498560
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(nextSliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice+1])}];
499561
}
500-
501562
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
502563
descriptor:srcTensorNDArrayDesc
503564
aliasing:MPSAliasingStrategyShallAlias];
@@ -696,7 +757,7 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self)
696757
{c10::ScalarType::Int, "int"},
697758
{c10::ScalarType::Short, "short"},
698759
{c10::ScalarType::Char, "char"},
699-
{c10::ScalarType::Byte, "char"},
760+
{c10::ScalarType::Byte, "uchar"},
700761
{c10::ScalarType::Bool, "bool"},
701762
};
702763

test/test_mps.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,27 @@ def helper(val, shape):
435435
helper(0, [1024])
436436
helper(0.2, [2, 3])
437437

438+
def test_fill_storage_offset(self):
439+
shape = [2, 10]
440+
val = 0.2
441+
tensor = torch.ones(shape, device="mps")
442+
tensor_mps = tensor[:][1].fill_(val)
443+
tensor_0 = torch.ones(shape, device="cpu")
444+
tensor_cpu = tensor_0[:][1].fill_(val)
445+
446+
self.assertEqual(tensor_mps, tensor_cpu)
447+
448+
shape = [1, 10]
449+
val = 0.0
450+
tensor = torch.ones(shape, device="mps")
451+
val_tensor_mps = torch.tensor(val, device="mps")
452+
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
453+
tensor_0 = torch.ones(shape, device="cpu")
454+
val_tensor_cpu = torch.tensor(val, device="cpu")
455+
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
456+
457+
self.assertEqual(tensor_mps, tensor_cpu)
458+
438459
def test_cdist_large(self, device="mps"):
439460
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
440461
x = torch.randn(100, 10, device=device)
@@ -1806,6 +1827,63 @@ def test_slice_reshape_contg_view(self):
18061827

18071828
self.assertEqual(r_mps, r_cpu)
18081829

1830+
def test_contiguous_slice_2d(self):
1831+
def helper(shape):
1832+
for i in range(0, shape[0]):
1833+
for j in range(0, shape[1]):
1834+
t_mps = torch.randn(shape, device="mps")
1835+
t_cpu = t_mps.detach().clone().cpu()
1836+
1837+
y_mps = t_mps[i:, :j]
1838+
y_cpu = t_cpu[i:, :j]
1839+
self.assertEqual(y_mps + 1, y_cpu + 1)
1840+
1841+
y_mps = t_mps[i:, j]
1842+
y_cpu = t_cpu[i:, j]
1843+
self.assertEqual(y_mps + 1, y_cpu + 1)
1844+
1845+
y_mps = t_mps[i, :j]
1846+
y_cpu = t_cpu[i, :j]
1847+
self.assertEqual(y_mps + 1, y_cpu + 1)
1848+
1849+
y_mps = t_mps[:i, :j]
1850+
y_cpu = t_cpu[:i, :j]
1851+
self.assertEqual(y_mps + 1, y_cpu + 1)
1852+
1853+
y_mps = t_mps[:i, j]
1854+
y_cpu = t_cpu[:i, j]
1855+
self.assertEqual(y_mps + 1, y_cpu + 1)
1856+
1857+
y_mps = t_mps[:i, j:]
1858+
y_cpu = t_cpu[:i, j:]
1859+
self.assertEqual(y_mps + 1, y_cpu + 1)
1860+
1861+
l = []
1862+
for N in range(1, 3):
1863+
l.append(N)
1864+
for C in range(1, 3):
1865+
l.append(C)
1866+
helper(l)
1867+
for D in range(1, 3):
1868+
l.append(D)
1869+
helper(l)
1870+
for H in range(1, 3):
1871+
l.append(H)
1872+
helper(l)
1873+
for W in range(1, 3):
1874+
l.append(W)
1875+
helper(l)
1876+
l.pop()
1877+
l.pop()
1878+
l.pop()
1879+
l.pop()
1880+
l.pop()
1881+
1882+
helper([9, 15, 4])
1883+
helper([9, 3, 2])
1884+
helper([3, 4, 18, 22])
1885+
helper([3, 4, 18, 22, 150])
1886+
18091887
def test_view_slice(self):
18101888
# https://github.com/pytorch/pytorch/issues/83995
18111889
NUM_SAMPLES = 60
@@ -1899,25 +1977,28 @@ def helper(operator):
18991977
if operator == "<=":
19001978
res_mps = x_mps <= y_mps
19011979
res_cpu = x_cpu <= y_cpu
1902-
if operator == "<":
1980+
elif operator == "<":
19031981
res_mps = x_mps < y_mps
19041982
res_cpu = x_cpu < y_cpu
1905-
if operator == ">=":
1983+
elif operator == ">=":
19061984
res_mps = x_mps >= y_mps
19071985
res_cpu = x_cpu >= y_cpu
1908-
if operator == ">":
1986+
elif operator == ">":
19091987
res_mps = x_mps >= y_mps
19101988
res_cpu = x_cpu >= y_cpu
1911-
if operator == "==":
1989+
elif operator == "==":
19121990
res_mps = x_mps == y_mps
19131991
res_cpu = x_cpu == y_cpu
1914-
if operator == "!=":
1992+
elif operator == "!=":
19151993
res_mps = x_mps != y_mps
19161994
res_cpu = x_cpu != y_cpu
1995+
elif operator == "stack":
1996+
res_mps = torch.stack((y_mps, x_mps), dim=-1)
1997+
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
19171998

19181999
self.assertEqual(res_mps, res_cpu)
19192000

1920-
for op in ["<=", "<", ">=", ">", "==", "!="]:
2001+
for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
19212002
helper(op)
19222003

19232004
def test_slice_of_slice(self):

0 commit comments

Comments
 (0)