88import torch
99from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
1010from torch .onnx import _constants , symbolic_helper
11+ from torch .onnx ._internal import jit_utils
1112from torch .testing ._internal import common_utils
1213
1314
@@ -22,6 +23,17 @@ def verify(actual_type):
2223 return verify
2324
2425
26+ def g_op (graph : torch .Graph , op_name : str , * args , ** kwargs ):
27+ return jit_utils .GraphContext (
28+ graph = graph ,
29+ block = graph .block (),
30+ opset = _constants .ONNX_MAX_OPSET ,
31+ original_node = None , # type: ignore[arg-type]
32+ params_dict = {},
33+ env = {},
34+ ).op (op_name , * args , ** kwargs )
35+
36+
2537class TestONNXShapeInference (pytorch_test_common .ExportTestCase ):
2638 def setUp (self ):
2739 self .opset_version = _constants .ONNX_MAX_OPSET
@@ -43,21 +55,23 @@ def create_empty_graph(self):
4355 return g
4456
4557 def insert_tensor_constant (self , g , tensor ):
46- return g . op ( "Constant" , value_t = tensor )
58+ return g_op ( g , "Constant" , value_t = tensor )
4759
4860 def test_cast (self ):
4961 # Test cast with input of unknown scalar type.
5062 g = self .create_empty_graph ()
5163 input = g .addInput ()
52- cast_out = g . op ( "Cast" , input , to_i = 1 )
64+ cast_out = g_op ( g , "Cast" , input , to_i = 1 )
5365 self .run_test (g , cast_out .node (), expect_tensor ("Float" ))
5466
5567 def test_constant_of_shape (self ):
5668 # Test ConstantOfShape with input of onnx::Shape node.
5769 g = self .create_empty_graph ()
5870 constant = self .insert_tensor_constant (g , torch .ones (1 , 2 , 3 , 4 ))
59- shape = g .op ("Shape" , constant )
60- constant_of_shape = g .op ("ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ]))
71+ shape = g_op (g , "Shape" , constant )
72+ constant_of_shape = g_op (
73+ g , "ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ])
74+ )
6175 self .run_test (
6276 g , constant_of_shape .node (), expect_tensor ("Float" , shape = (1 , 2 , 3 , 4 ))
6377 )
@@ -69,9 +83,11 @@ def test_constant_of_shape_static(self):
6983 constants = [
7084 self .insert_tensor_constant (g , torch .tensor (i + 1 )) for i in range (rank )
7185 ]
72- shape = g . op ( "prim::ListConstruct" , * constants )
86+ shape = g_op ( g , "prim::ListConstruct" , * constants )
7387 shape .setType (torch ._C .ListType .ofInts ())
74- constant_of_shape = g .op ("ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ]))
88+ constant_of_shape = g_op (
89+ g , "ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ])
90+ )
7591 self .run_test (
7692 g , constant_of_shape .node (), expect_tensor ("Float" , shape = (1 , 2 , 3 , 4 ))
7793 )
@@ -81,9 +97,11 @@ def test_constant_of_shape_dynamic(self):
8197 rank = 4
8298 g = self .create_empty_graph ()
8399 inputs = [g .addInput () for i in range (rank )]
84- shape = g . op ( "prim::ListConstruct" , * inputs )
100+ shape = g_op ( g , "prim::ListConstruct" , * inputs )
85101 shape .setType (torch ._C .ListType .ofInts ())
86- constant_of_shape = g .op ("ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ]))
102+ constant_of_shape = g_op (
103+ g , "ConstantOfShape" , shape , value_t = torch .tensor ([2.0 ])
104+ )
87105 self .run_test (
88106 g ,
89107 constant_of_shape .node (),
@@ -98,7 +116,7 @@ def test_gather_dynamic_index(self):
98116 )
99117 indices = g .addInput ()
100118 indices .setType (indices .type ().with_dtype (torch .int64 ).with_sizes ([None ]))
101- output = g . op ( "Gather" , input , indices , axis_i = 1 )
119+ output = g_op ( g , "Gather" , input , indices , axis_i = 1 )
102120 self .run_test (
103121 g , output .node (), expect_tensor ("Float" , shape = ([None , None , 16 , 16 ]))
104122 )
@@ -110,34 +128,34 @@ def test_gather_scalar_index(self):
110128 input .type ().with_dtype (torch .float ).with_sizes ([None , 3 , 16 , 16 ])
111129 )
112130 indices = self .insert_tensor_constant (g , torch .tensor (1 ))
113- output = g . op ( "Gather" , input , indices , axis_i = 1 )
131+ output = g_op ( g , "Gather" , input , indices , axis_i = 1 )
114132 self .run_test (g , output .node (), expect_tensor ("Float" , shape = ([None , 16 , 16 ])))
115133
116134 def test_reshape (self ):
117135 g = self .create_empty_graph ()
118136 constant = self .insert_tensor_constant (g , torch .ones (2 , 16 , 5 , 5 ))
119137 constant_2 = self .insert_tensor_constant (g , torch .tensor ([2 , 0 , - 1 ]))
120- shape = g . op ( "Reshape" , constant , constant_2 )
138+ shape = g_op ( g , "Reshape" , constant , constant_2 )
121139 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (2 , 16 , 25 )))
122140
123141 g = self .create_empty_graph ()
124142 constant = self .insert_tensor_constant (g , torch .ones (2 , 16 , 5 , 4 ))
125143 constant_2 = self .insert_tensor_constant (g , torch .tensor ([- 1 , 0 , 4 ]))
126- shape = g . op ( "Reshape" , constant , constant_2 )
144+ shape = g_op ( g , "Reshape" , constant , constant_2 )
127145 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (10 , 16 , 4 )))
128146
129147 g = self .create_empty_graph ()
130148 constant = self .insert_tensor_constant (g , torch .ones (2 , 16 , 5 , 4 ))
131149 constant_2 = self .insert_tensor_constant (g , torch .tensor ([- 1 , 0 , 0 ]))
132- shape = g . op ( "Reshape" , constant , constant_2 )
150+ shape = g_op ( g , "Reshape" , constant , constant_2 )
133151 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (8 , 16 , 5 )))
134152
135153 def test_reshape_symbolic (self ):
136154 g = self .create_empty_graph ()
137155 input = g .addInput ()
138156 input .setType (input .type ().with_sizes ([None , None , 2 , 8 ]))
139157 constant = self .insert_tensor_constant (g , torch .tensor ([0 , 0 , - 1 ]))
140- output = g . op ( "Reshape" , input , constant )
158+ output = g_op ( g , "Reshape" , input , constant )
141159 self .run_test (g , output .node (), expect_tensor (None , shape = (None , None , 16 )))
142160
143161 @skipIfUnsupportedMinOpsetVersion (14 )
@@ -146,7 +164,7 @@ def test_reshape_allowzero(self):
146164 input = g .addInput ()
147165 input .setType (input .type ().with_sizes ([3 , 4 , 0 ]))
148166 constant = self .insert_tensor_constant (g , torch .tensor ([0 , 4 , 3 ]))
149- output = g . op ( "Reshape" , input , constant , allowzero_i = 1 )
167+ output = g_op ( g , "Reshape" , input , constant , allowzero_i = 1 )
150168 self .run_test (g , output .node (), expect_tensor (None , shape = (0 , 4 , 3 )))
151169
152170 def test_slice (self ):
@@ -158,62 +176,62 @@ def test_slice(self):
158176 end = self .insert_tensor_constant (g , torch .tensor ([3 ]))
159177 axis = self .insert_tensor_constant (g , torch .tensor ([0 ]))
160178 step = self .insert_tensor_constant (g , torch .tensor ([1 ]))
161- slice = g . op ( "Slice" , input , start_input , end , axis , step )
179+ slice = g_op ( g , "Slice" , input , start_input , end , axis , step )
162180 self .run_test (g , slice .node (), expect_tensor (None , shape = (None , None )))
163181
164182 def test_broadcast_matmul (self ):
165183 g = self .create_empty_graph ()
166184 constant = self .insert_tensor_constant (g , torch .ones (5 , 1 , 2 ))
167185 constant_2 = self .insert_tensor_constant (g , torch .ones (3 , 1 , 2 , 1 ))
168- shape = g . op ( "MatMul" , constant , constant_2 )
186+ shape = g_op ( g , "MatMul" , constant , constant_2 )
169187 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (3 , 5 , 1 , 1 )))
170188
171189 # test when first input is of rank 1
172190 g = self .create_empty_graph ()
173191 constant = self .insert_tensor_constant (g , torch .ones (2 ))
174192 constant_2 = self .insert_tensor_constant (g , torch .ones (3 , 1 , 2 , 1 ))
175- shape = g . op ( "MatMul" , constant , constant_2 )
193+ shape = g_op ( g , "MatMul" , constant , constant_2 )
176194 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (3 , 1 , 1 )))
177195
178196 # test when second input is of rank 1
179197 g = self .create_empty_graph ()
180198 constant = self .insert_tensor_constant (g , torch .ones (5 , 1 , 2 ))
181199 constant_2 = self .insert_tensor_constant (g , torch .ones (2 ))
182- shape = g . op ( "MatMul" , constant , constant_2 )
200+ shape = g_op ( g , "MatMul" , constant , constant_2 )
183201 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = (5 , 1 )))
184202
185203 # test when both inputs are of rank 1
186204 g = self .create_empty_graph ()
187205 constant = self .insert_tensor_constant (g , torch .ones (2 ))
188206 constant_2 = self .insert_tensor_constant (g , torch .ones (2 ))
189- shape = g . op ( "MatMul" , constant , constant_2 )
207+ shape = g_op ( g , "MatMul" , constant , constant_2 )
190208 self .run_test (g , shape .node (), expect_tensor ("Float" , shape = ()))
191209
192210 def test_expand (self ):
193211 g = self .create_empty_graph ()
194212 input = g .addInput ()
195213 constant = self .insert_tensor_constant (g , torch .ones (2 , 4 ))
196214 input .setType (constant .type ().with_sizes ([None , None ]))
197- shape = g . op ( "Shape" , input )
198- expand = g . op ( "Expand" , constant , shape )
215+ shape = g_op ( g , "Shape" , input )
216+ expand = g_op ( g , "Expand" , constant , shape )
199217 self .run_test (g , expand .node (), expect_tensor ("Float" , shape = (None , None )))
200218
201219 def test_pad (self ):
202220 g = self .create_empty_graph ()
203221 input = g .addInput ()
204222 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([3 , 320 , 100 ]))
205223 constant = self .insert_tensor_constant (g , torch .ones (6 , dtype = torch .long ))
206- none = g . op ( "prim::Constant" ).setType (torch .NoneType .get ())
207- pad = g . op ( "Pad" , input , constant , none , mode_s = "constant" )
224+ none = g_op ( g , "prim::Constant" ).setType (torch .NoneType .get ())
225+ pad = g_op ( g , "Pad" , input , constant , none , mode_s = "constant" )
208226 self .run_test (g , pad .node (), expect_tensor ("Float" , shape = (5 , 322 , 102 )))
209227
210228 def test_pad_with_dynamic_input_shape (self ):
211229 g = self .create_empty_graph ()
212230 input = g .addInput ()
213231 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([3 , None , None ]))
214232 constant = self .insert_tensor_constant (g , torch .ones (6 , dtype = torch .long ))
215- none = g . op ( "prim::Constant" ).setType (torch .NoneType .get ())
216- pad = g . op ( "Pad" , input , constant , none , mode_s = "constant" )
233+ none = g_op ( g , "prim::Constant" ).setType (torch .NoneType .get ())
234+ pad = g_op ( g , "Pad" , input , constant , none , mode_s = "constant" )
217235 self .run_test (g , pad .node (), expect_tensor ("Float" , shape = (5 , None , None )))
218236
219237 def test_pad_with_dynamic_pad_size (self ):
@@ -222,19 +240,20 @@ def test_pad_with_dynamic_pad_size(self):
222240 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([3 , 320 , 100 ]))
223241 pad_size = g .addInput ()
224242 pad_size .setType (pad_size .type ().with_dtype (torch .long ).with_sizes ([6 ]))
225- none = g . op ( "prim::Constant" ).setType (torch .NoneType .get ())
226- pad = g . op ( "Pad" , input , pad_size , none , mode_s = "constant" )
243+ none = g_op ( g , "prim::Constant" ).setType (torch .NoneType .get ())
244+ pad = g_op ( g , "Pad" , input , pad_size , none , mode_s = "constant" )
227245 self .run_test (g , pad .node (), expect_tensor ("Float" , shape = (None , None , None )))
228246
229247 def test_resize (self ):
230248 g = self .create_empty_graph ()
231249 input = g .addInput ()
232250 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([4 , 32 , 64 , 64 ]))
233- none = g . op ( "prim::Constant" ).setType (torch .NoneType .get ())
251+ none = g_op ( g , "prim::Constant" ).setType (torch .NoneType .get ())
234252 scales = self .insert_tensor_constant (
235253 g , torch .tensor ([1 , 1 , 2 , 2 ], dtype = torch .float )
236254 )
237- resize = g .op (
255+ resize = g_op (
256+ g ,
238257 "Resize" ,
239258 input ,
240259 none ,
@@ -250,16 +269,17 @@ def test_resize_after_concat(self):
250269 g = self .create_empty_graph ()
251270 input = g .addInput ()
252271 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([4 , 32 , 64 , 64 ]))
253- none = g . op ( "prim::Constant" ).setType (torch .NoneType .get ())
272+ none = g_op ( g , "prim::Constant" ).setType (torch .NoneType .get ())
254273 scale_1 = self .insert_tensor_constant (
255274 g , torch .tensor ([1 , 1 ], dtype = torch .float )
256275 )
257276 scale_2 = self .insert_tensor_constant (
258277 g , torch .tensor ([2 , 2 ], dtype = torch .float )
259278 )
260279 # `scales` values should be statically known due to constant folding in shape inference.
261- scales = g .op ("Concat" , scale_1 , scale_2 , axis_i = 0 )
262- resize = g .op (
280+ scales = g_op (g , "Concat" , scale_1 , scale_2 , axis_i = 0 )
281+ resize = g_op (
282+ g ,
263283 "Resize" ,
264284 input ,
265285 none ,
@@ -275,14 +295,14 @@ def test_reduce_prod_with_axes(self):
275295 g = self .create_empty_graph ()
276296 input = g .addInput ()
277297 input .setType (input .type ().with_dtype (torch .long ).with_sizes ([2 ]))
278- reduce_prod = g . op ( "ReduceProd" , input , axes_i = [0 ])
298+ reduce_prod = g_op ( g , "ReduceProd" , input , axes_i = [0 ])
279299 self .run_test (g , reduce_prod .node (), expect_tensor ("Long" , shape = (1 ,)))
280300
281301 def test_reduce_prod_without_axes (self ):
282302 g = self .create_empty_graph ()
283303 input = g .addInput ()
284304 input .setType (input .type ().with_dtype (torch .long ).with_sizes ([2 ]))
285- reduce_prod = g . op ( "ReduceProd" , input )
305+ reduce_prod = g_op ( g , "ReduceProd" , input )
286306 self .run_test (g , reduce_prod .node (), expect_tensor ("Long" , shape = (1 ,)))
287307
288308 def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly (self ):
@@ -291,14 +311,14 @@ def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self):
291311 input .setType (input .type ().with_dtype (torch .float ).with_sizes ([4 , 16 ]))
292312 length = g .addInput ()
293313 length .setType (length .type ().with_dtype (torch .long ).with_sizes ([4 ]))
294- padded , batch_size = g . op ( "prim::PackPadded" , input , length , outputs = 2 )
314+ padded , batch_size = g_op ( g , "prim::PackPadded" , input , length , outputs = 2 )
295315 # `prim::PackPadded` only occurs in tracing mode. Hence its outputs inherits
296316 # shape and data type from traced graph.
297317 padded .setType (padded .type ().with_dtype (torch .float ).with_sizes ([None , None ]))
298318 batch_size .setType (batch_size .type ().with_dtype (torch .long ).with_sizes ([None ]))
299319 # `Gather` should use the data type of `batch_size` as the data type of its output.
300320 gather_idx = self .insert_tensor_constant (g , torch .tensor ([0 ], dtype = torch .long ))
301- gather = g . op ( "Gather" , batch_size , gather_idx , axis_i = 0 )
321+ gather = g_op ( g , "Gather" , batch_size , gather_idx , axis_i = 0 )
302322 self .run_test (g , gather .node (), expect_tensor ("Long" , shape = (None ,)))
303323
304324
0 commit comments