4040import caffe2 .python .onnx .backend as c2
4141
4242from test_pytorch_common import skipIfTravis , skipIfNoLapack , skipIfNoCuda
43- from test_pytorch_common import skipIfUnsupportedOpsetVersion
43+ from test_pytorch_common import skipIfUnsupportedOpsetVersion , skipIfUnsupportedMinOpsetVersion
4444import verify
4545
4646skip = unittest .skip
@@ -664,6 +664,7 @@ def forward(self, input):
664664 input = torch .empty (BATCH_SIZE , 10 , 10 ).uniform_ (4 , 9 )
665665 self .run_model_test (MyModel (), train = False , input = input , batch_size = BATCH_SIZE )
666666
667+ @skipIfUnsupportedMinOpsetVersion (9 )
667668 def test_erf (self ):
668669 class MyModel (torch .nn .Module ):
669670 def __init__ (self ):
@@ -807,16 +808,19 @@ def test_adaptive_avg_pool3D(self):
807808 x = torch .randn (20 , 16 , 50 , 44 , 30 , requires_grad = True )
808809 self .run_model_test (model , train = False , input = x , batch_size = BATCH_SIZE )
809810
811+ @skipIfUnsupportedMinOpsetVersion (8 )
810812 def test_adaptive_max_pool1D (self ):
811813 model = torch .nn .AdaptiveMaxPool1d ((5 ))
812814 x = torch .randn (20 , 16 , 50 , requires_grad = True )
813815 self .run_model_test (model , train = False , input = x , batch_size = BATCH_SIZE )
814816
817+ @skipIfUnsupportedMinOpsetVersion (8 )
815818 def test_adaptive_max_pool2D (self ):
816819 model = torch .nn .AdaptiveMaxPool2d ((5 , 4 ))
817820 x = torch .randn (20 , 16 , 50 , 32 , requires_grad = True )
818821 self .run_model_test (model , train = False , input = x , batch_size = BATCH_SIZE )
819822
823+ @skipIfUnsupportedMinOpsetVersion (8 )
820824 def test_adaptive_max_pool3D (self ):
821825 model = torch .nn .AdaptiveMaxPool3d ((5 , 4 , 3 ))
822826 x = torch .randn (20 , 16 , 50 , 44 , 30 , requires_grad = True )
@@ -993,7 +997,7 @@ def forward(self, x):
993997 self .run_model_test (model , train = False , input = (x ),
994998 batch_size = BATCH_SIZE , use_gpu = False )
995999
996- @skipIfUnsupportedOpsetVersion ([10 ])
1000+ @skipIfUnsupportedOpsetVersion ([7 , 8 , 10 ])
9971001 def test_interpolate_upsample_dynamic_sizes (self ):
9981002 class MyModel (torch .nn .Module ):
9991003 def __init__ (self ):
@@ -1291,6 +1295,7 @@ def forward(self, x):
12911295 self .run_model_test (FullClass (), train = False , input = (x ,), batch_size = BATCH_SIZE ,
12921296 use_gpu = False , example_outputs = FullClass ()(x ))
12931297
1298+ @skipIfUnsupportedMinOpsetVersion (9 )
12941299 def test_where_functional (self ):
12951300 class WhereFunctional (torch .nn .Module ):
12961301 def forward (self , x ):
@@ -1299,6 +1304,7 @@ def forward(self, x):
12991304 x = torch .randn (3 , 4 )
13001305 self .run_model_test (WhereFunctional (), train = False , input = (x ,), batch_size = BATCH_SIZE , use_gpu = False )
13011306
1307+ @skipIfUnsupportedMinOpsetVersion (9 )
13021308 def test_where_method (self ):
13031309 class WhereMethod (torch .nn .Module ):
13041310 def forward (self , x ):
@@ -1353,6 +1359,7 @@ def forward(self, x):
13531359 self .run_model_test (RsubModel (), train = False , input = (x ,),
13541360 batch_size = BATCH_SIZE , use_gpu = False )
13551361
1362+ @skipIfUnsupportedMinOpsetVersion (9 )
13561363 def test_isnan (self ):
13571364 class IsNaNModel (torch .nn .Module ):
13581365 def forward (self , input ):
@@ -1361,6 +1368,7 @@ def forward(self, input):
13611368 x = torch .tensor ([1.0 , float ('nan' ), 2.0 ])
13621369 self .run_model_test (IsNaNModel (), train = False , input = x , batch_size = BATCH_SIZE , use_gpu = False )
13631370
1371+ @skipIfUnsupportedMinOpsetVersion (9 )
13641372 def test_scatter (self ):
13651373 class ScatterModel (torch .nn .Module ):
13661374 def forward (self , input , indices , values ):
@@ -1396,6 +1404,23 @@ def forward(self, input):
13961404 x = torch .randn (4 , 4 , requires_grad = True )
13971405 self .run_model_test (MaxModel (), train = False , input = x , batch_size = BATCH_SIZE )
13981406
1407+ def test_max_keepdim (self ):
1408+ class MaxModel (torch .nn .Module ):
1409+ def forward (self , input ):
1410+ return torch .max (input , dim = 1 , keepdim = True )
1411+
1412+ x = torch .randn (4 , 4 , requires_grad = True )
1413+ self .run_model_test (MaxModel (), train = False , input = x , batch_size = BATCH_SIZE )
1414+
1415+ def test_max_tensors (self ):
1416+ class MaxModel (torch .nn .Module ):
1417+ def forward (self , input , other ):
1418+ return torch .max (input , other )
1419+
1420+ x = torch .randn (4 , 4 , requires_grad = True )
1421+ y = torch .randn (4 , 4 , requires_grad = True )
1422+ self .run_model_test (MaxModel (), train = False , input = (x , y ), batch_size = BATCH_SIZE )
1423+
13991424 def test_min (self ):
14001425 class MinModel (torch .nn .Module ):
14011426 def forward (self , input ):
@@ -1841,6 +1866,7 @@ def forward(self, x):
18411866 x = torch .randn (1 , 2 , 3 )
18421867 self .run_model_test (DropoutModel (), train = False , input = x , batch_size = BATCH_SIZE )
18431868
1869+ @skipIfUnsupportedMinOpsetVersion (9 )
18441870 def test_while (self ):
18451871 class WhileModel (torch .jit .ScriptModule ):
18461872 @torch .jit .script_method
@@ -1901,6 +1927,7 @@ def forward(self, x):
19011927 self .run_model_test (model , train = False , input = (inputs ,), batch_size = BATCH_SIZE ,
19021928 example_outputs = (outputs ,))
19031929
1930+ @skipIfUnsupportedMinOpsetVersion (9 )
19041931 def test_nested_loops (self ):
19051932 class NestedLoopsModel (torch .jit .ScriptModule ):
19061933 @torch .jit .script_method
@@ -2013,6 +2040,24 @@ def setup_rnn_tests():
20132040 (unittest .TestCase ,),
20142041 dict (TestCaffe2Backend_opset9 .__dict__ , embed_params = True ))
20152042
2043+ # opset 7 tests
2044+ TestCaffe2Backend_opset7 = type (str ("TestCaffe2Backend_opset7" ),
2045+ (unittest .TestCase ,),
2046+ dict (TestCaffe2Backend_opset9 .__dict__ , opset_version = 7 ))
2047+ TestCaffe2BackendEmbed_opset7 = type (str ("TestCaffe2BackendEmbed_opset7" ),
2048+ (unittest .TestCase ,),
2049+ dict (TestCaffe2Backend_opset9 .__dict__ ,
2050+ embed_params = True , opset_version = 7 ))
2051+
2052+ # opset 8 tests
2053+ TestCaffe2Backend_opset8 = type (str ("TestCaffe2Backend_opset8" ),
2054+ (unittest .TestCase ,),
2055+ dict (TestCaffe2Backend_opset9 .__dict__ , opset_version = 8 ))
2056+ TestCaffe2BackendEmbed_opset8 = type (str ("TestCaffe2BackendEmbed_opset8" ),
2057+ (unittest .TestCase ,),
2058+ dict (TestCaffe2Backend_opset9 .__dict__ ,
2059+ embed_params = True , opset_version = 8 ))
2060+
20162061# opset 10 tests
20172062TestCaffe2Backend_opset10 = type (str ("TestCaffe2Backend_opset10" ),
20182063 (unittest .TestCase ,),
0 commit comments