Skip to content

Commit cc85c3d

Browse files
Larafacebook-github-bot
authored andcommitted
ONNX Export Slice and Flip ops for opset 10
Summary: Pull Request resolved: #20533 Reviewed By: zrphercule Differential Revision: D15579713 Pulled By: houseroad fbshipit-source-id: 91f3ac0cb14ef226f980362b0013b6b92cb8b8da
1 parent 3eced79 commit cc85c3d

File tree

5 files changed

+159
-28
lines changed

5 files changed

+159
-28
lines changed

test/onnx/test_onnx_opset.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
4040
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
4141

4242

43-
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=False):
43+
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=False, example_outputs=None):
4444
for opset_version in opset_versions:
4545
f = io.BytesIO()
46-
torch.onnx.export(module, x, f, opset_version=opset_version, training=training)
46+
torch.onnx.export(module, x, f,
47+
opset_version=opset_version,
48+
training=training,
49+
example_outputs=example_outputs)
4750
model = onnx.load(io.BytesIO(f.getvalue()))
4851
check_onnx_opset_operator(model, ops[opset_version], opset_version)
4952

@@ -107,6 +110,79 @@ def test_maxpool(self):
107110
x = torch.randn(20, 16, 50)
108111
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
109112

113+
def test_slice(self):
114+
class MyModule(Module):
115+
def forward(self, x):
116+
return x[0:1]
117+
118+
ops_9 = [{"op_name" : "Slice",
119+
"attributes" :
120+
[{"name": "axes", "ints": [0], "type": 7},
121+
{"name": "ends", "ints": [1], "type": 7},
122+
{"name": "starts", "ints": [0], "type": 7}]}]
123+
ops_10 = [{"op_name" : "Constant"},
124+
{"op_name" : "Constant"},
125+
{"op_name" : "Constant"},
126+
{"op_name" : "Constant"},
127+
{"op_name" : "Slice",
128+
"attributes" : []}]
129+
ops = {9 : ops_9, 10 : ops_10}
130+
x = torch.randn(3)
131+
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
132+
133+
class DynamicSliceModel(torch.jit.ScriptModule):
134+
@torch.jit.script_method
135+
def forward(self, x):
136+
return x[1:x.size(0)]
137+
138+
ops_9 = [{"op_name" : "Constant"},
139+
{"op_name" : "Constant"},
140+
{"op_name" : "Shape"},
141+
{"op_name" : "Gather",
142+
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
143+
{"op_name" : "Unsqueeze",
144+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
145+
{"op_name" : "Unsqueeze",
146+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
147+
{"op_name" : "Unsqueeze",
148+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
149+
{"op_name" : "DynamicSlice"}]
150+
ops_10 = [{"op_name" : "Constant"},
151+
{"op_name" : "Constant"},
152+
{"op_name" : "Shape"},
153+
{"op_name" : "Gather",
154+
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
155+
{"op_name" : "Unsqueeze",
156+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
157+
{"op_name" : "Unsqueeze",
158+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
159+
{"op_name" : "Unsqueeze",
160+
"attributes" : [{"name" : "axes", "i" : 0, "type" : 7}]},
161+
{"op_name" : "Constant"},
162+
{"op_name" : "Slice",
163+
"attributes" : []}]
164+
ops = {9 : ops_9, 10 : ops_10}
165+
module = DynamicSliceModel()
166+
x = torch.rand(1, 2)
167+
example_output = module(x)
168+
check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10], example_outputs=example_output)
169+
170+
def test_flip(self):
171+
class MyModule(Module):
172+
def forward(self, x):
173+
return torch.flip(x, dims=[0])
174+
175+
ops_10 = [{"op_name" : "Constant"},
176+
{"op_name" : "Constant"},
177+
{"op_name" : "Constant"},
178+
{"op_name" : "Constant"},
179+
{"op_name" : "Slice",
180+
"attributes" : []}]
181+
ops = {10 : ops_10}
182+
import numpy
183+
x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
184+
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
185+
110186
def test_dropout(self):
111187
class MyModule(Module):
112188
def __init__(self):

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,17 @@ def forward(self, x):
11981198
x = torch.rand(5, 5, 5)
11991199
self.run_model_test(DynamicSliceExportMod(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
12001200

1201+
def test_dynamic_slice_script(self):
1202+
class DynamicSliceModel(torch.jit.ScriptModule):
1203+
@torch.jit.script_method
1204+
def forward(self, x):
1205+
return x[1:x.size(0)]
1206+
module = DynamicSliceModel()
1207+
x = torch.rand(1, 2)
1208+
example_output = module(x)
1209+
self.run_model_test(DynamicSliceModel(), train=False, input=(x,),
1210+
batch_size=BATCH_SIZE, use_gpu=False, example_outputs=example_output)
1211+
12011212
def test_dynamic_slice_to_the_end(self):
12021213
class DynamicSliceExportMod(torch.nn.Module):
12031214
def forward(self, x):

torch/onnx/symbolic_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ def _try_get_scalar_type(*args):
187187
pass
188188
return None
189189

190+
def _slice_op(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
191+
if _export_onnx_opset_version == 9:
192+
from torch.onnx.symbolic_opset9 import slice_op
193+
return slice_op(g, input, axes, starts, ends)
194+
if _export_onnx_opset_version == 10:
195+
from torch.onnx.symbolic_opset10 import slice_op
196+
return slice_op(g, input, axes, starts, ends, steps, dynamic_slice)
190197

191198
# ---------------------------------------------------------------------
192199
# ONNX operator version

torch/onnx/symbolic_opset10.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# ONNX symbolics
66
import torch.onnx.utils
77

8+
import torch.onnx.symbolic_helper as sym_help
89
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset
910
import torch.onnx.symbolic_opset9
1011

@@ -22,9 +23,7 @@
2223
# It is very important to blacklist these operators to avoid exporting
2324
# models with mixed versions of operators.
2425
# TODO : add support for the blacklisted operators in black_listed_operators
25-
black_listed_operators = ["flip",
26-
"slice",
27-
"upsample_nearest2d", "upsample_bilinear2d"]
26+
black_listed_operators = ["upsample_nearest2d", "upsample_bilinear2d"]
2827

2928
for black_listed_op in black_listed_operators:
3029
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
@@ -118,3 +117,45 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include
118117
avg_pool1d = _avg_pool('avg_pool1d', _single)
119118
avg_pool2d = _avg_pool('avg_pool2d', _pair)
120119
avg_pool3d = _avg_pool('avg_pool3d', _triple)
120+
121+
122+
def slice_op(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
123+
if dynamic_slice:
124+
starts = g.op("Unsqueeze", starts, axes_i=[0])
125+
ends = g.op("Unsqueeze", ends, axes_i=[0])
126+
axes = g.op("Unsqueeze", axes, axes_i=[0])
127+
else:
128+
assert len(starts) == len(ends)
129+
assert len(starts) == len(axes)
130+
assert steps is None or len(starts) == len(steps)
131+
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807 \
132+
and (steps is None or (len(steps) == 1 and steps[0] == 1)):
133+
return input
134+
axes = g.op("Constant", value_t=torch.tensor(axes))
135+
starts = g.op("Constant", value_t=torch.tensor(starts))
136+
ends = g.op("Constant", value_t=torch.tensor(ends))
137+
if steps is None:
138+
return g.op("Slice", input, starts, ends, axes)
139+
steps = g.op("Constant", value_t=torch.tensor(steps))
140+
return g.op("Slice", input, starts, ends, axes, steps)
141+
142+
143+
@parse_args('v', 'v', 'v', 'v', 'i')
144+
def slice(g, self, dim, start, end, step):
145+
if (start.node().kind() != 'onnx::Constant' or
146+
end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant'):
147+
dynamic_slice = True
148+
else:
149+
start = [sym_help._parse_arg(start, 'i')]
150+
end = [sym_help._parse_arg(end, 'i')]
151+
dim = [sym_help._parse_arg(dim, 'i')]
152+
dynamic_slice = False
153+
return sym_help._slice_op(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
154+
155+
156+
@parse_args('v', 'is')
157+
def flip(g, input, dims):
158+
return sym_help._slice_op(g, input, axes=dims,
159+
starts=[-1] * len(dims),
160+
ends=[-9223372036854775807] * len(dims),
161+
steps=[-1] * len(dims))

torch/onnx/symbolic_opset9.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def sign(g, self):
194194
return g.op("Sign", self)
195195

196196

197-
def _slice_op(g, input, axes, starts, ends):
197+
def slice_op(g, input, axes, starts, ends):
198198
assert len(starts) == len(ends)
199199
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
200200
return input
@@ -360,8 +360,8 @@ def select(g, self, dim, index):
360360
# of Gather in caffe2. We need to change this as soon as possible.
361361
# TODO: this breaks if index == -1
362362
index_val = _parse_arg(index, 'i')
363-
slice_node = _slice_op(g, self, axes=[dim],
364-
starts=[index_val], ends=[index_val + 1])
363+
slice_node = sym_help._slice_op(g, self, axes=[dim],
364+
starts=[index_val], ends=[index_val + 1])
365365
return g.op("Squeeze", slice_node, axes_i=[dim])
366366
else:
367367
return g.op("Gather", self, index, axis_i=dim)
@@ -538,8 +538,8 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
538538
kernel_shape_i=[1 for _ in range(ndims)],
539539
strides_i=[1 for _ in range(ndims)])
540540
# convert indices to have non-flattened indices values
541-
s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
542-
starts=tuple_fn(0), ends=tuple_fn(1))
541+
s = sym_help._slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
542+
starts=tuple_fn(0), ends=tuple_fn(1))
543543
indices = sub(g, indices, s)
544544
return r, indices
545545
else:
@@ -674,13 +674,11 @@ def upsample_nearest2d(g, input, output_size):
674674
input_length = len(input.type().sizes())
675675
offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)]))
676676
dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"])
677-
divisor = g.op(
678-
"Slice",
679-
g.op("Shape", input),
680-
axes_i=[0],
681-
ends_i=[input_length],
682-
starts_i=[offset]
683-
)
677+
divisor = sym_help._slice_op(g,
678+
g.op("Shape", input),
679+
axes=[0],
680+
starts=[offset],
681+
ends=[input_length])
684682
divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"])
685683
scale_dims = g.op("Div", dividend, divisor)
686684
scales = g.op("Concat", offsets, scale_dims, axis_i=0)
@@ -703,13 +701,11 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
703701
input_length = len(input.type().sizes())
704702
offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)]))
705703
dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"])
706-
divisor = g.op(
707-
"Slice",
708-
g.op("Shape", input),
709-
axes_i=[0],
710-
ends_i=[input_length],
711-
starts_i=[offset]
712-
)
704+
divisor = sym_help._slice_op(g,
705+
g.op("Shape", input),
706+
axes=[0],
707+
starts=[offset],
708+
ends=[input_length])
713709
divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"])
714710
scale_dims = g.op("Div", dividend, divisor)
715711
scales = g.op("Concat", offsets, scale_dims, axis_i=0)
@@ -1161,7 +1157,7 @@ def slice(g, self, dim, start, end, step):
11611157
start = _parse_arg(start, 'i')
11621158
end = _parse_arg(end, 'i')
11631159
dim = _parse_arg(dim, 'i')
1164-
return _slice_op(g, self, axes=[dim], starts=[start], ends=[end])
1160+
return sym_help._slice_op(g, self, axes=[dim], starts=[start], ends=[end])
11651161

11661162

11671163
@parse_args('v', 'f', 'f')
@@ -1306,7 +1302,7 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
13061302
reform_permutation = [(0, 1), (3, 4), (1, 3)]
13071303

13081304
def reform_weights(g, w, n, intervals):
1309-
slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals]
1305+
slices = [sym_help._slice_op(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals]
13101306
return g.op('Concat', *slices, axis_i=0)
13111307

13121308
def transform_weights(layer_index):
@@ -1320,7 +1316,7 @@ def transform_weights(layer_index):
13201316
return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat))
13211317

13221318
def retrieve_state(x, start, end):
1323-
return x if num_layers == 1 else g.op('Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
1319+
return x if num_layers == 1 else sym_help._slice_op(g, x, axes=[0], starts=[start], ends=[end])
13241320

13251321
for i in range(num_layers):
13261322
if unidirectional:
@@ -1552,7 +1548,7 @@ def isnan(g, input):
15521548

15531549
@parse_args('v', 'i', 'i', 'i')
15541550
def narrow(g, input, dim, start, length):
1555-
return _slice_op(g, input, axes=[dim], starts=[start], ends=[start + length])
1551+
return sym_help._slice_op(g, input, axes=[dim], starts=[start], ends=[start + length])
15561552

15571553

15581554
def argmax(g, input, dim, keepdim):

0 commit comments

Comments
 (0)