Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,9 +2958,11 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size)

if variable_len:
batch_sizes = [7, 5, 5, 2, 1, 1]
input_val = rnn_utils.pack_padded_sequence(input_val, batch_sizes, batch_first=batch_first)
grad_output = rnn_utils.pack_padded_sequence(grad_output, batch_sizes, batch_first=batch_first).data
lengths = [7, 5, 5, 2, 1, 1]
input_val = Variable(input_val)
grad_output = Variable(grad_output)
input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data

rnn = module(input_size,
hidden_size,
Expand Down
7 changes: 5 additions & 2 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ def unflatten_helper(input, proto):
if not isinstance(proto, (list, tuple)):
return input[0], input[1:]
for e in proto:
res_e, input = unflatten_helper(input, e)
res.append(res_e)
if e is None:
res.append(e)
else:
res_e, input = unflatten_helper(input, e)
res.append(res_e)
return type(proto)(res), input

return unflatten_helper(input, proto)[0]
Expand Down
16 changes: 9 additions & 7 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,22 @@ void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g, const
encodeValueInfo(v, output);
}
for (auto node : g->nodes()) {
if (node->kind() == kUndefined && !node->hasUses()) {
// Undefined nodes never show up in ONNX; they're just a tool
// to help symbolics do the right thing.
if (node->kind() == kUndefined) {
// Undefined nodes are used to implement optional inputs. One
// way to "not provide" an optional input is to create an
// Undefined node, and pass its output as that input.
continue;
}
auto p_n = p_g->add_node();
if (node->getSourceLocation()) {
p_n->set_doc_string(node->getSourceLocation()->python_traceback);
}
for(auto input : node->inputs()) {
p_n->add_input(value_name(input));
if (input->node()->kind() == kUndefined) {
p_n->add_input("");
} else {
p_n->add_input(value_name(input));
}
}
for(auto output : node->outputs()) {
p_n->add_output(value_name(output));
Expand Down Expand Up @@ -244,9 +249,6 @@ void validateGraph(const std::shared_ptr<Graph>& graph) {
if (node->kind() == kExpand) {
FAIL_EXPORT("Couldn't export operator expand; this usually means you used a form of broadcasting that ONNX does not currently support");
}
if (node->kind() == kUndefined) {
FAIL_EXPORT("Couldn't export undefined constant tensor (please file an issue)")
}
std::string n = node->kind().toString();
if (n.size() == 0) {
FAIL_EXPORT("Operator to export had empty name (please file an issue)")
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ _(value) \
_(Subgraph) \
_(BatchNormalization) \
_(Conv) \
_(PackPadded) \
_(PadPacked) \
_(ConvTranspose) \
_(is_test) \
_(epsilon) \
Expand All @@ -61,6 +63,9 @@ _(strides) \
_(stride) \
_(pads) \
_(pad) \
_(RNN) \
_(LSTM) \
_(GRU) \
_(beta) \
_(alpha) \
_(dilations) \
Expand Down
17 changes: 12 additions & 5 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ struct Value {
return uses_;
}

void replaceFirstUseWith(Value * newValue);

// Replaces all uses of this node with 'newValue'.
//
// Given: %3 = f(%1, %2)
Expand Down Expand Up @@ -1031,13 +1033,18 @@ inline const Graph * Value::owningGraph() const {
return node()->owningGraph();
}

inline void Value::replaceAllUsesWith(Value * newValue) {
inline void Value::replaceFirstUseWith(Value * newValue) {
JIT_ASSERT(owningGraph() == newValue->owningGraph());
for(auto u : uses()) {
u.user->inputs_[u.offset] = newValue;
newValue->uses_.push_back(u);
auto u = uses()[0];
u.user->inputs_[u.offset] = newValue;
newValue->uses_.push_back(u);
uses_.erase(uses_.begin());
}

inline void Value::replaceAllUsesWith(Value * newValue) {
while (!uses().empty()) {
replaceFirstUseWith(newValue);
}
uses_.clear();
}

inline Node::Node(Graph * graph_, NodeKind kind_) :
Expand Down
91 changes: 90 additions & 1 deletion torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ std::unordered_set<NodeKind> broadcasting = {
kGemm,
};

bool isRNN(const Node *node) {
auto k = node->kind();
return k == kRNN || k == kLSTM || k == kGRU;
}

bool isNopTranspose(const std::vector<int64_t> & perm) {
for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
if (perm[i] != i)
Expand Down Expand Up @@ -167,6 +172,87 @@ void fuseTransposeIntoGemm(std::shared_ptr<Graph>& graph) {
}
}

// Why this is here:
//
// Pytorch has a "packed" representation of sequences, as well as a
// "padded" representation. ONNX has only one representation,
// corresponding to pytorch's "padded". Therefore, we need to remove
// any use of packed sequences before exporting.
//
// What this does:
//
// This code uses the observation that
// RNN(PackPadded(x)) == PackPadded(RNN(x))
// and converts the first form to the second whenever possible,
// "pushing" the packing operation past the RNN operation. Then,
// the removeNopPacking pass removes the packing operations
// entirely by pairing them with their inverse PadPacked. If the
// input graph does not pair the operations, export will fail.
void pushPackingPastRnn(std::shared_ptr<Graph>& graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;

if (n->kind() != kPackPadded) {
continue;
}
if (n->outputs()[0]->uses().size() != 1) {
// For now, only handle the case where there is one consumer.

This comment was marked as off-topic.

This comment was marked as off-topic.

continue;
}
Node * rnn = n->outputs()[0]->uses()[0].user;
if (!isRNN(rnn)) {
continue;
}

// remove PackPadded from in front of the RNN
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);

// note there can be multiple uses of the length blob. If we are
// translating a multi-level RNN it will be an input to each level.
n->outputs()[1]->replaceFirstUseWith(n->inputs()[1]);

// and insert new PackPadded after the RNN
Node * newPackPadded = graph->create(kPackPadded, 2);
newPackPadded->insertAfter(rnn);

// make things consume from the new PackPadded
rnn->outputs()[0]->replaceAllUsesWith(newPackPadded->outputs()[0]);
n->outputs()[1]->replaceAllUsesWith(newPackPadded->outputs()[1]);

// setup the new PackPadded's inputs
newPackPadded->addInput(rnn->outputs()[0]);
newPackPadded->addInput(n->inputs()[1]);

it.destroyCurrent();
}
}

void removeNopPacking(std::shared_ptr<Graph>& graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;

if (n->kind() != kPadPacked) {
continue;
}
Node* input = n->inputs()[0]->node();
if (input->kind() != kPackPadded) {
continue;
}
if (input->outputs()[0] != n->inputs()[0]) {
continue;
}
if (input->outputs()[1] != n->inputs()[1]) {
continue;
}
n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);

n->removeAllInputs();
it.destroyCurrent();
}
}


// This optimization does ONNX-specific peephole optimizations.
//
// At the moment, here are the optimizations it does:
Expand All @@ -175,8 +261,9 @@ void fuseTransposeIntoGemm(std::shared_ptr<Graph>& graph) {
// local information. This optimization is not useful for PyTorch as 'expand'
// is free.
// - Fusing of consecutive transposes
// - Elimiation of NOP transposes
// - Elimination of NOP transposes
// - Fusing of transposes into Gemm
// - Elimination of PaddedSequences
//
// Before you write an optimization here, ask yourself, "Could I do this
// optimization on ATen operators"? If so, you should seriously consider
Expand All @@ -191,6 +278,8 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
fuseConsecutiveTransposes(graph);
eliminateNopTranspose(graph);
fuseTransposeIntoGemm(graph);
pushPackingPastRnn(graph);
removeNopPacking(graph);
}

}}
5 changes: 4 additions & 1 deletion torch/nn/_functions/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def forward(ctx, input, lengths, batch_first):

steps = []
batch_sizes = []
lengths_iter = reversed(lengths)

# lengths is a Tensor, so we must convert to list before reversed()
lengths_iter = reversed(list(lengths))

batch_size = input.size(1)

if len(lengths) != batch_size:
Expand Down
51 changes: 24 additions & 27 deletions torch/nn/_functions/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
num_directions = len(inners)
total_layers = num_layers * num_directions

def forward(input, hidden, weight):
def forward(input, hidden, weight, batch_sizes):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

assert(len(weight) == total_layers)
next_hidden = []

Expand All @@ -82,7 +82,7 @@ def forward(input, hidden, weight):
for j, inner in enumerate(inners):
l = i * num_directions + j

hy, output = inner(input, hidden[l], weight[l])
hy, output = inner(input, hidden[l], weight[l], batch_sizes)
next_hidden.append(hy)
all_output.append(output)

Expand All @@ -107,7 +107,7 @@ def forward(input, hidden, weight):


def Recurrent(inner, reverse=False):
def forward(input, hidden, weight):
def forward(input, hidden, weight, batch_sizes):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
Expand All @@ -124,17 +124,16 @@ def forward(input, hidden, weight):
return forward


def variable_recurrent_factory(batch_sizes):
def fac(inner, reverse=False):
if reverse:
return VariableRecurrentReverse(batch_sizes, inner)
else:
return VariableRecurrent(batch_sizes, inner)
return fac
def variable_recurrent_factory(inner, reverse=False):
if reverse:
return VariableRecurrentReverse(inner)
else:
return VariableRecurrent(inner)


def VariableRecurrent(batch_sizes, inner):
def forward(input, hidden, weight):
def VariableRecurrent(inner):
def forward(input, hidden, weight, batch_sizes):

output = []
input_offset = 0
last_batch_size = batch_sizes[0]
Expand Down Expand Up @@ -172,8 +171,8 @@ def forward(input, hidden, weight):
return forward


def VariableRecurrentReverse(batch_sizes, inner):
def forward(input, hidden, weight):
def VariableRecurrentReverse(inner):
def forward(input, hidden, weight, batch_sizes):
output = []
input_offset = input.size(0)
last_batch_size = batch_sizes[-1]
Expand All @@ -183,7 +182,8 @@ def forward(input, hidden, weight):
hidden = (hidden,)
initial_hidden = (initial_hidden,)
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
for batch_size in reversed(batch_sizes):
for i in reversed(range(len(batch_sizes))):
batch_size = batch_sizes[i]
inc = batch_size - last_batch_size
if inc > 0:
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
Expand All @@ -208,7 +208,7 @@ def forward(input, hidden, weight):


def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, batch_sizes=None,
dropout=0, train=True, bidirectional=False, variable_length=False,
dropout_state=None, flat_weight=None):

if mode == 'RNN_RELU':
Expand All @@ -222,10 +222,7 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
else:
raise Exception('Unknown mode: {}'.format(mode))

if batch_sizes is None:
rec_factory = Recurrent
else:
rec_factory = variable_recurrent_factory(batch_sizes)
rec_factory = variable_recurrent_factory if variable_length else Recurrent

if bidirectional:
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
Expand All @@ -238,13 +235,13 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=dropout,
train=train)

def forward(input, weight, hidden):
if batch_first and batch_sizes is None:
def forward(input, weight, hidden, batch_sizes):
if batch_first and not variable_length:
input = input.transpose(0, 1)

nexth, output = func(input, hidden, weight)
nexth, output = func(input, hidden, weight, batch_sizes)

if batch_first and batch_sizes is None:
if batch_first and not variable_length:
output = output.transpose(0, 1)

return output, nexth
Expand All @@ -254,7 +251,7 @@ def forward(input, weight, hidden):

def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
batch_first=False, dropout=0, train=True, bidirectional=False,
batch_sizes=None, dropout_state=None, flat_weight=None):
variable_length=False, dropout_state=None, flat_weight=None):
if dropout_state is None:
dropout_state = {}
mode = cudnn.rnn.get_cudnn_mode(mode)
Expand All @@ -265,7 +262,7 @@ def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
"at every call, possibly greatly increasing memory usage. "
"To compact weights again call flatten_parameters().", stacklevel=5)

def forward(input, weight, hx):
def forward(input, weight, hx, batch_sizes):
if mode == cudnn.CUDNN_LSTM:
hx, cx = hx
else:
Expand All @@ -283,7 +280,7 @@ def forward(input, weight, hx):
hx, cx,
mode, hidden_size, num_layers,
batch_first, dropout, train, bool(bidirectional),
batch_sizes if batch_sizes else (),
list(batch_sizes.data) if variable_length else (),
Variable(dropout_desc.state) if dropout_desc.state is not None else None)

if cx is not None:
Expand Down
Loading