@@ -1356,23 +1356,6 @@ def forward(self, input):
13561356 caffe2_out = prepared .run (inputs = [x .cpu ().numpy ()])
13571357 self .assertEqual (caffe2_out [0 ].shape , x .shape )
13581358
1359- def test_topk (self ):
1360- class TopKModel (torch .nn .Module ):
1361- def forward (self , input ):
1362- return torch .topk (input , 3 , dim = 0 )
1363-
1364- x = torch .randn (4 , 3 , requires_grad = True )
1365- self .run_model_test (TopKModel (), train = False , input = x , batch_size = BATCH_SIZE )
1366-
1367- def test_topk_script (self ):
1368- class TopKModel (torch .jit .ScriptModule ):
1369- @torch .jit .script_method
1370- def forward (self , input ):
1371- return torch .topk (input , 3 , dim = 0 )
1372-
1373- x = torch .randn (4 , 3 , requires_grad = True )
1374- self .run_model_test (TopKModel (), train = False , input = (x ,), batch_size = BATCH_SIZE , example_outputs = torch .topk (x , 3 , dim = 0 ))
1375-
13761359 def test_traced_ints (self ):
13771360 A = 4
13781361 H = 10
@@ -1600,10 +1583,19 @@ def test_topk(self):
16001583 class TopKModel (torch .nn .Module ):
16011584 def forward (self , input ):
16021585 return torch .topk (input , 3 )
1603- model = TopKModel ()
1586+
16041587 x = torch .arange (1. , 6. )
16051588 self .run_model_test (TopKModel (), train = False , input = x , batch_size = BATCH_SIZE )
16061589
1590+ def test_topk_script (self ):
1591+ class TopKModel (torch .jit .ScriptModule ):
1592+ @torch .jit .script_method
1593+ def forward (self , input ):
1594+ return torch .topk (input , 3 , dim = 0 )
1595+
1596+ x = torch .randn (4 , 3 , requires_grad = True )
1597+ self .run_model_test (TopKModel (), train = False , input = (x ,), batch_size = BATCH_SIZE , example_outputs = torch .topk (x , 3 , dim = 0 ))
1598+
16071599 def test_floor (self ):
16081600 class FloorModel (torch .nn .Module ):
16091601 def forward (self , input ):
0 commit comments