Skip to content

Commit 8ab426f

Browse files
committed
add comments
1 parent dbf48d6 commit 8ab426f

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

torch/onnx/symbolic_opset8.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,28 +88,23 @@ def _cast_to_type(g, input, to_type):
8888
return getattr(sym_opset9, '_cast_{}'.format(to_type))(g, input, False)
8989

9090

91-
@wrap_logical_op_with_cast_to('Byte')
92-
def gt(g, input, other):
93-
return gt_impl(g, input, other)
94-
95-
96-
def gt_impl(g, input, other):
91+
def _comparison_operator(g, input, other, op_name):
9792
other = sym_help._maybe_get_scalar(other)
9893
other = sym_help._if_scalar_type_as(g, other, input)
9994
_, input, other = _try_cast_integer_to_float(g, input, other)
100-
return g.op("Greater", input, other)
95+
return g.op(op_name, input, other)
10196

10297

98+
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
99+
# integer input type not supported in opset8. Cast to float if possible.
103100
@wrap_logical_op_with_cast_to('Byte')
104-
def lt(g, input, other):
105-
return lt_impl(g, input, other)
101+
def gt(g, input, other):
102+
return _comparison_operator(g, input, other, "Greater")
106103

107104

108-
def lt_impl(g, input, other):
109-
other = sym_help._maybe_get_scalar(other)
110-
other = sym_help._if_scalar_type_as(g, other, input)
111-
_, input, other = _try_cast_integer_to_float(g, input, other)
112-
return g.op("Less", input, other)
105+
@wrap_logical_op_with_cast_to('Byte')
106+
def lt(g, input, other):
107+
return _comparison_operator(g, input, other, "Less")
113108

114109

115110
def bmm(g, self, other):
@@ -121,11 +116,7 @@ def bmm(g, self, other):
121116

122117

123118
def matmul(g, self, other):
124-
if _try_get_scalar_type(self):
125-
old_type, self, other = _try_cast_integer_to_float(g, self, other)
126-
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
127-
else:
128-
return g.op("MatMul", self, other)
119+
return bmm(g, self, other)
129120

130121

131122
def prelu(g, self, weight):

torch/onnx/symbolic_registry.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,17 @@ def register_ops_in_version(domain, version):
3636
# opset versions for operators supported in
3737
# previous versions.
3838

39-
# Opset 9 is the base version.
40-
# For operators of different opset version, updated symbolic functions are added
41-
# in the respective symbolic_opset{version}.py file.
39+
# Opset 9 is the base version. It is selected as the base version because
40+
# 1. It is the first opset version supported by PyTorch export.
41+
# 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
42+
# that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
43+
# we chose to handle them as special cases separately.
44+
# Backward support for opset versions beyond opset 7 is not in our roadmap.
45+
46+
# For opset versions other than 9, by default they will inherit the symbolic functions defined in
47+
# symbolic_opset9.py.
48+
# To extend support for updated operators in different opset versions on top of opset 9,
49+
# simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
4250
# Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
4351
iter_version = version
4452
while iter_version != 9:

0 commit comments

Comments
 (0)