@@ -83,6 +83,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
8383 pool2d_shape_check (input, kH , kW , dH, dW, padH, padW, dilationH, dilationW,
8484 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
8585
86+ auto output_memory_format = output.suggest_memory_format ();
8687 // the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
8788 // by simply restriding them (instead of calling the costly Contiguous()).
8889 if (indices.suggest_memory_format () == MemoryFormat::ChannelsLast) {
@@ -94,8 +95,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
9495 outputSizes.insert (outputSizes.begin (), nbatch);
9596 }
9697 output.resize_ (outputSizes);
97- } else if (output. suggest_memory_format () == MemoryFormat::ChannelsLast) {
98+ } else if (output_memory_format == MemoryFormat::ChannelsLast) {
9899 output.unsafeGetTensorImpl ()->empty_tensor_restride (MemoryFormat::Contiguous);
100+ output_memory_format = MemoryFormat::Contiguous;
99101 }
100102
101103 if (output.numel () == 0 || (is_backward_pass && grad_output.numel () == 0 )) {
@@ -196,6 +198,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
196198 }
197199
198200 runMPSGraph (mpsStream, cachedGraph->graph (), feeds, results);
201+
202+ if (output_memory_format != suggested_memory_format) {
203+ const_cast <Tensor&>(output) = output.to (suggested_memory_format);
204+ }
199205 }
200206}
201207
@@ -302,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
302308
303309} // namespace mps
304310
305- Tensor _mps_max_pool2d (
311+ Tensor mps_max_pool2d (
306312 const Tensor& input,
307313 IntArrayRef kernel_size,
308314 IntArrayRef stride,
@@ -356,6 +362,8 @@ Tensor mps_max_pool2d_backward(
356362 const Tensor& output,
357363 const Tensor& indices) {
358364
365+ auto indices_memory_format = indices.suggest_memory_format ();
366+
359367 mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn (cachedGraph, desc) {
360368 MPSGraph* mpsGraph = cachedGraph.graph ();
361369 NSArray <MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
@@ -366,6 +374,10 @@ Tensor mps_max_pool2d_backward(
366374 };
367375 mps::pool2d_template (input, output, indices, c10::nullopt , kernel_size, stride,
368376 padding, dilation, ceil_mode, false , c10::nullopt , pooling_op_block, " max_pool2d_indices" );
377+
378+ if (indices_memory_format == MemoryFormat::ChannelsLast) {
379+ const_cast <Tensor&>(indices) = indices.to (MemoryFormat::ChannelsLast);
380+ }
369381}
370382
371383TORCH_IMPL_FUNC (max_pool2d_with_indices_backward_out_mps)(
0 commit comments