Skip to content

Commit a3db284

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Support tuples in ScriptModule inputs/outputs (#20784)
Summary: - [x] Add tests after #20256 is merged - Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples. - Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass ```_jit_pass_lower_all_tuples```. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in ```_jit_pass_lower_all_tuples```. Handling inputs is easier because all tuple information is encoded within the input tensor type. - Swap the order of ```_jit_pass_lower_all_tuples``` and ```_jit_pass_erase_number_types```. Ops like ```prim::TupleIndex``` relies on index being a scalar. ```_jit_pass_erase_number_types``` will convert these kind of scalars to tensors. Pull Request resolved: #20784 Reviewed By: zrphercule Differential Revision: D15484171 Pulled By: houseroad fbshipit-source-id: 4767a84038244c929f5662758047af6cb92228d3
1 parent 4c03ac7 commit a3db284

File tree

3 files changed

+107
-34
lines changed

3 files changed

+107
-34
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,29 @@ def forward(self, lstm_in):
16511651

16521652
self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)
16531653

1654+
def test_tuple_input_output(self):
1655+
class TupleModel(torch.jit.ScriptModule):
1656+
@torch.jit.script_method
1657+
def forward(self, a):
1658+
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
1659+
return a
1660+
1661+
x = (torch.randn(3, 4), torch.randn(4, 3))
1662+
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
1663+
example_outputs=(x,))
1664+
1665+
def test_nested_tuple_input_output(self):
1666+
class NestedTupleModel(torch.jit.ScriptModule):
1667+
@torch.jit.script_method
1668+
def forward(self, a, b):
1669+
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
1670+
return a + b[0] + b[1][0] + b[1][1]
1671+
1672+
x = torch.randn(4, 5)
1673+
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
1674+
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
1675+
example_outputs=x + y[0] + y[1][0] + y[1][1])
1676+
16541677
def test_topk(self):
16551678
class TopKModel(torch.nn.Module):
16561679
def forward(self, input):

torch/csrc/jit/script/init.cpp

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,58 @@ static Self moduleSelf(
178178
};
179179
}
180180

181-
static void setInputTensorTypes(Graph& g, const Stack& stack) {
182-
AT_ASSERT(stack.size() == g.inputs().size());
183-
for (size_t i = 0; i < stack.size(); ++i) {
184-
g.inputs().at(i)->setType(
185-
DimensionedTensorType::create(stack.at(i).toTensor()));
181+
static TypePtr getTensorType(
182+
const at::Tensor& t,
183+
const TypeKind type_kind) {
184+
switch (type_kind) {
185+
case TypeKind::DimensionedTensorType:
186+
return DimensionedTensorType::create(t);
187+
case TypeKind::CompleteTensorType: {
188+
auto scalar_type = t.scalar_type();
189+
auto sizes = t.sizes();
190+
return CompleteTensorType::create(scalar_type, at::kCPU, sizes);
191+
}
192+
default:
193+
throw std::runtime_error(
194+
"Attempted to call getTensorType for type kind other than DimensionedTensorType or CompleteTensorType.");
195+
}
196+
}
197+
198+
static TupleTypePtr getTupleTensorType(
199+
const Stack::const_iterator& s_iter,
200+
const Stack::const_iterator& s_iter_end,
201+
const TypePtr& tupleType,
202+
const TypeKind type_kind) {
203+
AT_ASSERT(tupleType->kind() == TupleType::Kind);
204+
AT_ASSERT(s_iter != s_iter_end);
205+
206+
std::vector<TypePtr> types;
207+
for (const auto& subType : tupleType->containedTypes()) {
208+
if (subType->kind() == TupleType::Kind) {
209+
types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, type_kind));
210+
} else {
211+
types.push_back(getTensorType(s_iter->toTensor(), type_kind));
212+
}
213+
}
214+
return TupleType::create(types);
215+
}
216+
217+
static void setInputTensorTypes(
218+
Graph& g,
219+
const Stack& stack,
220+
const TypeKind type_kind = TypeKind::DimensionedTensorType) {
221+
at::ArrayRef<Value*> input_values = g.inputs();
222+
auto s_iter = stack.begin();
223+
for (auto v : input_values) {
224+
AT_ASSERT(s_iter != stack.end());
225+
if (v->type()->kind() == TupleType::Kind) {
226+
AT_ASSERT(v->node()->kind() == prim::Param);
227+
v->setType(
228+
getTupleTensorType(s_iter, stack.end(), v->type(), type_kind));
229+
} else {
230+
v->setType(getTensorType(s_iter->toTensor(), type_kind));
231+
s_iter++;
232+
}
186233
}
187234
}
188235

@@ -197,38 +244,32 @@ static std::shared_ptr<Graph> _propagate_shapes(
197244
return retval;
198245
}
199246

200-
static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
247+
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
201248
Graph& graph,
202-
std::vector<at::Tensor> inputs,
203-
std::vector<at::Tensor> outputs,
249+
const std::vector<at::Tensor>& inputs,
204250
bool with_grad = false,
205251
bool propagate = true) {
206252
auto retval = graph.copy();
207253
if (propagate) {
208-
setInputTensorTypes(*retval, fmap<IValue>(inputs));
254+
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::DimensionedTensorType);
209255
PropagateInputShapes(retval);
210256
}
211-
AT_ASSERT(retval->inputs().size() == inputs.size());
212-
for (size_t i = 0; i < retval->inputs().size(); ++i) {
213-
auto scalar_type = inputs[i].scalar_type();
214-
auto sizes = inputs[i].sizes();
215-
auto type =
216-
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
217-
retval->inputs()[i]->setType(type);
218-
}
219-
at::ArrayRef<Value*> output_values = retval->outputs();
220-
// patch this to still work if we are returning a tuple of multiple values
221-
if (output_values.at(0)->type()->kind() == TupleType::Kind) {
222-
AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
223-
output_values = output_values.at(0)->node()->inputs();
224-
}
225-
AT_ASSERT(output_values.size() == outputs.size());
257+
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::CompleteTensorType);
258+
259+
return retval;
260+
}
261+
262+
static std::shared_ptr<Graph> _assign_output_shapes(
263+
Graph& graph,
264+
std::vector<at::Tensor> outputs) {
265+
auto retval = graph.copy();
266+
AT_ASSERT(retval->outputs().size() == outputs.size());
226267
for (size_t i = 0; i < outputs.size(); ++i) {
227268
auto scalar_type = outputs[i].scalar_type();
228269
auto sizes = outputs[i].sizes();
229270
auto type =
230271
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
231-
output_values[i]->setType(type);
272+
retval->outputs()[i]->setType(type);
232273
}
233274
return retval;
234275
}
@@ -679,8 +720,11 @@ void initJitScriptBindings(PyObject* module) {
679720
debugSetAutodiffSubgraphInlining);
680721
m.def("_propagate_shapes", _propagate_shapes);
681722
m.def(
682-
"_propagate_and_assign_input_and_output_shapes",
683-
_propagate_and_assign_input_and_output_shapes);
723+
"_propagate_and_assign_input_shapes",
724+
_propagate_and_assign_input_shapes);
725+
m.def(
726+
"_assign_output_shapes",
727+
_assign_output_shapes);
684728
m.def("_jit_python_print", [](py::object obj) {
685729
std::ostringstream ss;
686730
std::vector<at::Tensor> constants;

torch/onnx/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch._six import string_classes
1717
from torch.jit import _unique_state_dict
1818
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
19-
from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
19+
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes
2020

2121

2222
# the flag to tell the user whether it's in the middle of ONNX export or not
@@ -214,10 +214,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
214214

215215
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
216216
torch._C._jit_pass_prepare_division_for_onnx(graph)
217-
# onnx only supports tensors, so we turn all out number types into tensors
218-
torch._C._jit_pass_erase_number_types(graph)
219217
# onnx does not support tuples, so try to remove them
220218
torch._C._jit_pass_lower_all_tuples(graph)
219+
# onnx only supports tensors, so we turn all out number types into tensors
220+
torch._C._jit_pass_erase_number_types(graph)
221221
torch._C._jit_pass_peephole(graph, True)
222222
torch._C._jit_pass_lint(graph)
223223

@@ -288,16 +288,18 @@ def _model_to_graph(model, args, verbose=False, training=False,
288288
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
289289
try:
290290
method_graph, params = model.forward._lowered_graph()
291-
graph = _propagate_and_assign_input_and_output_shapes(
292-
method_graph, tuple(args) + tuple(params), example_outputs, False, propagate)
291+
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
292+
graph = _propagate_and_assign_input_shapes(
293+
method_graph, tuple(in_vars), False, propagate)
293294
except AttributeError:
294295
raise RuntimeError('\'forward\' method must be a script method')
295296
elif isinstance(model, torch.jit.Function):
296297
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
297298
method = model
298299
params = ()
299-
graph = _propagate_and_assign_input_and_output_shapes(
300-
model.graph, tuple(args), example_outputs, False, propagate)
300+
in_vars, in_desc = torch.jit._flatten(tuple(args))
301+
graph = _propagate_and_assign_input_shapes(
302+
model.graph, tuple(in_vars), False, propagate)
301303
else:
302304
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
303305
state_dict = _unique_state_dict(model)
@@ -313,6 +315,10 @@ def _model_to_graph(model, args, verbose=False, training=False,
313315
graph = _optimize_graph(graph, operator_export_type,
314316
_disable_torch_constant_prop=_disable_torch_constant_prop)
315317

318+
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function):
319+
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
320+
graph = _assign_output_shapes(graph, out_vars)
321+
316322
# NB: ONNX requires complete information about output types, which might be
317323
# erased by some optimizations, so we need to set it explicitly again.
318324
if torch_out is not None:

0 commit comments

Comments
 (0)