Skip to content

Commit 447b17d

Browse files
committed
Add tests for ONNX pass scopes
1 parent 2fc0ec5 commit 447b17d

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

test/test_jit.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,47 @@ def f(x, y):
158158
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
159159
self.assertExpectedTrace(trace)
160160

161+
class Net(nn.Module):
162+
def forward(self, x):
163+
return F.log_softmax(x, dim=0)
164+
165+
net = Net()
166+
t = Variable(torch.ones(2), requires_grad=True)
167+
trace, _ = torch.jit.trace(net, (t, ))
168+
torch.onnx._optimize_trace(trace, False)
169+
g = torch._C._jit_get_graph(trace)
170+
for node in g.nodes():
171+
self.assertTrue(node.scopeName() == 'Net')
172+
173+
class Net(nn.Module):
174+
175+
def __init__(self, num_classes=1000):
176+
super(Net, self).__init__()
177+
self.features = nn.Sequential(
178+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
179+
nn.ReLU(inplace=True),
180+
nn.MaxPool2d(kernel_size=3, stride=2),
181+
)
182+
183+
def forward(self, x):
184+
x = self.features(x)
185+
return x
186+
187+
model = Net()
188+
189+
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
190+
191+
with torch.onnx.set_training(model, False):
192+
trace, _ = torch.jit.trace(model, (t, ))
193+
194+
torch.onnx._optimize_trace(trace, False)
195+
graph = torch._C._jit_get_graph(trace)
196+
nodes = list(graph.nodes())
197+
198+
self.assertTrue(nodes[0].scopeName() == 'Net/Sequential[features]/Conv2d[0]')
199+
self.assertTrue(nodes[1].scopeName() == 'Net/Sequential[features]/ReLU[1]')
200+
self.assertTrue(nodes[2].scopeName() == 'Net/Sequential[features]/MaxPool2d[2]')
201+
161202
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
162203
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
163204
def test_lstm_fusion(self):

0 commit comments

Comments
 (0)