Skip to content

Commit 4447b80

Browse files
author
Anders Papitto
committed
support RNN export
1 parent e6ad0ea commit 4447b80

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

torch/autograd/function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def unflatten_helper(input, proto):
304304

305305
_iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), condition_msg="Variables")
306306
_iter_variables_permissive = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), skip_unknown=True)
307-
_iter_jit_values = _iter_filter(lambda o: isinstance(o, torch._C.Value), condition_msg="jit's Values")
307+
_iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
308+
condition_msg="jit's Values or None")
308309
_iter_tensors = _iter_filter(torch.is_tensor, condition_msg="Tensors")
309310
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o), condition_msg="Tensors or None")
310311
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable),

torch/nn/_functions/rnn.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,50 @@ def backward_extended(self, grad_output, grad_hy):
342342
return grad_input, grad_weight, grad_hx
343343

344344

345-
def RNN_symbolic_builder(*args, **kwargs):
346-
def symbolic(g, input, all_weights, hx, **kwargs):
347-
# Something can go here, e.g.
348-
# return g.op('LSTM', input, *all_weights[0], outputs=2)
349-
raise RuntimeError("RNN symbolic NYI")
345+
def RNN_symbolic_builder(cell_type, input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs):
346+
assert cell_type == 'LSTM'
347+
assert not batch_first
348+
assert not dropout
349+
assert not bidirectional
350+
351+
def symbolic(g, input, all_weights, h0_and_c0, **fkwargs):
352+
h0, c0 = h0_and_c0
353+
sequence_len = input.type().sizes()[0]
354+
batch_size = input.type().sizes()[1]
355+
356+
# TODO leave out this argument to increase parametricity.
357+
# This is nontrivial because we provide subsequent optional
358+
# arguments, and ONNX does not have a mechanism for skipping
359+
# non-trailing optional arguments.
360+
sequence_lens = g.op('Constant', value_t=torch.IntTensor(batch_size).fill_(sequence_len))
361+
362+
prev_output = input
363+
h_outs = []
364+
for i in range(num_layers):
365+
# pytorch is input, forget, cell, output.
366+
# onnx is input, output, forget, cell.
367+
# Therefore lots of awkward slicing and concatenation.
368+
369+
def reform(x):
370+
return g.op(
371+
'Concat',
372+
g.op('Slice', x, axes_i=[0], starts_i=[0 * hidden_size], ends_i=[1 * hidden_size]),
373+
g.op('Slice', x, axes_i=[0], starts_i=[3 * hidden_size], ends_i=[4 * hidden_size]),
374+
g.op('Slice', x, axes_i=[0], starts_i=[1 * hidden_size], ends_i=[3 * hidden_size]),
375+
axis_i=0)
376+
377+
weight_ih, weight_hh, bias_ih, bias_hh = map(reform, all_weights[i])
378+
379+
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
380+
381+
h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
382+
c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1])
383+
384+
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in, c_in]
385+
prev_output, h_out = g.op('LSTM', *inputs, outputs=2, hidden_size_i=hidden_size)
386+
h_outs.append(h_out)
387+
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
388+
return prev_output, h_outs, None
350389

351390
import torch.onnx
352391
return torch.onnx.symbolic_override(symbolic)

torch/onnx/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def _graph_op(g, opname, *raw_args, **kwargs):
266266
kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
267267

268268
def const_if_tensor(arg):
269-
if isinstance(arg, torch._C.Value):
269+
if arg is None:
270+
return arg
271+
elif isinstance(arg, torch._C.Value):
270272
return arg
271273
else:
272274
return g.op("Constant", value_z=arg)

0 commit comments

Comments
 (0)