Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 77 additions & 30 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,25 +573,38 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
return _unimplemented("RNN", "batch_first")
if dropout:
return _unimplemented("RNN", "dropout")
if bidirectional:
return _unimplemented("RNN", "bidirectional")

unidirectional = not bidirectional

prev_output = input
h_outs = []

sequence_lens = unused(g) if batch_sizes is None else batch_sizes

for i in range(num_layers):
weight_ih, weight_hh, bias_ih, bias_hh = all_weights[i]

bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
if unidirectional:
weight_ih, weight_hh, bias_ih, bias_hh = all_weights[i]
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
else:
weight_ih = g.op('Concat', all_weights[2 * i][0], all_weights[2 * i + 1][0], axis_i=0)

This comment was marked as off-topic.

This comment was marked as off-topic.

weight_hh = g.op('Concat', all_weights[2 * i][1], all_weights[2 * i + 1][1], axis_i=0)
bias_concat = g.op('Concat',
all_weights[2 * i][2],
all_weights[2 * i][3],
all_weights[2 * i + 1][2],
all_weights[2 * i + 1][3],
axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])

inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in]
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
prev_output, h_out = g.op('RNN', *inputs, outputs=2,
hidden_size_i=hidden_size,
activations_s=[nonlinearity.lower()])
activations_s=[nonlinearity.lower()],
**extra_kwargs)
h_outs.append(h_out)
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
return prev_output, h_outs
Expand All @@ -605,8 +618,8 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
return _unimplemented("LSTM", "batch_first")
if dropout:
return _unimplemented("LSTM", "dropout")
if bidirectional:
return _unimplemented("LSTM", "bidirectional")

unidirectional = not bidirectional

h0, c0 = h0_and_c0

Expand All @@ -616,18 +629,36 @@ def symbolic(g, input, all_weights, h0_and_c0, batch_sizes):
sequence_lens = unused(g) if batch_sizes is None else batch_sizes

for i in range(num_layers):
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[i]]

bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
if unidirectional:
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[i]]

bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
else:
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i]]
weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \
[reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i + 1]]

weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])

inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in, c_in]
prev_output, h_out = g.op('LSTM', *inputs, outputs=2, hidden_size_i=hidden_size)
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
prev_output, h_out = g.op('LSTM', *inputs, outputs=2,
hidden_size_i=hidden_size,
**extra_kwargs)
h_outs.append(h_out)
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
return prev_output, h_outs, None
Expand All @@ -641,27 +672,43 @@ def symbolic(g, input, all_weights, h0, batch_sizes):
return _unimplemented("GRU", "batch_first")
if dropout:
return _unimplemented("GRU", "dropout")
if bidirectional:
return _unimplemented("GRU", "bidirectional")

unidirectional = not bidirectional

prev_output = input
h_outs = []

sequence_lens = unused(g) if batch_sizes is None else batch_sizes

for i in range(num_layers):
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[i]]
if unidirectional:
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[i]]

bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
else:
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i]]
weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \
[reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i + 1]]

bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0)

h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2])

inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in]
prev_output, h_out = g.op(
'GRU', *inputs, outputs=2, hidden_size_i=hidden_size, linear_before_reset_i=1)
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
prev_output, h_out = g.op('GRU', *inputs, outputs=2,
hidden_size_i=hidden_size, linear_before_reset_i=1,
**extra_kwargs)
h_outs.append(h_out)
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
return prev_output, h_outs
Expand Down