Skip to content

Commit 65fb885

Browse files
anderspapittoezyang
authored andcommitted
Bidirectional RNN export to ONNX (Elman/LSTM/GRU) (#5120)
1 parent 873f116 commit 65fb885

File tree

1 file changed

+77
-30
lines changed

1 file changed

+77
-30
lines changed

torch/onnx/symbolic.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -573,25 +573,38 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
573573
return _unimplemented("RNN", "batch_first")
574574
if dropout:
575575
return _unimplemented("RNN", "dropout")
576-
if bidirectional:
577-
return _unimplemented("RNN", "bidirectional")
576+
577+
unidirectional = not bidirectional
578578

579579
prev_output = input
580580
h_outs = []
581581

582582
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
583583

584584
for i in range(num_layers):
585-
weight_ih, weight_hh, bias_ih, bias_hh = all_weights[i]
586-
587-
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
588-
589-
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
585+
if unidirectional:
586+
weight_ih, weight_hh, bias_ih, bias_hh = all_weights[i]
587+
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
588+
589+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
590+
else:
591+
weight_ih = g.op('Concat', all_weights[2 * i][0], all_weights[2 * i + 1][0], axis_i=0)
592+
weight_hh = g.op('Concat', all_weights[2 * i][1], all_weights[2 * i + 1][1], axis_i=0)
593+
bias_concat = g.op('Concat',
594+
all_weights[2 * i][2],
595+
all_weights[2 * i][3],
596+
all_weights[2 * i + 1][2],
597+
all_weights[2 * i + 1][3],
598+
axis_i=0)
599+
600+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])
590601

591602
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in]
603+
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
592604
prev_output, h_out = g.op('RNN', *inputs, outputs=2,
593605
hidden_size_i=hidden_size,
594-
activations_s=[nonlinearity.lower()])
606+
activations_s=[nonlinearity.lower()],
607+
**extra_kwargs)
595608
h_outs.append(h_out)
596609
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
597610
return prev_output, h_outs
@@ -605,8 +618,8 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
605618
return _unimplemented("LSTM", "batch_first")
606619
if dropout:
607620
return _unimplemented("LSTM", "dropout")
608-
if bidirectional:
609-
return _unimplemented("LSTM", "bidirectional")
621+
622+
unidirectional = not bidirectional
610623

611624
h0, c0 = h0_and_c0
612625

@@ -616,18 +629,36 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
616629
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
617630

618631
for i in range(num_layers):
619-
# pytorch is input, forget, cell, output.
620-
# onnx is input, output, forget, cell.
621-
weight_ih, weight_hh, bias_ih, bias_hh = \
622-
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[i]]
623-
624-
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
625-
626-
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
627-
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
632+
if unidirectional:
633+
# pytorch is input, forget, cell, output.
634+
# onnx is input, output, forget, cell.
635+
weight_ih, weight_hh, bias_ih, bias_hh = \
636+
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[i]]
637+
638+
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
639+
640+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
641+
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
642+
else:
643+
# pytorch is input, forget, cell, output.
644+
# onnx is input, output, forget, cell.
645+
weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \
646+
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i]]
647+
weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \
648+
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i + 1]]
649+
650+
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
651+
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
652+
bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0)
653+
654+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])
655+
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])
628656

629657
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in, c_in]
630-
prev_output, h_out = g.op('LSTM', *inputs, outputs=2, hidden_size_i=hidden_size)
658+
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
659+
prev_output, h_out = g.op('LSTM', *inputs, outputs=2,
660+
hidden_size_i=hidden_size,
661+
**extra_kwargs)
631662
h_outs.append(h_out)
632663
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
633664
return prev_output, h_outs, None
@@ -641,27 +672,43 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
641672
return _unimplemented("GRU", "batch_first")
642673
if dropout:
643674
return _unimplemented("GRU", "dropout")
644-
if bidirectional:
645-
return _unimplemented("GRU", "bidirectional")
675+
676+
unidirectional = not bidirectional
646677

647678
prev_output = input
648679
h_outs = []
649680

650681
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
651682

652683
for i in range(num_layers):
653-
# pytorch is reset, input, hidden
654-
# onnx is input, reset, hidden
655-
weight_ih, weight_hh, bias_ih, bias_hh = \
656-
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[i]]
684+
if unidirectional:
685+
# pytorch is reset, input, hidden
686+
# onnx is input, reset, hidden
687+
weight_ih, weight_hh, bias_ih, bias_hh = \
688+
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[i]]
689+
690+
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
691+
692+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
693+
else:
694+
# pytorch is reset, input, hidden
695+
# onnx is input, reset, hidden
696+
weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \
697+
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i]]
698+
weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \
699+
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i + 1]]
657700

658-
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
701+
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
702+
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
703+
bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0)
659704

660-
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
705+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])
661706

662707
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in]
663-
prev_output, h_out = g.op(
664-
'GRU', *inputs, outputs=2, hidden_size_i=hidden_size, linear_before_reset_i=1)
708+
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
709+
prev_output, h_out = g.op('GRU', *inputs, outputs=2,
710+
hidden_size_i=hidden_size, linear_before_reset_i=1,
711+
**extra_kwargs)
665712
h_outs.append(h_out)
666713
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
667714
return prev_output, h_outs

0 commit comments

Comments
 (0)