Skip to content

Commit 82b078b

Browse files
[MPS] Fix views with 3 or more sliced dimensions (#95762) (#95871)
Fixes #95482 Pull Request resolved: #95762 Approved by: https://github.com/razarmehr Co-authored-by: Denis Vieriu <dvieriu@apple.com>
1 parent 77f7bc5 commit 82b078b

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,6 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
510510
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
511511
MPSNDArray *srcTensorNDArray = nil;
512512
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
513-
514513
int64_t base_idx = 0;
515514

516515
std::vector<int64_t> src_base_shape_vec;
@@ -544,20 +543,20 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
544543
}
545544

546545
int64_t sliceOffset = src.storage_offset() / view_numel;
547-
// There are cases where both dimensions of a view can shrink
548-
// E.g: x = torch.randn((3,6))[1, 1:3]
549-
int64_t nextSliceOffset = 0;
550-
bool sliceNextDim = (firstDimToSlice < (src_base_shape.size() - 1)) &&
551-
(src_view_shape[firstDimToSlice + 1] != src_base_shape[firstDimToSlice + 1]);
552-
553-
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
554-
if (sliceNextDim) {
555-
if (firstDimToSlice + 1 == src_base_shape.size() - 1) {
556-
nextSliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
557-
} else {
558-
nextSliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[firstDimToSlice + 1]);
546+
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
547+
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
548+
549+
// Slice any remaining dimensions
550+
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
551+
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
552+
if (crtSliceOffset == src_base_shape.size() - 1) {
553+
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
554+
} else {
555+
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
556+
}
557+
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
558+
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
559559
}
560-
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(nextSliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice+1])}];
561560
}
562561
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
563562
descriptor:srcTensorNDArrayDesc

test/test_mps.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,6 +1884,15 @@ def helper(shape):
18841884
helper([3, 4, 18, 22])
18851885
helper([3, 4, 18, 22, 150])
18861886

1887+
def test_contiguous_slice_3d(self):
1888+
x = torch.randn(2, 3, 3, device="mps")
1889+
x_cpu = x.detach().clone().cpu()
1890+
x = x[:1]
1891+
x_cpu = x_cpu[:1]
1892+
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
1893+
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
1894+
self.assertEqual(out, out_cpu)
1895+
18871896
def test_view_slice(self):
18881897
# https://github.com/pytorch/pytorch/issues/83995
18891898
NUM_SAMPLES = 60

0 commit comments

Comments
 (0)