@@ -573,25 +573,38 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
573573 return _unimplemented ("RNN" , "batch_first" )
574574 if dropout :
575575 return _unimplemented ("RNN" , "dropout" )
576- if bidirectional :
577- return _unimplemented ( "RNN" , " bidirectional" )
576+
577+ unidirectional = not bidirectional
578578
579579 prev_output = input
580580 h_outs = []
581581
582582 sequence_lens = unused (g ) if batch_sizes is None else batch_sizes
583583
584584 for i in range (num_layers ):
585- weight_ih , weight_hh , bias_ih , bias_hh = all_weights [i ]
586-
587- bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
588-
589- h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
585+ if unidirectional :
586+ weight_ih , weight_hh , bias_ih , bias_hh = all_weights [i ]
587+ bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
588+
589+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
590+ else :
591+ weight_ih = g .op ('Concat' , all_weights [2 * i ][0 ], all_weights [2 * i + 1 ][0 ], axis_i = 0 )
592+ weight_hh = g .op ('Concat' , all_weights [2 * i ][1 ], all_weights [2 * i + 1 ][1 ], axis_i = 0 )
593+ bias_concat = g .op ('Concat' ,
594+ all_weights [2 * i ][2 ],
595+ all_weights [2 * i ][3 ],
596+ all_weights [2 * i + 1 ][2 ],
597+ all_weights [2 * i + 1 ][3 ],
598+ axis_i = 0 )
599+
600+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [2 * i ], ends_i = [2 * i + 2 ])
590601
591602 inputs = [prev_output , weight_ih , weight_hh , bias_concat , sequence_lens , h_in ]
603+ extra_kwargs = {} if unidirectional else {'direction_s' : 'bidirectional' }
592604 prev_output , h_out = g .op ('RNN' , * inputs , outputs = 2 ,
593605 hidden_size_i = hidden_size ,
594- activations_s = [nonlinearity .lower ()])
606+ activations_s = [nonlinearity .lower ()],
607+ ** extra_kwargs )
595608 h_outs .append (h_out )
596609 h_outs = h_out if num_layers == 1 else g .op ('Concat' , * h_outs , axis_i = 0 )
597610 return prev_output , h_outs
@@ -605,8 +618,8 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
605618 return _unimplemented ("LSTM" , "batch_first" )
606619 if dropout :
607620 return _unimplemented ("LSTM" , "dropout" )
608- if bidirectional :
609- return _unimplemented ( "LSTM" , " bidirectional" )
621+
622+ unidirectional = not bidirectional
610623
611624 h0 , c0 = h0_and_c0
612625
@@ -616,18 +629,36 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
616629 sequence_lens = unused (g ) if batch_sizes is None else batch_sizes
617630
618631 for i in range (num_layers ):
619- # pytorch is input, forget, cell, output.
620- # onnx is input, output, forget, cell.
621- weight_ih , weight_hh , bias_ih , bias_hh = \
622- [reform_weights (g , w , hidden_size , [(0 , 1 ), (3 , 4 ), (1 , 3 )]) for w in all_weights [i ]]
623-
624- bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
625-
626- h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
627- c_in = c0 if num_layers == 1 else g .op ('Slice' , c0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
632+ if unidirectional :
633+ # pytorch is input, forget, cell, output.
634+ # onnx is input, output, forget, cell.
635+ weight_ih , weight_hh , bias_ih , bias_hh = \
636+ [reform_weights (g , w , hidden_size , [(0 , 1 ), (3 , 4 ), (1 , 3 )]) for w in all_weights [i ]]
637+
638+ bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
639+
640+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
641+ c_in = c0 if num_layers == 1 else g .op ('Slice' , c0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
642+ else :
643+ # pytorch is input, forget, cell, output.
644+ # onnx is input, output, forget, cell.
645+ weight_ih_f , weight_hh_f , bias_ih_f , bias_hh_f = \
646+ [reform_weights (g , w , hidden_size , [(0 , 1 ), (3 , 4 ), (1 , 3 )]) for w in all_weights [2 * i ]]
647+ weight_ih_b , weight_hh_b , bias_ih_b , bias_hh_b = \
648+ [reform_weights (g , w , hidden_size , [(0 , 1 ), (3 , 4 ), (1 , 3 )]) for w in all_weights [2 * i + 1 ]]
649+
650+ weight_ih = g .op ('Concat' , weight_ih_f , weight_ih_b , axis_i = 0 )
651+ weight_hh = g .op ('Concat' , weight_hh_f , weight_hh_b , axis_i = 0 )
652+ bias_concat = g .op ('Concat' , bias_ih_f , bias_hh_f , bias_ih_b , bias_hh_b , axis_i = 0 )
653+
654+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [2 * i ], ends_i = [2 * i + 2 ])
655+ c_in = c0 if num_layers == 1 else g .op ('Slice' , c0 , axes_i = [0 ], starts_i = [2 * i ], ends_i = [2 * i + 2 ])
628656
629657 inputs = [prev_output , weight_ih , weight_hh , bias_concat , sequence_lens , h_in , c_in ]
630- prev_output , h_out = g .op ('LSTM' , * inputs , outputs = 2 , hidden_size_i = hidden_size )
658+ extra_kwargs = {} if unidirectional else {'direction_s' : 'bidirectional' }
659+ prev_output , h_out = g .op ('LSTM' , * inputs , outputs = 2 ,
660+ hidden_size_i = hidden_size ,
661+ ** extra_kwargs )
631662 h_outs .append (h_out )
632663 h_outs = h_out if num_layers == 1 else g .op ('Concat' , * h_outs , axis_i = 0 )
633664 return prev_output , h_outs , None
@@ -641,27 +672,43 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
641672 return _unimplemented ("GRU" , "batch_first" )
642673 if dropout :
643674 return _unimplemented ("GRU" , "dropout" )
644- if bidirectional :
645- return _unimplemented ( "GRU" , " bidirectional" )
675+
676+ unidirectional = not bidirectional
646677
647678 prev_output = input
648679 h_outs = []
649680
650681 sequence_lens = unused (g ) if batch_sizes is None else batch_sizes
651682
652683 for i in range (num_layers ):
653- # pytorch is reset, input, hidden
654- # onnx is input, reset, hidden
655- weight_ih , weight_hh , bias_ih , bias_hh = \
656- [reform_weights (g , w , hidden_size , [(1 , 2 ), (0 , 1 ), (2 , 3 )]) for w in all_weights [i ]]
684+ if unidirectional :
685+ # pytorch is reset, input, hidden
686+ # onnx is input, reset, hidden
687+ weight_ih , weight_hh , bias_ih , bias_hh = \
688+ [reform_weights (g , w , hidden_size , [(1 , 2 ), (0 , 1 ), (2 , 3 )]) for w in all_weights [i ]]
689+
690+ bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
691+
692+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
693+ else :
694+ # pytorch is reset, input, hidden
695+ # onnx is input, reset, hidden
696+ weight_ih_f , weight_hh_f , bias_ih_f , bias_hh_f = \
697+ [reform_weights (g , w , hidden_size , [(1 , 2 ), (0 , 1 ), (2 , 3 )]) for w in all_weights [2 * i ]]
698+ weight_ih_b , weight_hh_b , bias_ih_b , bias_hh_b = \
699+ [reform_weights (g , w , hidden_size , [(1 , 2 ), (0 , 1 ), (2 , 3 )]) for w in all_weights [2 * i + 1 ]]
657700
658- bias_concat = g .op ('Concat' , bias_ih , bias_hh , axis_i = 0 )
701+ weight_ih = g .op ('Concat' , weight_ih_f , weight_ih_b , axis_i = 0 )
702+ weight_hh = g .op ('Concat' , weight_hh_f , weight_hh_b , axis_i = 0 )
703+ bias_concat = g .op ('Concat' , bias_ih_f , bias_hh_f , bias_ih_b , bias_hh_b , axis_i = 0 )
659704
660- h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [i ], ends_i = [i + 1 ])
705+ h_in = h0 if num_layers == 1 else g .op ('Slice' , h0 , axes_i = [0 ], starts_i = [2 * i ], ends_i = [2 * i + 2 ])
661706
662707 inputs = [prev_output , weight_ih , weight_hh , bias_concat , sequence_lens , h_in ]
663- prev_output , h_out = g .op (
664- 'GRU' , * inputs , outputs = 2 , hidden_size_i = hidden_size , linear_before_reset_i = 1 )
708+ extra_kwargs = {} if unidirectional else {'direction_s' : 'bidirectional' }
709+ prev_output , h_out = g .op ('GRU' , * inputs , outputs = 2 ,
710+ hidden_size_i = hidden_size , linear_before_reset_i = 1 ,
711+ ** extra_kwargs )
665712 h_outs .append (h_out )
666713 h_outs = h_out if num_layers == 1 else g .op ('Concat' , * h_outs , axis_i = 0 )
667714 return prev_output , h_outs
0 commit comments