|
25 | 25 | from __future__ import annotations |
26 | 26 |
|
27 | 27 | import copy |
28 | | -from typing import Optional, Tuple |
| 28 | +from typing import Any, Callable, Collection, Optional, Tuple, Union |
29 | 29 |
|
30 | 30 | import onnx_test_common |
31 | 31 |
|
|
162 | 162 | ] |
163 | 163 | ) |
164 | 164 |
|
| 165 | + |
| 166 | +# NOTE: For ATen signature modifications that will break ONNX export, |
| 167 | +# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip |
| 168 | +# to make the signal apparent for maintainers. |
| 169 | +def xfail_torchlib_forward_compatibility( |
| 170 | + op_name: str, |
| 171 | + variant_name: str = "", |
| 172 | + *, |
| 173 | + reason: str, |
| 174 | + github_issue: str, |
| 175 | + opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, |
| 176 | + dtypes: Optional[Collection[torch.dtype]] = None, |
| 177 | + matcher: Optional[Callable[[Any], bool]] = None, |
| 178 | + enabled_if: bool = True, |
| 179 | +): |
| 180 | + """Prefer using this (xfail) over skip when possible. |
| 181 | +
|
| 182 | + Only skip when the test is not failing consistently. |
| 183 | + """ |
| 184 | + return xfail( |
| 185 | + op_name, |
| 186 | + variant_name=variant_name, |
| 187 | + reason=f"{reason}. GitHub Issue: {github_issue}", |
| 188 | + opsets=opsets, |
| 189 | + dtypes=dtypes, |
| 190 | + matcher=matcher, |
| 191 | + enabled_if=enabled_if, |
| 192 | + ) |
| 193 | + |
| 194 | + |
| 195 | +def skip_torchlib_forward_compatibility( |
| 196 | + op_name: str, |
| 197 | + variant_name: str = "", |
| 198 | + *, |
| 199 | + reason: str, |
| 200 | + github_issue: str, |
| 201 | + opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, |
| 202 | + dtypes: Optional[Collection[torch.dtype]] = None, |
| 203 | + matcher: Optional[Callable[[Any], Any]] = None, |
| 204 | + enabled_if: bool = True, |
| 205 | +): |
| 206 | + """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. |
| 207 | +
|
| 208 | + Only skip when the test is not failing consistently. |
| 209 | + """ |
| 210 | + return skip( |
| 211 | + op_name, |
| 212 | + variant_name=variant_name, |
| 213 | + reason=f"{reason}. GitHub Issue: {github_issue}", |
| 214 | + opsets=opsets, |
| 215 | + dtypes=dtypes, |
| 216 | + matcher=matcher, |
| 217 | + enabled_if=enabled_if, |
| 218 | + ) |
| 219 | + |
| 220 | + |
165 | 221 | # fmt: off |
166 | 222 | # Turn off black formatting to keep the list compact |
167 | 223 |
|
|
541 | 597 | matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), |
542 | 598 | reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", |
543 | 599 | ), |
544 | | - skip( |
| 600 | + skip_torchlib_forward_compatibility( |
545 | 601 | "nn.functional.embedding_bag", |
546 | 602 | matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, |
547 | | - reason=( |
548 | | - "Torchlib does not support 'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " |
549 | | - "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided. " |
550 | | - "See https://github.com/microsoft/onnxscript/issues/1056 for details." |
| 603 | + reason=onnx_test_common.reason_onnx_script_does_not_support( |
| 604 | + "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " |
| 605 | + "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" |
551 | 606 | ), |
| 607 | + github_issue="https://github.com/microsoft/onnxscript/issues/1056", |
552 | 608 | ), |
553 | 609 | skip( |
554 | 610 | "nn.functional.max_pool3d", |
555 | 611 | matcher=lambda sample: sample.kwargs.get("ceil_mode") is True |
556 | 612 | and sample.kwargs.get("padding") == 1, |
557 | 613 | reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", |
558 | 614 | ), |
559 | | - skip( |
560 | | - "nn.functional.nll_loss", |
561 | | - matcher=lambda sample: isinstance(sample.kwargs.get("reduction"), str), |
562 | | - reason=onnx_test_common.reason_onnx_script_does_not_support( |
563 | | - "string in reduction kwarg: https://github.com/microsoft/onnxscript/issues/726" |
564 | | - ), |
565 | | - ), |
566 | 615 | xfail( |
567 | 616 | "nonzero", |
568 | 617 | matcher=lambda sample: len(sample.input.shape) == 0 |
|
0 commit comments