@@ -194,7 +194,7 @@ def sign(g, self):
194194 return g .op ("Sign" , self )
195195
196196
197- def _slice_op (g , input , axes , starts , ends ):
197+ def slice_op (g , input , axes , starts , ends ):
198198 assert len (starts ) == len (ends )
199199 if len (starts ) == 1 and starts [0 ] == 0 and ends [0 ] == 9223372036854775807 :
200200 return input
@@ -360,8 +360,8 @@ def select(g, self, dim, index):
360360 # of Gather in caffe2. We need to change this as soon as possible.
361361 # TODO: this breaks if index == -1
362362 index_val = _parse_arg (index , 'i' )
363- slice_node = _slice_op (g , self , axes = [dim ],
364- starts = [index_val ], ends = [index_val + 1 ])
363+ slice_node = sym_help . _slice_op (g , self , axes = [dim ],
364+ starts = [index_val ], ends = [index_val + 1 ])
365365 return g .op ("Squeeze" , slice_node , axes_i = [dim ])
366366 else :
367367 return g .op ("Gather" , self , index , axis_i = dim )
@@ -538,8 +538,8 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
538538 kernel_shape_i = [1 for _ in range (ndims )],
539539 strides_i = [1 for _ in range (ndims )])
540540 # convert indices to have non-flattened indices values
541- s = _slice_op (g , flattened_indices , axes = [2 + i for i in range (ndims )],
542- starts = tuple_fn (0 ), ends = tuple_fn (1 ))
541+ s = sym_help . _slice_op (g , flattened_indices , axes = [2 + i for i in range (ndims )],
542+ starts = tuple_fn (0 ), ends = tuple_fn (1 ))
543543 indices = sub (g , indices , s )
544544 return r , indices
545545 else :
@@ -674,13 +674,11 @@ def upsample_nearest2d(g, input, output_size):
674674 input_length = len (input .type ().sizes ())
675675 offsets = g .op ("Constant" , value_t = torch .tensor ([1. for i in range (offset )]))
676676 dividend = g .op ("Cast" , output_size , to_i = sym_help .cast_pytorch_to_onnx ["Float" ])
677- divisor = g .op (
678- "Slice" ,
679- g .op ("Shape" , input ),
680- axes_i = [0 ],
681- ends_i = [input_length ],
682- starts_i = [offset ]
683- )
677+ divisor = sym_help ._slice_op (g ,
678+ g .op ("Shape" , input ),
679+ axes = [0 ],
680+ starts = [offset ],
681+ ends = [input_length ])
684682 divisor = g .op ("Cast" , divisor , to_i = sym_help .cast_pytorch_to_onnx ["Float" ])
685683 scale_dims = g .op ("Div" , dividend , divisor )
686684 scales = g .op ("Concat" , offsets , scale_dims , axis_i = 0 )
@@ -703,13 +701,11 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
703701 input_length = len (input .type ().sizes ())
704702 offsets = g .op ("Constant" , value_t = torch .tensor ([1. for i in range (offset )]))
705703 dividend = g .op ("Cast" , output_size , to_i = sym_help .cast_pytorch_to_onnx ["Float" ])
706- divisor = g .op (
707- "Slice" ,
708- g .op ("Shape" , input ),
709- axes_i = [0 ],
710- ends_i = [input_length ],
711- starts_i = [offset ]
712- )
704+ divisor = sym_help ._slice_op (g ,
705+ g .op ("Shape" , input ),
706+ axes = [0 ],
707+ starts = [offset ],
708+ ends = [input_length ])
713709 divisor = g .op ("Cast" , divisor , to_i = sym_help .cast_pytorch_to_onnx ["Float" ])
714710 scale_dims = g .op ("Div" , dividend , divisor )
715711 scales = g .op ("Concat" , offsets , scale_dims , axis_i = 0 )
@@ -1161,7 +1157,7 @@ def slice(g, self, dim, start, end, step):
11611157 start = _parse_arg (start , 'i' )
11621158 end = _parse_arg (end , 'i' )
11631159 dim = _parse_arg (dim , 'i' )
1164- return _slice_op (g , self , axes = [dim ], starts = [start ], ends = [end ])
1160+ return sym_help . _slice_op (g , self , axes = [dim ], starts = [start ], ends = [end ])
11651161
11661162
11671163@parse_args ('v' , 'f' , 'f' )
@@ -1306,7 +1302,7 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
13061302 reform_permutation = [(0 , 1 ), (3 , 4 ), (1 , 3 )]
13071303
13081304 def reform_weights (g , w , n , intervals ):
1309- slices = [g . op ( 'Slice' , w , axes_i = [0 ], starts_i = [x * n ], ends_i = [y * n ]) for x , y in intervals ]
1305+ slices = [sym_help . _slice_op ( g , w , axes = [0 ], starts = [x * n ], ends = [y * n ]) for x , y in intervals ]
13101306 return g .op ('Concat' , * slices , axis_i = 0 )
13111307
13121308 def transform_weights (layer_index ):
@@ -1320,7 +1316,7 @@ def transform_weights(layer_index):
13201316 return tuple (g .op ('Unsqueeze' , x , axes_i = [0 ]) for x in (weight_ih , weight_hh , bias_concat ))
13211317
13221318 def retrieve_state (x , start , end ):
1323- return x if num_layers == 1 else g . op ( 'Slice' , x , axes_i = [0 ], starts_i = [start ], ends_i = [end ])
1319+ return x if num_layers == 1 else sym_help . _slice_op ( g , x , axes = [0 ], starts = [start ], ends = [end ])
13241320
13251321 for i in range (num_layers ):
13261322 if unidirectional :
@@ -1552,7 +1548,7 @@ def isnan(g, input):
15521548
15531549@parse_args ('v' , 'i' , 'i' , 'i' )
15541550def narrow (g , input , dim , start , length ):
1555- return _slice_op (g , input , axes = [dim ], starts = [start ], ends = [start + length ])
1551+ return sym_help . _slice_op (g , input , axes = [dim ], starts = [start ], ends = [start + length ])
15561552
15571553
15581554def argmax (g , input , dim , keepdim ):
0 commit comments