Skip to content

Commit e76ae8e

Browse files
committed
merge tests for topk
1 parent 094850f commit e76ae8e

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,23 +1356,6 @@ def forward(self, input):
13561356
caffe2_out = prepared.run(inputs=[x.cpu().numpy()])
13571357
self.assertEqual(caffe2_out[0].shape, x.shape)
13581358

1359-
def test_topk(self):
1360-
class TopKModel(torch.nn.Module):
1361-
def forward(self, input):
1362-
return torch.topk(input, 3, dim=0)
1363-
1364-
x = torch.randn(4, 3, requires_grad=True)
1365-
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
1366-
1367-
def test_topk_script(self):
1368-
class TopKModel(torch.jit.ScriptModule):
1369-
@torch.jit.script_method
1370-
def forward(self, input):
1371-
return torch.topk(input, 3, dim=0)
1372-
1373-
x = torch.randn(4, 3, requires_grad=True)
1374-
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
1375-
13761359
def test_traced_ints(self):
13771360
A = 4
13781361
H = 10
@@ -1600,10 +1583,19 @@ def test_topk(self):
16001583
class TopKModel(torch.nn.Module):
16011584
def forward(self, input):
16021585
return torch.topk(input, 3)
1603-
model = TopKModel()
1586+
16041587
x = torch.arange(1., 6.)
16051588
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
16061589

1590+
def test_topk_script(self):
1591+
class TopKModel(torch.jit.ScriptModule):
1592+
@torch.jit.script_method
1593+
def forward(self, input):
1594+
return torch.topk(input, 3, dim=0)
1595+
1596+
x = torch.randn(4, 3, requires_grad=True)
1597+
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
1598+
16071599
def test_floor(self):
16081600
class FloorModel(torch.nn.Module):
16091601
def forward(self, input):

0 commit comments

Comments
 (0)