Skip to content

Commit b234d94

Browse files
authored
[ONNX] Improve error handling for adaptive_pool (#45874) (#46100)
Summary: Duplicate of #43032 This update would also improve error handling for interpolate with 'area' mode. Pull Request resolved: #45874 Reviewed By: albanD Differential Revision: D24141266 Pulled By: bzinodev fbshipit-source-id: 7559f1d6af4f1ef3507c15a1aee76fe01fa433cd
1 parent 36fe788 commit b234d94

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,14 @@ def forward(self, x, y):
16671667
y = torch.randn(16, 16, requires_grad=True)
16681668
self.run_test(MyModel(), (x, y))
16691669

1670-
@disableScriptTest()
1670+
def test_interpolate_adaptive_pooling_error(self):
1671+
x = torch.randn(1, 2, 6, requires_grad=True)
1672+
with self.assertRaises(RuntimeError) as cm:
1673+
self._interpolate(x, "area", True, True)
1674+
1675+
with self.assertRaises(RuntimeError) as cm:
1676+
self._interpolate(x, "area", False, True)
1677+
16711678
def test_groupnorm(self):
16721679
model = torch.nn.GroupNorm(3, 6, 0.002)
16731680
x = torch.randn(4, 6, 180, 180, 180)

torch/onnx/symbolic_opset9.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,6 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include
826826

827827

828828
def _adaptive_pool(name, type, tuple_fn, fn=None):
829-
@parse_args('v', 'is')
830829
def symbolic_fn(g, input, output_size):
831830
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
832831
# by executing a GlobalPool.
@@ -837,6 +836,10 @@ def symbolic_fn(g, input, output_size):
837836
# so we try using max_poolxd_with_indices, and if it is not possible
838837
# (input is not a complete tensor or output size not factor of input size)
839838
# then we call GlobalAveragePool and return None for the indices
839+
try:
840+
output_size = _parse_arg(output_size, 'is')
841+
except Exception:
842+
return sym_help._onnx_unsupported('adaptive pooling, since output_size is not constant.')
840843
if output_size == [1] * len(output_size) and type == "AveragePool":
841844
return g.op("GlobalAveragePool", input)
842845
if not input.isCompleteTensor():
@@ -849,7 +852,10 @@ def symbolic_fn(g, input, output_size):
849852
if mod != [0] * len(mod):
850853
if output_size == [1] * len(output_size):
851854
return g.op("GlobalMaxPool", input), None
852-
return _unimplemented(name, 'output size that are not factor of input size')
855+
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
856+
return _unimplemented(name, 'output size that are not factor of input size')
857+
else:
858+
return sym_help._onnx_unsupported(name + ', since output size is not factor of input size')
853859
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
854860
# call max_poolxd_with_indices to get indices in the output
855861
if type == "MaxPool":

torch/onnx/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,8 +1003,7 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
10031003
else:
10041004
raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. "
10051005
"If you are trying to export a custom operator, make sure you registered "
1006-
"it with the right domain and version. "
1007-
"Otherwise, please report a bug.".format(ns, op_name))
1006+
"it with the right domain and version.".format(ns, op_name))
10081007
except RuntimeError:
10091008
if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH:
10101009
return None

0 commit comments

Comments
 (0)