@@ -342,11 +342,50 @@ def backward_extended(self, grad_output, grad_hy):
342342 return grad_input , grad_weight , grad_hx
343343
344344
345- def RNN_symbolic_builder (* args , ** kwargs ):
346- def symbolic (g , input , all_weights , hx , ** kwargs ):
347- # Something can go here, e.g.
348- # return g.op('LSTM', input, *all_weights[0], outputs=2)
349- raise RuntimeError ("RNN symbolic NYI" )
345+ def RNN_symbolic_builder (cell_type , input_size , hidden_size , num_layers , batch_first , dropout , bidirectional , ** kwargs ):
346+ assert cell_type == 'LSTM'
347+ assert not batch_first
348+ assert not dropout
349+ assert not bidirectional
350+
351+ def symbolic (g , input , all_weights , h0_and_c0 , ** fkwargs ):
352+ h0 , c0 = h0_and_c0
353+ sequence_len = input .type ().sizes ()[0 ]
354+ batch_size = input .type ().sizes ()[1 ]
355+
356+ # TODO leave out this argument to increase parametricity.
357+ # This is nontrivial because we provide subsequent optional
358+ # arguments, and ONNX does not have a mechanism for skipping
359+ # non-trailing optional arguments.
360+ sequence_lens = g .op ('Constant' , value_t = torch .IntTensor (batch_size ).fill_ (sequence_len ))
361+
362+ prev_output = input
363+ h_outs = []
364+ for i in range (num_layers ):
365+ # pytorch is input, forget, cell, output.
366+ # onnx is input, output, forget, cell.
367+ # Therefore lots of awkward slicing and concatenation.
368+
369+ def reform (x ):
370+ return g .op (
371+ 'Concat' ,
372+ g .op ('Slice' , x , axes_i = [0 ], starts_i = [0 * hidden_size ], ends_i = [1 * hidden_size ]),
373+ g .op ('Slice' , x , axes_i = [0 ], starts_i = [3 * hidden_size ], ends_i = [4 * hidden_size ]),
374+ g .op ('Slice' , x , axes_i = [0 ], starts_i = [1 * hidden_size ], ends_i = [3 * hidden_size ]),
375+ axis_i = 0 )
376+
377+ weight_ih , weight_hh , bias_ih , bias_hh = map (reform , all_weights [i ])
378+
379+ bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
380+
381+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
382+ c_in = c0 if num_layers == 1 else g .op ('Slice' , c0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
383+
384+ inputs = [prev_output , weight_ih , weight_hh , bias_concat , sequence_lens , h_in , c_in ]
385+ prev_output , h_out = g .op ('LSTM' , * inputs , outputs = 2 , hidden_size_i = hidden_size )
386+ h_outs .append (h_out )
387+ h_outs = h_out if num_layers == 1 else g .op ('Concat' , * h_outs , axis_i = 0 )
388+ return prev_output , h_outs , None
350389
351390 import torch .onnx
352391 return torch .onnx .symbolic_override (symbolic )
0 commit comments