Skip to content

Unused argument causes error in JIT graph to ONNX conversion #17534

@heydavid525

Description

@heydavid525

🐛 Bug

When a function takes in an unused variable, JIT graph is created but conversion to ONNX fails.

To Reproduce

Run this script (notice that the only difference between SimpleModel1 and SimpleModel2 is whether dim is used):

from __future__ import print_function
import torch

class SimpleModel1(torch.jit.ScriptModule):
    def __init__(self):
        super(SimpleModel1, self).__init__()

    @torch.jit.script_method
    def forward(self, dim : int):
        x = torch.ones([dim, 2], dtype=torch.float32)
        v = torch.ones(2, 1, dtype=torch.float32)
        v = x * v
        return x, v

model = SimpleModel1()

model_onnx = torch.onnx._export(model, torch.tensor(5), "simple1.onnx",
        verbose=True,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
        example_outputs=(torch.zeros(2,2), torch.zeros(2,1)))

class SimpleModel2(torch.jit.ScriptModule):
    def __init__(self):
        super(SimpleModel2, self).__init__()

    @torch.jit.script_method
    def forward(self, dim : int):
        x = torch.ones([2, 2], dtype=torch.float32)
        v = torch.ones(2, 1, dtype=torch.float32)
        v = x * v
        return x, v

model = SimpleModel2()
print(model.graph)
model_onnx = torch.onnx._export(model, torch.tensor(5), "simple2.onnx",
        verbose=True,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
        example_outputs=(torch.zeros(2,2), torch.zeros(2,1)))

produces errors (snippet):

Traceback (most recent call last):
  File "unused_input.py", line 38, in <module>
    example_outputs=(torch.zeros(2,2), torch.zeros(2,1)))
  File "/Users/weidai/envs/tmp/lib/python3.7/site-packages/torch/onnx/__init__.py", line 22, in _export
    return utils._export(*args, **kwargs)
  File "/Users/weidai/envs/tmp/lib/python3.7/site-packages/torch/onnx/utils.py", line 281, in _export
    example_outputs, propagate)
  File "/Users/weidai/envs/tmp/lib/python3.7/site-packages/torch/onnx/utils.py", line 227, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type)
  File "/Users/weidai/envs/tmp/lib/python3.7/site-packages/torch/onnx/utils.py", line 137, in _optimize_graph
    torch._C._jit_pass_lint(graph)
RuntimeError: std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)) ASSERT FAILED at /Users/soumith/b101_2/2019_02_08/wheel_build_dirs/whee
l_3.7/pytorch/torch/csrc/jit/ir.cpp:438, please report a bug to PyTorch. (check_graph at /Users/soumith/b101_2/2019_02_08/wheel_build_dirs/
wheel_3.7/pytorch/torch/csrc/jit/ir.cpp:438)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > cons
t&) + 64 (0x1160effc0 in libc10.dylib)
frame #1: torch::jit::Graph::lint() const + 1576 (0x11b5718c8 in libtorch.1.dylib)
frame #2: void pybind11::cpp_function::initialize<void (*&)(std::__1::shared_ptr<torch::jit::Graph>&), void, std::__1::shared_ptr<torch::ji
t::Graph>&, pybind11::name, pybind11::scope, pybind11::sibling>(void (*&&&)(std::__1::shared_ptr<torch::jit::Graph>&), void (*)(std::__1::s
hared_ptr<torch::jit::Graph>&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::'lambda'(pybind11::detail::functi
on_call&)::operator()(pybind11::detail::function_call&) const + 106 (0x115c0728a in libtorch_python.dylib)
frame #3: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3531 (0x1158fde8b in libtorch_python.dylib)
frame #3: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3531 (0x1158fde8b in libtorch_python.dylib)
frame #4: _PyMethodDef_RawFastCallKeywords + 545 (0x108618668 in Python)
frame #5: _PyCFunction_FastCallKeywords + 44 (0x108617bd3 in Python)
frame #6: call_function + 636 (0x1086ad5f0 in Python)
frame #7: _PyEval_EvalFrameDefault + 7016 (0x1086a6231 in Python)
frame #8: function_code_fastcall + 112 (0x108617fae in Python)
frame #9: call_function + 753 (0x1086ad665 in Python)
frame #10: _PyEval_EvalFrameDefault + 7174 (0x1086a62cf in Python)
frame #11: _PyEval_EvalCodeWithName + 1835 (0x1086adef7 in Python)
frame #12: _PyFunction_FastCallKeywords + 225 (0x108617b98 in Python)
frame #13: call_function + 753 (0x1086ad665 in Python)
frame #14: _PyEval_EvalFrameDefault + 7174 (0x1086a62cf in Python)
frame #15: _PyEval_EvalCodeWithName + 1835 (0x1086adef7 in Python)
frame #16: _PyFunction_FastCallDict + 441 (0x108617801 in Python)
frame #17: _PyEval_EvalFrameDefault + 7807 (0x1086a6548 in Python)
frame #18: _PyEval_EvalCodeWithName + 1835 (0x1086adef7 in Python)
frame #19: _PyFunction_FastCallKeywords + 225 (0x108617b98 in Python)
frame #20: call_function + 753 (0x1086ad665 in Python)
frame #21: _PyEval_EvalFrameDefault + 7340 (0x1086a6375 in Python)
frame #22: _PyEval_EvalCodeWithName + 1835 (0x1086adef7 in Python)
frame #23: PyEval_EvalCode + 51 (0x1086a4626 in Python)
frame #24: run_mod + 54 (0x1086d32a5 in Python)
frame #25: PyRun_FileExFlags + 164 (0x1086d22c0 in Python)
frame #26: PyRun_SimpleFileExFlags + 266 (0x1086d197a in Python)
frame #27: pymain_main + 5614 (0x1086ea6a2 in Python)
frame #28: _Py_UnixMain + 56 (0x1086eaca4 in Python)
frame #29: start + 1 (0x7fff598f408d in libdyld.dylib)
frame #30: 0x0 + 2 (0x2 in ???)

Environment

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.1
GCC version: Could not collect
CMake version: version 3.13.4

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.1
[pip3] torch==1.0.1.post2
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

module: onnxRelated to torch.onnxoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions