Skip to content

Commit 28be521

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Fix bug in exporting node with multiple outputs by scripting
Summary: Pull Request resolved: #20256 Differential Revision: D15422040 Pulled By: houseroad fbshipit-source-id: 5de2a992d7d99a48905c39a1878eb0b3b68d6a3f
1 parent c2e3e79 commit 28be521

File tree

4 files changed

+28
-8
lines changed

4 files changed

+28
-8
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def wrapper(self):
7777
def do_export(model, inputs, *args, **kwargs):
7878
f = io.BytesIO()
7979
out = torch.onnx._export(model, inputs, f, *args, **kwargs)
80+
if isinstance(model, torch.jit.ScriptModule):
81+
# Special case for common case of passing a single Tensor
82+
if isinstance(inputs, torch.Tensor):
83+
inputs = (inputs,)
84+
out = model(*inputs)
8085
return f.getvalue(), out
8186

8287

@@ -178,7 +183,7 @@ def run_actual_test(self, model, train, batch_size, state_dict=None,
178183

179184
# Verify the model runs the same in Caffe2
180185
verify.verify(model, input, c2, rtol=rtol, atol=atol,
181-
do_constant_folding=do_constant_folding)
186+
example_outputs=example_outputs, do_constant_folding=do_constant_folding)
182187

183188
def run_model_test(self, model, train, batch_size, state_dict=None,
184189
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
@@ -1592,10 +1597,19 @@ def test_topk(self):
15921597
class TopKModel(torch.nn.Module):
15931598
def forward(self, input):
15941599
return torch.topk(input, 3)
1595-
model = TopKModel()
1600+
15961601
x = torch.arange(1., 6.)
15971602
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
15981603

1604+
def test_topk_script(self):
1605+
class TopKModel(torch.jit.ScriptModule):
1606+
@torch.jit.script_method
1607+
def forward(self, input):
1608+
return torch.topk(input, 3, dim=0)
1609+
1610+
x = torch.randn(4, 3, requires_grad=True)
1611+
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
1612+
15991613
def test_floor(self):
16001614
class FloorModel(torch.nn.Module):
16011615
def forward(self, input):

test/onnx/verify.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def set_training(model, mode):
244244

245245

246246
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
247-
test_args=2, do_constant_folding=False):
247+
test_args=2, do_constant_folding=False, example_outputs=None):
248248
"""
249249
Export a model into ONNX, import it into a specified ONNX backend, and then
250250
on a few random inputs verify that PyTorch and the backend produced the same
@@ -358,14 +358,18 @@ def load_bytes(b):
358358
with set_training(model, training):
359359
proto_bytes = io.BytesIO()
360360
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
361-
do_constant_folding=do_constant_folding)
361+
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
362+
if isinstance(model, torch.jit.ScriptModule):
363+
torch_out = model(*args)
362364
proto = load_bytes(proto_bytes)
363365
prepared = backend.prepare(proto)
364366

365367
def run(args):
366368
alt_proto_bytes = io.BytesIO()
367369
torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose,
368-
do_constant_folding=do_constant_folding)
370+
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
371+
if isinstance(model, torch.jit.ScriptModule):
372+
torch_out = model(*args)
369373
alt_proto = load_bytes(alt_proto_bytes)
370374
if proto.SerializeToString() != alt_proto.SerializeToString():
371375
# OK, let's try to figure out what happened.

torch/csrc/jit/script/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
223223
output_values = output_values.at(0)->node()->inputs();
224224
}
225225
AT_ASSERT(output_values.size() == outputs.size());
226-
for (size_t i = 0; i < retval->outputs().size(); ++i) {
226+
for (size_t i = 0; i < outputs.size(); ++i) {
227227
auto scalar_type = outputs[i].scalar_type();
228228
auto sizes = outputs[i].sizes();
229229
auto type =

torch/onnx/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def set_training(model, mode):
5656
def export(model, args, f, export_params=True, verbose=False, training=False,
5757
input_names=None, output_names=None, aten=False, export_raw_ir=False,
5858
operator_export_type=None, opset_version=None, _retain_param_name=True,
59-
do_constant_folding=False, strip_doc_string=True):
59+
do_constant_folding=False, example_outputs=None, strip_doc_string=True):
6060
r"""
6161
Export a model into ONNX format. This exporter runs your model
6262
once in order to get a trace of its execution to be exported;
@@ -112,6 +112,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
112112
optimization is applied to the model during export. Constant-folding
113113
optimization will replace some of the ops that have all constant
114114
inputs, with pre-computed constant nodes.
115+
example_outputs (tuple of Tensors, default None): example_outputs must be provided
116+
when exporting a ScriptModule or TorchScript Function.
115117
strip_doc_string (bool, default True): if True, strips the field
116118
"doc_string" from the exported model, which information about the stack
117119
trace.
@@ -128,7 +130,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
128130
_export(model, args, f, export_params, verbose, training, input_names, output_names,
129131
operator_export_type=operator_export_type, opset_version=opset_version,
130132
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
131-
strip_doc_string=strip_doc_string)
133+
example_outputs=example_outputs, strip_doc_string=strip_doc_string)
132134

133135

134136
# ONNX can't handle constants that are lists of tensors, which can

0 commit comments

Comments
 (0)