Skip to content

Commit b3147bc

Browse files
BowenBaofacebook-github-bot
authored andcommitted
PyTorch export to ONNX Opset 7 and 8 - Cont (#22421)
Summary: This is an extension to the original PR #21765 1. Increase the coverage of different opsets support, comments, and blacklisting. 2. Adding backend tests for both caffe2 and onnxruntime on opset 7 and opset 8. 3. Reusing onnx model tests in caffe2 for onnxruntime. Pull Request resolved: #22421 Reviewed By: zrphercule Differential Revision: D16225518 Pulled By: houseroad fbshipit-source-id: 01ae3eed85111a83a0124e9e95512b80109d6aee
1 parent 9f8e2c0 commit b3147bc

19 files changed

+798
-73
lines changed

scripts/onnx/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pytest "${args[@]}" \
5555
'not (TestOperators and test_full_like) and not (TestOperators and test_zeros_like) and not (TestOperators and test_ones_like) and not (TestModels and test_vgg16) and not (TestModels and test_vgg16_bn) and not (TestModels and test_vgg19) and not (TestModels and test_vgg19_bn)' \
5656
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
5757
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
58+
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
5859
"${test_paths[@]}"
5960

6061
# onnxruntime only support py3
@@ -63,5 +64,6 @@ if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
6364
pip install --user onnxruntime
6465
pytest "${args[@]}" "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
6566
pytest "${args[@]}" "$top_dir/test/onnx/test_custom_ops.py"
67+
pytest "${args[@]}" "$top_dir/test/onnx/test_models_onnxruntime.py"
6668
fi
6769

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import unittest
7+
import onnxruntime # noqa
8+
9+
from test_models import TestModels
10+
from test_pytorch_onnx_onnxruntime import run_model_test
11+
12+
13+
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
14+
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10]
15+
for opset_version in opset_versions:
16+
self.opset_version = opset_version
17+
run_model_test(self, model, False,
18+
input=inputs, rtol=rtol, atol=atol)
19+
20+
21+
if __name__ == '__main__':
22+
TestModels.exportTest = exportTest
23+
unittest.main()

test/onnx/test_onnx_opset.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,45 @@ def test_maxpool(self):
132132
x = torch.randn(20, 16, 50)
133133
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
134134

135+
def test_upsample(self):
136+
class MyModule(Module):
137+
def __init__(self):
138+
super(MyModule, self).__init__()
139+
140+
def forward(self, x):
141+
size = [v * 2 for v in x.size()[2:]]
142+
size = [int(i) for i in size]
143+
return torch.nn.functional.interpolate(x, size=size, mode='nearest')
144+
145+
module = MyModule()
146+
ops8 = [{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3},
147+
{"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6}]}]
148+
ops9 = [{"op_name" : "Constant"},
149+
{"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}]
150+
ops = {8 : ops8, 9 : ops9}
151+
x = torch.randn(2, 2, 2, 2)
152+
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
153+
154+
def test_cast_constant(self):
155+
class MyModule(Module):
156+
def __init__(self):
157+
super(MyModule, self).__init__()
158+
159+
def forward(self, x):
160+
return torch._dim_arange(x, 1)
161+
162+
module = MyModule()
163+
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
164+
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
165+
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
166+
{"op_name" : "Range"}]
167+
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
168+
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
169+
{"op_name" : "Range"}]
170+
ops = {8 : ops_8, 9 : ops_9}
171+
x = torch.ones(5, 6)
172+
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
173+
135174
def test_slice(self):
136175
class MyModule(Module):
137176
def forward(self, x):

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import caffe2.python.onnx.backend as c2
4141

4242
from test_pytorch_common import skipIfTravis, skipIfNoLapack, skipIfNoCuda
43-
from test_pytorch_common import skipIfUnsupportedOpsetVersion
43+
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion
4444
import verify
4545

4646
skip = unittest.skip
@@ -664,6 +664,7 @@ def forward(self, input):
664664
input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
665665
self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE)
666666

667+
@skipIfUnsupportedMinOpsetVersion(9)
667668
def test_erf(self):
668669
class MyModel(torch.nn.Module):
669670
def __init__(self):
@@ -807,16 +808,19 @@ def test_adaptive_avg_pool3D(self):
807808
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
808809
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
809810

811+
@skipIfUnsupportedMinOpsetVersion(8)
810812
def test_adaptive_max_pool1D(self):
811813
model = torch.nn.AdaptiveMaxPool1d((5))
812814
x = torch.randn(20, 16, 50, requires_grad=True)
813815
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
814816

817+
@skipIfUnsupportedMinOpsetVersion(8)
815818
def test_adaptive_max_pool2D(self):
816819
model = torch.nn.AdaptiveMaxPool2d((5, 4))
817820
x = torch.randn(20, 16, 50, 32, requires_grad=True)
818821
self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE)
819822

823+
@skipIfUnsupportedMinOpsetVersion(8)
820824
def test_adaptive_max_pool3D(self):
821825
model = torch.nn.AdaptiveMaxPool3d((5, 4, 3))
822826
x = torch.randn(20, 16, 50, 44, 30, requires_grad=True)
@@ -993,7 +997,7 @@ def forward(self, x):
993997
self.run_model_test(model, train=False, input=(x),
994998
batch_size=BATCH_SIZE, use_gpu=False)
995999

996-
@skipIfUnsupportedOpsetVersion([10])
1000+
@skipIfUnsupportedOpsetVersion([7, 8, 10])
9971001
def test_interpolate_upsample_dynamic_sizes(self):
9981002
class MyModel(torch.nn.Module):
9991003
def __init__(self):
@@ -1291,6 +1295,7 @@ def forward(self, x):
12911295
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
12921296
use_gpu=False, example_outputs=FullClass()(x))
12931297

1298+
@skipIfUnsupportedMinOpsetVersion(9)
12941299
def test_where_functional(self):
12951300
class WhereFunctional(torch.nn.Module):
12961301
def forward(self, x):
@@ -1299,6 +1304,7 @@ def forward(self, x):
12991304
x = torch.randn(3, 4)
13001305
self.run_model_test(WhereFunctional(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
13011306

1307+
@skipIfUnsupportedMinOpsetVersion(9)
13021308
def test_where_method(self):
13031309
class WhereMethod(torch.nn.Module):
13041310
def forward(self, x):
@@ -1353,6 +1359,7 @@ def forward(self, x):
13531359
self.run_model_test(RsubModel(), train=False, input=(x,),
13541360
batch_size=BATCH_SIZE, use_gpu=False)
13551361

1362+
@skipIfUnsupportedMinOpsetVersion(9)
13561363
def test_isnan(self):
13571364
class IsNaNModel(torch.nn.Module):
13581365
def forward(self, input):
@@ -1361,6 +1368,7 @@ def forward(self, input):
13611368
x = torch.tensor([1.0, float('nan'), 2.0])
13621369
self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False)
13631370

1371+
@skipIfUnsupportedMinOpsetVersion(9)
13641372
def test_scatter(self):
13651373
class ScatterModel(torch.nn.Module):
13661374
def forward(self, input, indices, values):
@@ -1396,6 +1404,23 @@ def forward(self, input):
13961404
x = torch.randn(4, 4, requires_grad=True)
13971405
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
13981406

1407+
def test_max_keepdim(self):
1408+
class MaxModel(torch.nn.Module):
1409+
def forward(self, input):
1410+
return torch.max(input, dim=1, keepdim=True)
1411+
1412+
x = torch.randn(4, 4, requires_grad=True)
1413+
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
1414+
1415+
def test_max_tensors(self):
1416+
class MaxModel(torch.nn.Module):
1417+
def forward(self, input, other):
1418+
return torch.max(input, other)
1419+
1420+
x = torch.randn(4, 4, requires_grad=True)
1421+
y = torch.randn(4, 4, requires_grad=True)
1422+
self.run_model_test(MaxModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
1423+
13991424
def test_min(self):
14001425
class MinModel(torch.nn.Module):
14011426
def forward(self, input):
@@ -1841,6 +1866,7 @@ def forward(self, x):
18411866
x = torch.randn(1, 2, 3)
18421867
self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE)
18431868

1869+
@skipIfUnsupportedMinOpsetVersion(9)
18441870
def test_while(self):
18451871
class WhileModel(torch.jit.ScriptModule):
18461872
@torch.jit.script_method
@@ -1901,6 +1927,7 @@ def forward(self, x):
19011927
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
19021928
example_outputs=(outputs,))
19031929

1930+
@skipIfUnsupportedMinOpsetVersion(9)
19041931
def test_nested_loops(self):
19051932
class NestedLoopsModel(torch.jit.ScriptModule):
19061933
@torch.jit.script_method
@@ -2013,6 +2040,24 @@ def setup_rnn_tests():
20132040
(unittest.TestCase,),
20142041
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True))
20152042

2043+
# opset 7 tests
2044+
TestCaffe2Backend_opset7 = type(str("TestCaffe2Backend_opset7"),
2045+
(unittest.TestCase,),
2046+
dict(TestCaffe2Backend_opset9.__dict__, opset_version=7))
2047+
TestCaffe2BackendEmbed_opset7 = type(str("TestCaffe2BackendEmbed_opset7"),
2048+
(unittest.TestCase,),
2049+
dict(TestCaffe2Backend_opset9.__dict__,
2050+
embed_params=True, opset_version=7))
2051+
2052+
# opset 8 tests
2053+
TestCaffe2Backend_opset8 = type(str("TestCaffe2Backend_opset8"),
2054+
(unittest.TestCase,),
2055+
dict(TestCaffe2Backend_opset9.__dict__, opset_version=8))
2056+
TestCaffe2BackendEmbed_opset8 = type(str("TestCaffe2BackendEmbed_opset8"),
2057+
(unittest.TestCase,),
2058+
dict(TestCaffe2Backend_opset9.__dict__,
2059+
embed_params=True, opset_version=8))
2060+
20162061
# opset 10 tests
20172062
TestCaffe2Backend_opset10 = type(str("TestCaffe2Backend_opset10"),
20182063
(unittest.TestCase,),

0 commit comments

Comments
 (0)