@@ -510,7 +510,6 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
510510 MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil ;
511511 MPSNDArray *srcTensorNDArray = nil ;
512512 id <MTLCommandBuffer > commandBuffer = getCurrentMPSStream ()->commandBuffer ();
513-
514513 int64_t base_idx = 0 ;
515514
516515 std::vector<int64_t > src_base_shape_vec;
@@ -544,20 +543,20 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
544543 }
545544
546545 int64_t sliceOffset = src.storage_offset () / view_numel;
547- // There are cases where both dimensions of a view can shrink
548- // E.g: x = torch.randn((3,6))[1, 1:3]
549- int64_t nextSliceOffset = 0 ;
550- bool sliceNextDim = (firstDimToSlice < (src_base_shape.size () - 1 )) &&
551- (src_view_shape[firstDimToSlice + 1 ] != src_base_shape[firstDimToSlice + 1 ]);
552-
553- [srcTensorNDArrayDesc sliceDimension: src_ndim_base - 1 - firstDimToSlice withSubrange: {static_cast <NSUInteger >(sliceOffset), static_cast <NSUInteger >(src.sizes ()[firstDimToSlice])}];
554- if (sliceNextDim) {
555- if (firstDimToSlice + 1 == src_base_shape.size () - 1 ) {
556- nextSliceOffset = src.storage_offset () % src_base_shape[src_base_shape.size () - 1 ];
557- } else {
558- nextSliceOffset = (src.storage_offset () % view_numel) / (view_numel / src_base_shape[firstDimToSlice + 1 ]);
546+ [srcTensorNDArrayDesc sliceDimension: src_ndim_base - 1 - firstDimToSlice
547+ withSubrange: {static_cast <NSUInteger >(sliceOffset), static_cast <NSUInteger >(src.sizes ()[firstDimToSlice])}];
548+
549+ // Slice any remaining dimensions
550+ for (const auto crtSliceOffset: c10::irange (firstDimToSlice + 1 , src_base_shape.size ())) {
551+ if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
552+ if (crtSliceOffset == src_base_shape.size () - 1 ) {
553+ sliceOffset = src.storage_offset () % src_base_shape[src_base_shape.size () - 1 ];
554+ } else {
555+ sliceOffset = (src.storage_offset () % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
556+ }
557+ [srcTensorNDArrayDesc sliceDimension: src_ndim_base - 1 - crtSliceOffset
558+ withSubrange: {static_cast <NSUInteger >(sliceOffset), static_cast <NSUInteger >(src.sizes ()[crtSliceOffset])}];
559559 }
560- [srcTensorNDArrayDesc sliceDimension: src_ndim_base - 2 - firstDimToSlice withSubrange: {static_cast <NSUInteger >(nextSliceOffset), static_cast <NSUInteger >(src.sizes ()[firstDimToSlice+1 ])}];
561560 }
562561 srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer: commandBuffer
563562 descriptor: srcTensorNDArrayDesc
0 commit comments