Skip to content

Commit b2cfd96

Browse files
anderspapittosoumith
authored andcommitted
Handle sequence lengths correctly when exporting RNNs to ONNX (#4695)
* PackedSequence: store batch_sizes as tensor rather than converting to a list of python integers. This maintains the invariant that module's inputs/outputs are collections of Variables. In particular, this causes the JIT to no longer choke when flattening and unflattening arguments. * Handle sequence lengths correctly when exporting RNNs to ONNX - when uniform sequence lengths are provided, correctly omit the argument when constructing the ONNX graph, so as to not fix the graph to the batch size. - handle PackedSequences by floating them through the graph and eliminating them in an optimization pass. ONNX does not have packed sequences, but operates on a representation equivalent to PaddedSequence, so we hide the representation-switching from ONNX - as a preliminary step towards handling PackedSequences, not directly tied to ONNX export, change batch_sizes from being an argument to the RNN operators into being an argument to the forward() function of those RNN operators. This more closely models the reality that batch_sizes are effectively part of the input sequences.
1 parent f796080 commit b2cfd96

File tree

11 files changed

+259
-113
lines changed

11 files changed

+259
-113
lines changed

test/test_nn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,9 +3054,11 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
30543054
grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size)
30553055

30563056
if variable_len:
3057-
batch_sizes = [7, 5, 5, 2, 1, 1]
3058-
input_val = rnn_utils.pack_padded_sequence(input_val, batch_sizes, batch_first=batch_first)
3059-
grad_output = rnn_utils.pack_padded_sequence(grad_output, batch_sizes, batch_first=batch_first).data
3057+
lengths = [7, 5, 5, 2, 1, 1]
3058+
input_val = Variable(input_val)
3059+
grad_output = Variable(grad_output)
3060+
input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
3061+
grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data
30603062

30613063
rnn = module(input_size,
30623064
hidden_size,

torch/autograd/function.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,11 @@ def unflatten_helper(input, proto):
295295
if not isinstance(proto, (list, tuple)):
296296
return input[0], input[1:]
297297
for e in proto:
298-
res_e, input = unflatten_helper(input, e)
299-
res.append(res_e)
298+
if e is None:
299+
res.append(e)
300+
else:
301+
res_e, input = unflatten_helper(input, e)
302+
res.append(res_e)
300303
return type(proto)(res), input
301304

302305
return unflatten_helper(input, proto)[0]

torch/csrc/jit/export.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,22 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const
190190
encodeValueInfo(v, output);
191191
}
192192
for (auto node : g->nodes()) {
193-
if (node->kind() == kUndefined && !node->hasUses()) {
194-
// Undefined nodes never show up in ONNX; they're just a tool
195-
// to help symbolics do the right thing.
193+
if (node->kind() == kUndefined) {
194+
// Undefined nodes are used to implement optional inputs. One
195+
// way to "not provide" an optional input is to create an
196+
// Undefined node, and pass its output as that input.
196197
continue;
197198
}
198199
auto p_n = p_g->add_node();
199200
if (node->getSourceLocation()) {
200201
p_n->set_doc_string(node->getSourceLocation()->python_traceback);
201202
}
202203
for(auto input : node->inputs()) {
203-
p_n->add_input(value_name(input));
204+
if (input->node()->kind() == kUndefined) {
205+
p_n->add_input("");
206+
} else {
207+
p_n->add_input(value_name(input));
208+
}
204209
}
205210
for(auto output : node->outputs()) {
206211
p_n->add_output(value_name(output));
@@ -244,9 +249,6 @@ void validateGraph(const std::shared_ptr<Graph>& graph) {
244249
if (node->kind() == kExpand) {
245250
FAIL_EXPORT("Couldn't export operator expand; this usually means you used a form of broadcasting that ONNX does not currently support");
246251
}
247-
if (node->kind() == kUndefined) {
248-
FAIL_EXPORT("Couldn't export undefined constant tensor (please file an issue)")
249-
}
250252
std::string n = node->kind().toString();
251253
if (n.size() == 0) {
252254
FAIL_EXPORT("Operator to export had empty name (please file an issue)")

torch/csrc/jit/interned_strings.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ _(value) \
4545
_(Subgraph) \
4646
_(BatchNormalization) \
4747
_(Conv) \
48+
_(PackPadded) \
49+
_(PadPacked) \
4850
_(ConvTranspose) \
4951
_(is_test) \
5052
_(epsilon) \
@@ -61,6 +63,9 @@ _(strides) \
6163
_(stride) \
6264
_(pads) \
6365
_(pad) \
66+
_(RNN) \
67+
_(LSTM) \
68+
_(GRU) \
6469
_(beta) \
6570
_(alpha) \
6671
_(dilations) \

torch/csrc/jit/ir.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ struct Value {
211211
return uses_;
212212
}
213213

214+
void replaceFirstUseWith(Value * newValue);
215+
214216
// Replaces all uses of this node with 'newValue'.
215217
//
216218
// Given: %3 = f(%1, %2)
@@ -1031,13 +1033,18 @@ inline const Graph * Value::owningGraph() const {
10311033
return node()->owningGraph();
10321034
}
10331035

1034-
inline void Value::replaceAllUsesWith(Value * newValue) {
1036+
inline void Value::replaceFirstUseWith(Value * newValue) {
10351037
JIT_ASSERT(owningGraph() == newValue->owningGraph());
1036-
for(auto u : uses()) {
1037-
u.user->inputs_[u.offset] = newValue;
1038-
newValue->uses_.push_back(u);
1038+
auto u = uses()[0];
1039+
u.user->inputs_[u.offset] = newValue;
1040+
newValue->uses_.push_back(u);
1041+
uses_.erase(uses_.begin());
1042+
}
1043+
1044+
inline void Value::replaceAllUsesWith(Value * newValue) {
1045+
while (!uses().empty()) {
1046+
replaceFirstUseWith(newValue);
10391047
}
1040-
uses_.clear();
10411048
}
10421049

10431050
inline Node::Node(Graph * graph_, NodeKind kind_) :

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ std::unordered_set<NodeKind> broadcasting = {
2020
kGemm,
2121
};
2222

23+
bool isRNN(const Node *node) {
24+
auto k = node->kind();
25+
return k == kRNN || k == kLSTM || k == kGRU;
26+
}
27+
2328
bool isNopTranspose(const std::vector<int64_t> & perm) {
2429
for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
2530
if (perm[i] != i)
@@ -167,6 +172,87 @@ void fuseTransposeIntoGemm(std::shared_ptr<Graph>& graph) {
167172
}
168173
}
169174

175+
// Why this is here:
176+
//
177+
// Pytorch has a "packed" representation of sequences, as well as a
178+
// "padded" representation. ONNX has only one representation,
179+
// corresponding to pytorch's "padded". Therefore, we need to remove
180+
// any use of packed sequences before exporting.
181+
//
182+
// What this does:
183+
//
184+
// This code uses the observation that
185+
// RNN(PackPadded(x)) == PackPadded(RNN(x))
186+
// and converts the first form to the second whenever possible,
187+
// "pushing" the packing operation past the RNN operation. Then,
188+
// the removeNopPacking pass removes the packing operations
189+
// entirely by pairing them with their inverse PadPacked. If the
190+
// input graph does not pair the operations, export will fail.
191+
void pushPackingPastRnn(std::shared_ptr<Graph>& graph) {
192+
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
193+
auto* n = *it;
194+
195+
if (n->kind() != kPackPadded) {
196+
continue;
197+
}
198+
if (n->outputs()[0]->uses().size() != 1) {
199+
// For now, only handle the case where there is one consumer.
200+
continue;
201+
}
202+
Node * rnn = n->outputs()[0]->uses()[0].user;
203+
if (!isRNN(rnn)) {
204+
continue;
205+
}
206+
207+
// remove PackPadded from in front of the RNN
208+
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);
209+
210+
// note there can be multiple uses of the length blob. If we are
211+
// translating a multi-level RNN it will be an input to each level.
212+
n->outputs()[1]->replaceFirstUseWith(n->inputs()[1]);
213+
214+
// and insert new PackPadded after the RNN
215+
Node * newPackPadded = graph->create(kPackPadded, 2);
216+
newPackPadded->insertAfter(rnn);
217+
218+
// make things consume from the new PackPadded
219+
rnn->outputs()[0]->replaceAllUsesWith(newPackPadded->outputs()[0]);
220+
n->outputs()[1]->replaceAllUsesWith(newPackPadded->outputs()[1]);
221+
222+
// setup the new PackPadded's inputs
223+
newPackPadded->addInput(rnn->outputs()[0]);
224+
newPackPadded->addInput(n->inputs()[1]);
225+
226+
it.destroyCurrent();
227+
}
228+
}
229+
230+
void removeNopPacking(std::shared_ptr<Graph>& graph) {
231+
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
232+
auto* n = *it;
233+
234+
if (n->kind() != kPadPacked) {
235+
continue;
236+
}
237+
Node* input = n->inputs()[0]->node();
238+
if (input->kind() != kPackPadded) {
239+
continue;
240+
}
241+
if (input->outputs()[0] != n->inputs()[0]) {
242+
continue;
243+
}
244+
if (input->outputs()[1] != n->inputs()[1]) {
245+
continue;
246+
}
247+
n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
248+
n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);
249+
250+
n->removeAllInputs();
251+
it.destroyCurrent();
252+
}
253+
}
254+
255+
170256
// This optimization does ONNX-specific peephole optimizations.
171257
//
172258
// At the moment, here are the optimizations it does:
@@ -175,8 +261,9 @@ void fuseTransposeIntoGemm(std::shared_ptr<Graph>& graph) {
175261
// local information. This optimization is not useful for PyTorch as 'expand'
176262
// is free.
177263
// - Fusing of consecutive transposes
178-
// - Elimiation of NOP transposes
264+
// - Elimination of NOP transposes
179265
// - Fusing of transposes into Gemm
266+
// - Elimination of PaddedSequences
180267
//
181268
// Before you write an optimization here, ask yourself, "Could I do this
182269
// optimization on ATen operators"? If so, you should seriously consider
@@ -191,6 +278,8 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
191278
fuseConsecutiveTransposes(graph);
192279
eliminateNopTranspose(graph);
193280
fuseTransposeIntoGemm(graph);
281+
pushPackingPastRnn(graph);
282+
removeNopPacking(graph);
194283
}
195284

196285
}}

torch/nn/_functions/packing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ def forward(ctx, input, lengths, batch_first):
1414

1515
steps = []
1616
batch_sizes = []
17-
lengths_iter = reversed(lengths)
17+
18+
# lengths is a Tensor, so we must convert to list before reversed()
19+
lengths_iter = reversed(list(lengths))
20+
1821
batch_size = input.size(1)
1922

2023
if len(lengths) != batch_size:

torch/nn/_functions/rnn.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
7070
num_directions = len(inners)
7171
total_layers = num_layers * num_directions
7272

73-
def forward(input, hidden, weight):
73+
def forward(input, hidden, weight, batch_sizes):
7474
assert(len(weight) == total_layers)
7575
next_hidden = []
7676

@@ -82,7 +82,7 @@ def forward(input, hidden, weight):
8282
for j, inner in enumerate(inners):
8383
l = i * num_directions + j
8484

85-
hy, output = inner(input, hidden[l], weight[l])
85+
hy, output = inner(input, hidden[l], weight[l], batch_sizes)
8686
next_hidden.append(hy)
8787
all_output.append(output)
8888

@@ -107,7 +107,7 @@ def forward(input, hidden, weight):
107107

108108

109109
def Recurrent(inner, reverse=False):
110-
def forward(input, hidden, weight):
110+
def forward(input, hidden, weight, batch_sizes):
111111
output = []
112112
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
113113
for i in steps:
@@ -124,17 +124,16 @@ def forward(input, hidden, weight):
124124
return forward
125125

126126

127-
def variable_recurrent_factory(batch_sizes):
128-
def fac(inner, reverse=False):
129-
if reverse:
130-
return VariableRecurrentReverse(batch_sizes, inner)
131-
else:
132-
return VariableRecurrent(batch_sizes, inner)
133-
return fac
127+
def variable_recurrent_factory(inner, reverse=False):
128+
if reverse:
129+
return VariableRecurrentReverse(inner)
130+
else:
131+
return VariableRecurrent(inner)
134132

135133

136-
def VariableRecurrent(batch_sizes, inner):
137-
def forward(input, hidden, weight):
134+
def VariableRecurrent(inner):
135+
def forward(input, hidden, weight, batch_sizes):
136+
138137
output = []
139138
input_offset = 0
140139
last_batch_size = batch_sizes[0]
@@ -172,8 +171,8 @@ def forward(input, hidden, weight):
172171
return forward
173172

174173

175-
def VariableRecurrentReverse(batch_sizes, inner):
176-
def forward(input, hidden, weight):
174+
def VariableRecurrentReverse(inner):
175+
def forward(input, hidden, weight, batch_sizes):
177176
output = []
178177
input_offset = input.size(0)
179178
last_batch_size = batch_sizes[-1]
@@ -183,7 +182,8 @@ def forward(input, hidden, weight):
183182
hidden = (hidden,)
184183
initial_hidden = (initial_hidden,)
185184
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
186-
for batch_size in reversed(batch_sizes):
185+
for i in reversed(range(len(batch_sizes))):
186+
batch_size = batch_sizes[i]
187187
inc = batch_size - last_batch_size
188188
if inc > 0:
189189
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
@@ -208,7 +208,7 @@ def forward(input, hidden, weight):
208208

209209

210210
def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
211-
dropout=0, train=True, bidirectional=False, batch_sizes=None,
211+
dropout=0, train=True, bidirectional=False, variable_length=False,
212212
dropout_state=None, flat_weight=None):
213213

214214
if mode == 'RNN_RELU':
@@ -222,10 +222,7 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
222222
else:
223223
raise Exception('Unknown mode: {}'.format(mode))
224224

225-
if batch_sizes is None:
226-
rec_factory = Recurrent
227-
else:
228-
rec_factory = variable_recurrent_factory(batch_sizes)
225+
rec_factory = variable_recurrent_factory if variable_length else Recurrent
229226

230227
if bidirectional:
231228
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
@@ -238,13 +235,13 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
238235
dropout=dropout,
239236
train=train)
240237

241-
def forward(input, weight, hidden):
242-
if batch_first and batch_sizes is None:
238+
def forward(input, weight, hidden, batch_sizes):
239+
if batch_first and not variable_length:
243240
input = input.transpose(0, 1)
244241

245-
nexth, output = func(input, hidden, weight)
242+
nexth, output = func(input, hidden, weight, batch_sizes)
246243

247-
if batch_first and batch_sizes is None:
244+
if batch_first and not variable_length:
248245
output = output.transpose(0, 1)
249246

250247
return output, nexth
@@ -254,7 +251,7 @@ def forward(input, weight, hidden):
254251

255252
def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
256253
batch_first=False, dropout=0, train=True, bidirectional=False,
257-
batch_sizes=None, dropout_state=None, flat_weight=None):
254+
variable_length=False, dropout_state=None, flat_weight=None):
258255
if dropout_state is None:
259256
dropout_state = {}
260257
mode = cudnn.rnn.get_cudnn_mode(mode)
@@ -265,7 +262,7 @@ def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
265262
"at every call, possibly greatly increasing memory usage. "
266263
"To compact weights again call flatten_parameters().", stacklevel=5)
267264

268-
def forward(input, weight, hx):
265+
def forward(input, weight, hx, batch_sizes):
269266
if mode == cudnn.CUDNN_LSTM:
270267
hx, cx = hx
271268
else:
@@ -283,7 +280,7 @@ def forward(input, weight, hx):
283280
hx, cx,
284281
mode, hidden_size, num_layers,
285282
batch_first, dropout, train, bool(bidirectional),
286-
batch_sizes if batch_sizes else (),
283+
list(batch_sizes.data) if variable_length else (),
287284
Variable(dropout_desc.state) if dropout_desc.state is not None else None)
288285

289286
if cx is not None:

0 commit comments

Comments
 (0)