Skip to content

Commit 6df7741

Browse files
committed
Update on "[ONNX] Remove the deprecated monkey patches to torch.Graph"
cc ezyang gchanan [ghstack-poisoned]
1 parent d443fdb commit 6df7741

File tree

5 files changed

+64
-51
lines changed

5 files changed

+64
-51
lines changed

test/onnx/test_pytorch_onnx_shape_inference.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
1010
from torch.onnx import _constants, symbolic_helper
11+
from torch.onnx._internal import jit_utils
1112
from torch.testing._internal import common_utils
1213

1314

@@ -22,6 +23,17 @@ def verify(actual_type):
2223
return verify
2324

2425

26+
def g_op(graph: torch.Graph, op_name: str, *args, **kwargs):
27+
return jit_utils.GraphContext(
28+
graph=graph,
29+
block=graph.block(),
30+
opset=_constants.ONNX_MAX_OPSET,
31+
original_node=None, # type: ignore[arg-type]
32+
params_dict={},
33+
env={},
34+
).op(op_name, *args, **kwargs)
35+
36+
2537
class TestONNXShapeInference(pytorch_test_common.ExportTestCase):
2638
def setUp(self):
2739
self.opset_version = _constants.ONNX_MAX_OPSET
@@ -43,21 +55,23 @@ def create_empty_graph(self):
4355
return g
4456

4557
def insert_tensor_constant(self, g, tensor):
46-
return g.op("Constant", value_t=tensor)
58+
return g_op(g, "Constant", value_t=tensor)
4759

4860
def test_cast(self):
4961
# Test cast with input of unknown scalar type.
5062
g = self.create_empty_graph()
5163
input = g.addInput()
52-
cast_out = g.op("Cast", input, to_i=1)
64+
cast_out = g_op(g, "Cast", input, to_i=1)
5365
self.run_test(g, cast_out.node(), expect_tensor("Float"))
5466

5567
def test_constant_of_shape(self):
5668
# Test ConstantOfShape with input of onnx::Shape node.
5769
g = self.create_empty_graph()
5870
constant = self.insert_tensor_constant(g, torch.ones(1, 2, 3, 4))
59-
shape = g.op("Shape", constant)
60-
constant_of_shape = g.op("ConstantOfShape", shape, value_t=torch.tensor([2.0]))
71+
shape = g_op(g, "Shape", constant)
72+
constant_of_shape = g_op(
73+
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
74+
)
6175
self.run_test(
6276
g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
6377
)
@@ -69,9 +83,11 @@ def test_constant_of_shape_static(self):
6983
constants = [
7084
self.insert_tensor_constant(g, torch.tensor(i + 1)) for i in range(rank)
7185
]
72-
shape = g.op("prim::ListConstruct", *constants)
86+
shape = g_op(g, "prim::ListConstruct", *constants)
7387
shape.setType(torch._C.ListType.ofInts())
74-
constant_of_shape = g.op("ConstantOfShape", shape, value_t=torch.tensor([2.0]))
88+
constant_of_shape = g_op(
89+
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
90+
)
7591
self.run_test(
7692
g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
7793
)
@@ -81,9 +97,11 @@ def test_constant_of_shape_dynamic(self):
8197
rank = 4
8298
g = self.create_empty_graph()
8399
inputs = [g.addInput() for i in range(rank)]
84-
shape = g.op("prim::ListConstruct", *inputs)
100+
shape = g_op(g, "prim::ListConstruct", *inputs)
85101
shape.setType(torch._C.ListType.ofInts())
86-
constant_of_shape = g.op("ConstantOfShape", shape, value_t=torch.tensor([2.0]))
102+
constant_of_shape = g_op(
103+
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
104+
)
87105
self.run_test(
88106
g,
89107
constant_of_shape.node(),
@@ -98,7 +116,7 @@ def test_gather_dynamic_index(self):
98116
)
99117
indices = g.addInput()
100118
indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None]))
101-
output = g.op("Gather", input, indices, axis_i=1)
119+
output = g_op(g, "Gather", input, indices, axis_i=1)
102120
self.run_test(
103121
g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16]))
104122
)
@@ -110,34 +128,34 @@ def test_gather_scalar_index(self):
110128
input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16])
111129
)
112130
indices = self.insert_tensor_constant(g, torch.tensor(1))
113-
output = g.op("Gather", input, indices, axis_i=1)
131+
output = g_op(g, "Gather", input, indices, axis_i=1)
114132
self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16])))
115133

116134
def test_reshape(self):
117135
g = self.create_empty_graph()
118136
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5))
119137
constant_2 = self.insert_tensor_constant(g, torch.tensor([2, 0, -1]))
120-
shape = g.op("Reshape", constant, constant_2)
138+
shape = g_op(g, "Reshape", constant, constant_2)
121139
self.run_test(g, shape.node(), expect_tensor("Float", shape=(2, 16, 25)))
122140

123141
g = self.create_empty_graph()
124142
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
125143
constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 4]))
126-
shape = g.op("Reshape", constant, constant_2)
144+
shape = g_op(g, "Reshape", constant, constant_2)
127145
self.run_test(g, shape.node(), expect_tensor("Float", shape=(10, 16, 4)))
128146

129147
g = self.create_empty_graph()
130148
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
131149
constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 0]))
132-
shape = g.op("Reshape", constant, constant_2)
150+
shape = g_op(g, "Reshape", constant, constant_2)
133151
self.run_test(g, shape.node(), expect_tensor("Float", shape=(8, 16, 5)))
134152

135153
def test_reshape_symbolic(self):
136154
g = self.create_empty_graph()
137155
input = g.addInput()
138156
input.setType(input.type().with_sizes([None, None, 2, 8]))
139157
constant = self.insert_tensor_constant(g, torch.tensor([0, 0, -1]))
140-
output = g.op("Reshape", input, constant)
158+
output = g_op(g, "Reshape", input, constant)
141159
self.run_test(g, output.node(), expect_tensor(None, shape=(None, None, 16)))
142160

143161
@skipIfUnsupportedMinOpsetVersion(14)
@@ -146,7 +164,7 @@ def test_reshape_allowzero(self):
146164
input = g.addInput()
147165
input.setType(input.type().with_sizes([3, 4, 0]))
148166
constant = self.insert_tensor_constant(g, torch.tensor([0, 4, 3]))
149-
output = g.op("Reshape", input, constant, allowzero_i=1)
167+
output = g_op(g, "Reshape", input, constant, allowzero_i=1)
150168
self.run_test(g, output.node(), expect_tensor(None, shape=(0, 4, 3)))
151169

152170
def test_slice(self):
@@ -158,62 +176,62 @@ def test_slice(self):
158176
end = self.insert_tensor_constant(g, torch.tensor([3]))
159177
axis = self.insert_tensor_constant(g, torch.tensor([0]))
160178
step = self.insert_tensor_constant(g, torch.tensor([1]))
161-
slice = g.op("Slice", input, start_input, end, axis, step)
179+
slice = g_op(g, "Slice", input, start_input, end, axis, step)
162180
self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
163181

164182
def test_broadcast_matmul(self):
165183
g = self.create_empty_graph()
166184
constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
167185
constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
168-
shape = g.op("MatMul", constant, constant_2)
186+
shape = g_op(g, "MatMul", constant, constant_2)
169187
self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 5, 1, 1)))
170188

171189
# test when first input is of rank 1
172190
g = self.create_empty_graph()
173191
constant = self.insert_tensor_constant(g, torch.ones(2))
174192
constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
175-
shape = g.op("MatMul", constant, constant_2)
193+
shape = g_op(g, "MatMul", constant, constant_2)
176194
self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 1, 1)))
177195

178196
# test when second input is of rank 1
179197
g = self.create_empty_graph()
180198
constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
181199
constant_2 = self.insert_tensor_constant(g, torch.ones(2))
182-
shape = g.op("MatMul", constant, constant_2)
200+
shape = g_op(g, "MatMul", constant, constant_2)
183201
self.run_test(g, shape.node(), expect_tensor("Float", shape=(5, 1)))
184202

185203
# test when both inputs are of rank 1
186204
g = self.create_empty_graph()
187205
constant = self.insert_tensor_constant(g, torch.ones(2))
188206
constant_2 = self.insert_tensor_constant(g, torch.ones(2))
189-
shape = g.op("MatMul", constant, constant_2)
207+
shape = g_op(g, "MatMul", constant, constant_2)
190208
self.run_test(g, shape.node(), expect_tensor("Float", shape=()))
191209

192210
def test_expand(self):
193211
g = self.create_empty_graph()
194212
input = g.addInput()
195213
constant = self.insert_tensor_constant(g, torch.ones(2, 4))
196214
input.setType(constant.type().with_sizes([None, None]))
197-
shape = g.op("Shape", input)
198-
expand = g.op("Expand", constant, shape)
215+
shape = g_op(g, "Shape", input)
216+
expand = g_op(g, "Expand", constant, shape)
199217
self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
200218

201219
def test_pad(self):
202220
g = self.create_empty_graph()
203221
input = g.addInput()
204222
input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
205223
constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
206-
none = g.op("prim::Constant").setType(torch.NoneType.get())
207-
pad = g.op("Pad", input, constant, none, mode_s="constant")
224+
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
225+
pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
208226
self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, 322, 102)))
209227

210228
def test_pad_with_dynamic_input_shape(self):
211229
g = self.create_empty_graph()
212230
input = g.addInput()
213231
input.setType(input.type().with_dtype(torch.float).with_sizes([3, None, None]))
214232
constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
215-
none = g.op("prim::Constant").setType(torch.NoneType.get())
216-
pad = g.op("Pad", input, constant, none, mode_s="constant")
233+
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
234+
pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
217235
self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, None, None)))
218236

219237
def test_pad_with_dynamic_pad_size(self):
@@ -222,19 +240,20 @@ def test_pad_with_dynamic_pad_size(self):
222240
input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
223241
pad_size = g.addInput()
224242
pad_size.setType(pad_size.type().with_dtype(torch.long).with_sizes([6]))
225-
none = g.op("prim::Constant").setType(torch.NoneType.get())
226-
pad = g.op("Pad", input, pad_size, none, mode_s="constant")
243+
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
244+
pad = g_op(g, "Pad", input, pad_size, none, mode_s="constant")
227245
self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None)))
228246

229247
def test_resize(self):
230248
g = self.create_empty_graph()
231249
input = g.addInput()
232250
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
233-
none = g.op("prim::Constant").setType(torch.NoneType.get())
251+
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
234252
scales = self.insert_tensor_constant(
235253
g, torch.tensor([1, 1, 2, 2], dtype=torch.float)
236254
)
237-
resize = g.op(
255+
resize = g_op(
256+
g,
238257
"Resize",
239258
input,
240259
none,
@@ -250,16 +269,17 @@ def test_resize_after_concat(self):
250269
g = self.create_empty_graph()
251270
input = g.addInput()
252271
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
253-
none = g.op("prim::Constant").setType(torch.NoneType.get())
272+
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
254273
scale_1 = self.insert_tensor_constant(
255274
g, torch.tensor([1, 1], dtype=torch.float)
256275
)
257276
scale_2 = self.insert_tensor_constant(
258277
g, torch.tensor([2, 2], dtype=torch.float)
259278
)
260279
# `scales` values should be statically known due to constant folding in shape inference.
261-
scales = g.op("Concat", scale_1, scale_2, axis_i=0)
262-
resize = g.op(
280+
scales = g_op(g, "Concat", scale_1, scale_2, axis_i=0)
281+
resize = g_op(
282+
g,
263283
"Resize",
264284
input,
265285
none,
@@ -275,14 +295,14 @@ def test_reduce_prod_with_axes(self):
275295
g = self.create_empty_graph()
276296
input = g.addInput()
277297
input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
278-
reduce_prod = g.op("ReduceProd", input, axes_i=[0])
298+
reduce_prod = g_op(g, "ReduceProd", input, axes_i=[0])
279299
self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
280300

281301
def test_reduce_prod_without_axes(self):
282302
g = self.create_empty_graph()
283303
input = g.addInput()
284304
input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
285-
reduce_prod = g.op("ReduceProd", input)
305+
reduce_prod = g_op(g, "ReduceProd", input)
286306
self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
287307

288308
def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self):
@@ -291,14 +311,14 @@ def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self):
291311
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 16]))
292312
length = g.addInput()
293313
length.setType(length.type().with_dtype(torch.long).with_sizes([4]))
294-
padded, batch_size = g.op("prim::PackPadded", input, length, outputs=2)
314+
padded, batch_size = g_op(g, "prim::PackPadded", input, length, outputs=2)
295315
# `prim::PackPadded` only occurs in tracing mode. Hence its outputs inherits
296316
# shape and data type from traced graph.
297317
padded.setType(padded.type().with_dtype(torch.float).with_sizes([None, None]))
298318
batch_size.setType(batch_size.type().with_dtype(torch.long).with_sizes([None]))
299319
# `Gather` should use the data type of `batch_size` as the data type of its output.
300320
gather_idx = self.insert_tensor_constant(g, torch.tensor([0], dtype=torch.long))
301-
gather = g.op("Gather", batch_size, gather_idx, axis_i=0)
321+
gather = g_op(g, "Gather", batch_size, gather_idx, axis_i=0)
302322
self.run_test(g, gather.node(), expect_tensor("Long", shape=(None,)))
303323

304324

torch/onnx/_internal/jit_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
9999
**kwargs,
100100
)
101101

102+
# NOTE: For backward compatibility with the old symbolic functions.
103+
# We are probably going to remove this only after the fx exporter is established.
104+
at = aten_op
105+
102106
@_beartype.beartype
103107
def onnxscript_op(
104108
self,

torch/onnx/symbolic_helper.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@
2323
from torch import _C
2424

2525
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
26-
from torch.onnx import ( # noqa: F401
27-
_constants,
28-
_deprecation,
29-
_type_utils,
30-
errors,
31-
)
26+
from torch.onnx import _constants, _deprecation, _type_utils, errors
3227
from torch.onnx._globals import GLOBALS
3328
from torch.onnx._internal import _beartype, jit_utils
3429
from torch.types import Number

torch/onnx/symbolic_opset10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import _C
1010

1111
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
12-
from torch.onnx import ( # noqa: F401
12+
from torch.onnx import (
1313
_constants,
1414
_type_utils,
1515
errors,

torch/onnx/symbolic_opset9.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,7 @@
1818
from torch import _C
1919

2020
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
21-
from torch.onnx import ( # noqa: F401
22-
_constants,
23-
_deprecation,
24-
_type_utils,
25-
errors,
26-
symbolic_helper,
27-
)
21+
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
2822
from torch.onnx._globals import GLOBALS
2923
from torch.onnx._internal import _beartype, jit_utils, registration
3024
from torch.types import Number

0 commit comments

Comments
 (0)