Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def wrapper(self):
def do_export(model, inputs, *args, **kwargs):
f = io.BytesIO()
out = torch.onnx._export(model, inputs, f, *args, **kwargs)
if isinstance(model, torch.jit.ScriptModule):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fix out (such as unpack) instead of directly rerun the model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the "fix" in your comment... When exporting ScriptModule we don't run the model, we get out=None. The out here is used to compare against caffe2_out, so it shouldn't hurt running the model if it's ScriptModule?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is not related to the "fix" in the PR title. It is basically updating the current test infra to be able to test ScriptModule models. The actual fix is in torch/csrc/jit/script/init.cpp _propagate_and_assign_input_and_output_shapes

# Special case for common case of passing a single Tensor
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
out = model(*inputs)
return f.getvalue(), out


Expand Down Expand Up @@ -178,7 +183,7 @@ def run_actual_test(self, model, train, batch_size, state_dict=None,

# Verify the model runs the same in Caffe2
verify.verify(model, input, c2, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding)
example_outputs=example_outputs, do_constant_folding=do_constant_folding)

def run_model_test(self, model, train, batch_size, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
Expand Down Expand Up @@ -1581,10 +1586,19 @@ def test_topk(self):
class TopKModel(torch.nn.Module):
def forward(self, input):
return torch.topk(input, 3)
model = TopKModel()

x = torch.arange(1., 6.)
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)

def test_topk_script(self):
class TopKModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.topk(input, 3, dim=0)

x = torch.randn(4, 3, requires_grad=True)
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))

def test_floor(self):
class FloorModel(torch.nn.Module):
def forward(self, input):
Expand Down
10 changes: 7 additions & 3 deletions test/onnx/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def set_training(model, mode):


def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
test_args=2, do_constant_folding=False):
test_args=2, do_constant_folding=False, example_outputs=None):
"""
Export a model into ONNX, import it into a specified ONNX backend, and then
on a few random inputs verify that PyTorch and the backend produced the same
Expand Down Expand Up @@ -358,14 +358,18 @@ def load_bytes(b):
with set_training(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding)
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
proto = load_bytes(proto_bytes)
prepared = backend.prepare(proto)

def run(args):
alt_proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding)
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
alt_proto = load_bytes(alt_proto_bytes)
if proto.SerializeToString() != alt_proto.SerializeToString():
# OK, let's try to figure out what happened.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
output_values = output_values.at(0)->node()->inputs();
}
AT_ASSERT(output_values.size() == outputs.size());
for (size_t i = 0; i < retval->outputs().size(); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =
Expand Down
6 changes: 4 additions & 2 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def set_training(model, mode):
def export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=False, strip_doc_string=True):
do_constant_folding=False, example_outputs=None, strip_doc_string=True):
r"""
Export a model into ONNX format. This exporter runs your model
once in order to get a trace of its execution to be exported;
Expand Down Expand Up @@ -112,6 +112,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
optimization is applied to the model during export. Constant-folding
optimization will replace some of the ops that have all constant
inputs, with pre-computed constant nodes.
example_outputs (tuple of Tensors, default None): example_outputs must be provided
when exporting a ScriptModule or TorchScript Function.
strip_doc_string (bool, default True): if True, strips the field
"doc_string" from the exported model, which information about the stack
trace.
Expand All @@ -128,7 +130,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
_export(model, args, f, export_params, verbose, training, input_names, output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
strip_doc_string=strip_doc_string)
example_outputs=example_outputs, strip_doc_string=strip_doc_string)


# ONNX can't handle constants that are lists of tensors, which can
Expand Down