@@ -158,6 +158,47 @@ def f(x, y):
158158 trace , z = torch .jit .trace (f , (x , y ), nderivs = 0 )
159159 self .assertExpectedTrace (trace )
160160
161+ class Net (nn .Module ):
162+ def forward (self , x ):
163+ return F .log_softmax (x , dim = 0 )
164+
165+ net = Net ()
166+ t = Variable (torch .ones (2 ), requires_grad = True )
167+ trace , _ = torch .jit .trace (net , (t , ))
168+ torch .onnx ._optimize_trace (trace , False )
169+ g = torch ._C ._jit_get_graph (trace )
170+ for node in g .nodes ():
171+ self .assertTrue (node .scopeName () == 'Net' )
172+
173+ class Net (nn .Module ):
174+
175+ def __init__ (self , num_classes = 1000 ):
176+ super (Net , self ).__init__ ()
177+ self .features = nn .Sequential (
178+ nn .Conv2d (3 , 64 , kernel_size = 11 , stride = 4 , padding = 2 ),
179+ nn .ReLU (inplace = True ),
180+ nn .MaxPool2d (kernel_size = 3 , stride = 2 ),
181+ )
182+
183+ def forward (self , x ):
184+ x = self .features (x )
185+ return x
186+
187+ model = Net ()
188+
189+ t = Variable (torch .ones (1 , 3 , 227 , 227 ), requires_grad = True )
190+
191+ with torch .onnx .set_training (model , False ):
192+ trace , _ = torch .jit .trace (model , (t , ))
193+
194+ torch .onnx ._optimize_trace (trace , False )
195+ graph = torch ._C ._jit_get_graph (trace )
196+ nodes = list (graph .nodes ())
197+
198+ self .assertTrue (nodes [0 ].scopeName () == 'Net/Sequential[features]/Conv2d[0]' )
199+ self .assertTrue (nodes [1 ].scopeName () == 'Net/Sequential[features]/ReLU[1]' )
200+ self .assertTrue (nodes [2 ].scopeName () == 'Net/Sequential[features]/MaxPool2d[2]' )
201+
161202 @unittest .skipIf (IS_WINDOWS , "NYI: fuser support for Windows" )
162203 @unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
163204 def test_lstm_fusion (self ):
0 commit comments