Skip to content

Commit 1211cee

Browse files
kulinsethDenisVieriu97qqaatwrazarmehr
authored
[MPS] Fix issues with max_pool2d (#95325)
* [MPS] Fix upsample for NHWC output (#94963) Fixes huggingface/diffusers#941 **Before**: <img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png"> **After**: <img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png"> Pull Request resolved: #94963 Approved by: https://github.com/razarmehr * [MPS] Move max_pool2d to mps dispatch key (#90772) Related issue: #77394 This PR also modifies some assertions in the codegen, an explanatory comment for it has been added. Pull Request resolved: #90772 Approved by: https://github.com/albanD * [MPS] Convert output back to ChannelsLast for MaxPool2D (#94877) Since we re-stride the indices and output in MPS pooling from ChannelsLast to Contiguous, we need to convert the results back to ChannelsLast. This will fix the failure with test_memory_format with MaxPool2D in test_modules.py. Pull Request resolved: #94877 Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97 --------- Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com> Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com> Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
1 parent beaa5c5 commit 1211cee

File tree

9 files changed

+51
-30
lines changed

9 files changed

+51
-30
lines changed

aten/src/ATen/native/Pooling.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <ATen/Functions.h>
1010
#include <ATen/NativeFunctions.h>
1111
#else
12-
#include <ATen/ops/_mps_max_pool2d.h>
1312
#include <ATen/ops/adaptive_avg_pool1d_native.h>
1413
#include <ATen/ops/adaptive_avg_pool2d.h>
1514
#include <ATen/ops/adaptive_max_pool1d_native.h>
@@ -141,12 +140,6 @@ Tensor max_pool2d(
141140
return at::mkldnn_max_pool2d(
142141
self, kernel_size, stride, padding, dilation, ceil_mode);
143142
}
144-
#ifdef USE_MPS
145-
if (self.is_mps()) {
146-
return at::_mps_max_pool2d(
147-
self, kernel_size, stride, padding, dilation, ceil_mode);
148-
}
149-
#endif
150143
#if defined(C10_MOBILE)
151144
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
152145
dilation, ceil_mode)) {

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
8383
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
8484
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
8585

86+
auto output_memory_format = output.suggest_memory_format();
8687
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
8788
// by simply restriding them (instead of calling the costly Contiguous()).
8889
if (indices.suggest_memory_format() == MemoryFormat::ChannelsLast) {
@@ -94,8 +95,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
9495
outputSizes.insert(outputSizes.begin(), nbatch);
9596
}
9697
output.resize_(outputSizes);
97-
} else if (output.suggest_memory_format() == MemoryFormat::ChannelsLast) {
98+
} else if (output_memory_format == MemoryFormat::ChannelsLast) {
9899
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
100+
output_memory_format = MemoryFormat::Contiguous;
99101
}
100102

101103
if (output.numel() == 0 || (is_backward_pass && grad_output.numel() == 0)) {
@@ -196,6 +198,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
196198
}
197199

198200
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
201+
202+
if (output_memory_format != suggested_memory_format) {
203+
const_cast<Tensor&>(output) = output.to(suggested_memory_format);
204+
}
199205
}
200206
}
201207

@@ -302,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
302308

303309
} // namespace mps
304310

305-
Tensor _mps_max_pool2d(
311+
Tensor mps_max_pool2d(
306312
const Tensor& input,
307313
IntArrayRef kernel_size,
308314
IntArrayRef stride,
@@ -356,6 +362,8 @@ Tensor mps_max_pool2d_backward(
356362
const Tensor& output,
357363
const Tensor& indices) {
358364

365+
auto indices_memory_format = indices.suggest_memory_format();
366+
359367
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
360368
MPSGraph* mpsGraph = cachedGraph.graph();
361369
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
@@ -366,6 +374,10 @@ Tensor mps_max_pool2d_backward(
366374
};
367375
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
368376
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
377+
378+
if (indices_memory_format == MemoryFormat::ChannelsLast) {
379+
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
380+
}
369381
}
370382

371383
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(

aten/src/ATen/native/mps/operations/UpSample.mm

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ void upsample_out_template(const Tensor& input,
2626
} else {
2727
native::upsample_2d_common_check(input.sizes(), output_size);
2828
}
29+
Tensor out;
30+
if (!output.is_contiguous()) {
31+
out = at::empty_like(output, MemoryFormat::Contiguous);
32+
}
33+
2934
bool centerResults = false;
3035
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
3136
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
@@ -199,7 +204,7 @@ void upsample_out_template(const Tensor& input,
199204
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
200205

201206
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
202-
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
207+
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
203208

204209
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
205210
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
@@ -209,6 +214,10 @@ void upsample_out_template(const Tensor& input,
209214
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
210215
};
211216
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
217+
218+
if (out.has_storage()) {
219+
output.copy_(out);
220+
}
212221
}
213222
}
214223

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,19 +3567,14 @@
35673567
- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
35683568

35693569
- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
3570-
3571-
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
3572-
# native_functions.yaml
3573-
# https://github.com/pytorch/pytorch/issues/77394
3574-
- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
35753570
dispatch:
3576-
MPS: _mps_max_pool2d
3577-
autogen: _mps_max_pool2d.out
3571+
CompositeImplicitAutograd: max_pool2d
3572+
MPS: mps_max_pool2d
35783573

3579-
- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
3574+
- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
35803575
dispatch:
35813576
MPS: mps_max_pool2d_backward
3582-
autogen: mps_max_pool2d_backward.out
3577+
autogen: max_pool2d_backward.out
35833578

35843579
- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
35853580
dispatch:

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,6 @@ aten::_mps_convolution
377377
aten::_mps_convolution.out
378378
aten::_mps_convolution_transpose
379379
aten::_mps_convolution_transpose.out
380-
aten::_mps_max_pool2d
381-
aten::_mps_max_pool2d.out
382380
aten::_native_batch_norm_legit.no_stats_out
383381
aten::_native_batch_norm_legit.out
384382
aten::_native_decoder_only_multi_head_attention
@@ -857,6 +855,8 @@ aten::max
857855
aten::max.dim
858856
aten::max.dim_max
859857
aten::max.unary_out
858+
aten::max_pool2d_backward
859+
aten::max_pool2d_backward.out
860860
aten::max_pool2d_with_indices
861861
aten::max_pool2d_with_indices.out
862862
aten::max_pool2d_with_indices_backward
@@ -930,8 +930,6 @@ aten::mps_convolution_backward
930930
aten::mps_convolution_backward.out
931931
aten::mps_convolution_transpose_backward
932932
aten::mps_convolution_transpose_backward.out
933-
aten::mps_max_pool2d_backward
934-
aten::mps_max_pool2d_backward.out
935933
aten::multi_margin_loss
936934
aten::multi_margin_loss.out
937935
aten::multi_margin_loss_backward

test/forward_backward_compatibility/check_forward_backward_compatibility.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@
150150
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
151151
("aten::mps_linear", datetime.date(9999, 1, 1)),
152152
("aten::_mps_linear", datetime.date(9999, 1, 1)),
153+
("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
154+
("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
155+
("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
156+
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
153157
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
154158
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
155159
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),

test/test_mps.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4655,9 +4655,9 @@ def test_sort(self):
46554655
)
46564656

46574657
def test_upsample_nearest2d(self):
4658-
def helper(N, C, H, W):
4658+
def helper(N, C, H, W, memory_format):
46594659
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
4660-
requires_grad=True).reshape(N, C, H, W)
4660+
requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
46614661
inputCPU.retain_grad()
46624662
inputMPS = inputCPU.detach().to('mps').requires_grad_()
46634663

@@ -4683,8 +4683,9 @@ def helper(N, C, H, W):
46834683

46844684
self.assertEqual(inputCPU.grad, inputMPS.grad)
46854685

4686-
helper(1, 1, 4, 4)
4687-
helper(7, 5, 3, 2)
4686+
for memory_format in [torch.channels_last, torch.contiguous_format]:
4687+
helper(1, 1, 4, 4, memory_format=memory_format)
4688+
helper(7, 5, 3, 2, memory_format=memory_format)
46884689

46894690
def test_upsample_bilinear2d(self):
46904691
def helper(N, C, H, W):

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,8 +2170,8 @@
21702170
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
21712171

21722172
#mps
2173-
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
2174-
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
2173+
- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
2174+
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
21752175

21762176
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
21772177
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

torchgen/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ def from_yaml(
638638
raw_dispatch = e.pop("dispatch", None)
639639
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
640640
dispatch: Dict[DispatchKey, BackendMetadata] = {}
641+
num_dispatch_keys: int = 0
641642
if raw_dispatch is not None:
642643
assert not manual_kernel_registration, (
643644
"cannot specify both manual_kernel_registration and dispatch; with "
@@ -650,6 +651,8 @@ def from_yaml(
650651
assert isinstance(ks, str), e
651652
for k in ks.split(","):
652653
dispatch_key = DispatchKey.parse(k.strip())
654+
num_dispatch_keys += 1
655+
653656
if ignore_keys and dispatch_key in ignore_keys:
654657
continue
655658
assert dispatch_key in dispatch_keys, (
@@ -677,7 +680,12 @@ def from_yaml(
677680
):
678681
redundant_composite_implicit_autograd = True
679682

680-
assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), (
683+
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
684+
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
685+
# from being treated as redundant.
686+
assert not (
687+
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
688+
), (
681689
"unnecessary dispatch table for this function; just delete the dispatch "
682690
"key entirely"
683691
)
@@ -687,6 +695,7 @@ def from_yaml(
687695
structured_delegate
688696
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
689697
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
698+
or num_dispatch_keys != 1
690699
), (
691700
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
692701
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "

0 commit comments

Comments
 (0)