Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3153,8 +3153,32 @@ def checkScriptRaisesRegex(self, script, inputs, exception, regex,
ge = torch.jit.script(script, optimize)
ge(*inputs)

def test_submodule_twice(self):
def test_tracing_multiple_methods(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
return self.conv(x)

def weighted_kernel_sum(self, weight):
return weight * self.conv.weight

example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
n = Net()
module = torch.jit.trace_module(n, inputs)

check_inputs = []
for i in range(2):
check_weight = torch.rand(1, 1, 3, 3)
check_forward_input = torch.rand(1, 1, 3, 3)
check_inputs.append({'forward' : check_forward_input, 'weighted_kernel_sum' : check_weight})
module = torch.jit.trace_module(n, inputs, True, True, check_inputs)

def test_submodule_twice(self):
@torch.jit.script
def foo(x):
return x * x
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ void initJitScriptBindings(PyObject* module) {
return tensors;
})
.def_property_readonly("schema", &Method::getSchema)
.def_property_readonly("name", &Method::name)
.def_property_readonly("code", [](Method& self) {
std::ostringstream ss;
std::vector<at::Tensor> tensors;
Expand Down
166 changes: 148 additions & 18 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,23 +469,41 @@ def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):

# Check the traced module against a set of user-provided validation inputs
@torch.no_grad()
def _check_trace(check_inputs, func, executor_options, module, check_tolerance, force_outplace):
def _check_trace(check_inputs, func, executor_options, traced_func, check_tolerance, force_outplace, is_trace_module):
# Note: tracing is independent of optimizations, which consume the trace
executor_options['optimize'] = False
for inputs in check_inputs:

if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
check_mod = torch.jit.trace(
func,
_clone_inputs(inputs),
check_trace=False,
_force_outplace=force_outplace,
**executor_options)

if is_trace_module:
copied_dict = {}
for name, data in inputs.items():
copied_dict[name] = _clone_inputs(data)
check_mod = torch.jit.trace_module(
func.__self__,
copied_dict,
check_trace=False,
_force_outplace=force_outplace,
**executor_options)
check_mod_func = check_mod._c._get_method(traced_func.name)
inputs = inputs[traced_func.name]
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
else:
check_mod = torch.jit.trace(
func,
_clone_inputs(inputs),
check_trace=False,
_force_outplace=force_outplace,
**executor_options)
check_mod_func = check_mod

def graph_diagnostic_info():
mod_canonicalized = torch._C._jit_pass_canonicalize(module.graph)
mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
torch._C._jit_pass_erase_shape_information(mod_canonicalized)
check_canonicalized = torch._C._jit_pass_canonicalize(check_mod.graph)
check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
torch._C._jit_pass_erase_shape_information(check_canonicalized)

graph_diff_errors = None
Expand Down Expand Up @@ -558,7 +576,7 @@ def maybe_warn_nondeterministic():
if has_warned[0]:
return
has_warned[0] = True
nondeterm_ops = [op for op in module.graph.nodes() if op.isNondeterministic()]
nondeterm_ops = [op for op in traced_func.graph.nodes() if op.isNondeterministic()]
if len(nondeterm_ops) > 0:
nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
nondeterministic_ops_warning += "Did you forget call .eval() on your model? Nodes:\n"
Expand All @@ -582,10 +600,10 @@ def compare_outputs(original, reference, match_what):

return all_ok

traced_outs = run_mod_and_filter_tensor_outputs(module, inputs, 'trace')
traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, 'trace')
fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, 'Python function')
if compare_outputs(traced_outs, fn_outs, 'Python function'):
check_outs = run_mod_and_filter_tensor_outputs(check_mod, inputs, 'repeated trace')
check_outs = run_mod_and_filter_tensor_outputs(check_mod_func, inputs, 'repeated trace')
compare_outputs(traced_outs, check_outs, 'repeated trace')

diag_info = graph_diagnostic_info()
Expand All @@ -599,12 +617,23 @@ def ignore_lib_warnings():
# We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
warnings.filterwarnings('ignore', category=TracerWarning, module='torch.(?!jit)')


# We ignore the tracer warnings coming form inside the library, because all our shape
# checks in nn will trigger them.
TracerWarning.ignore_lib_warnings()
torch._C._tracer_warn_use_python()

def make_tuple(example_inputs):
if isinstance(example_inputs, (torch.Tensor, dict)):
return (example_inputs,)
# done primarily so that weird iterables fail here and not pybind11 code
if not isinstance(example_inputs, tuple):
return tuple(example_inputs)
return example_inputs

def make_module(mod, _module_class, executor_options):
if _module_class is None:
_module_class = TopLevelTracedModule
return _module_class(mod, **executor_options)

def trace(func,
example_inputs,
Expand Down Expand Up @@ -665,10 +694,22 @@ def trace(func,
sub-modules and parameters as ``func``.

Example::
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def f(x):
return x * 2
traced_f = torch.jit.trace(f, torch.rand(1))
def forward(self, x):
return self.conv(x)

def weighted_kernel_sum(self, weight):
return weight * self.conv.weight

example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
n = Net()
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

"""
if not _enabled:
Expand Down Expand Up @@ -699,12 +740,101 @@ def f(x):
# Check the trace against new traces created from user-specified inputs
if check_trace:
if check_inputs is not None:
_check_trace(check_inputs, func, executor_options, traced, check_tolerance, _force_outplace)
_check_trace(check_inputs, func, executor_options, traced, check_tolerance, _force_outplace, False)
else:
_check_trace([example_inputs], func, executor_options, traced, check_tolerance, _force_outplace)
_check_trace([example_inputs], func, executor_options, traced, check_tolerance, _force_outplace, False)

return traced

def trace_module(mod,
inputs,
optimize=True,
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
_force_outplace=False,
_module_class=None):
"""
Trace a function and return an executable ``ScriptModule`` that will be optimized
using just-in-time compilation.

.. warning::

Tracing only correctly records functions and modules which are not data
dependent (e.g., do not have conditionals on data in tensors) and do not have
any untracked external dependencies (e.g., perform input/output or
access global variables). If you trace such models, you may silently get
incorrect results on subsequent invocations of the model. The tracer
will try to emit warnings when doing something that may cause an
incorrect trace to be produced.

Arguments:
mod (torch.nn.Module): a ``torch.nn.Module`` containing methods whose names are
specified in ``example_inputs``. The given methods will be compiled
as a part of a single `ScriptModule`
example_inputs (dict): a dict containing sample inputs indexed by method names in ``mod``
The inputs will be passed to methods whose names correspond to inputs'
keys while tracing.
``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
Keyword arguments:
optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
check_trace (bool, optional): check if the same inputs run through
traced code produce the same outputs. Default: ``True``. You might want
to disable this if, for example, your network contains non-
deterministic ops or if you are sure that the network is correct despite
a checker failure.

check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
to check the trace against what is expected. Each tuple
is equivalent to a set of input arguments that would
be specified in ``example_inputs``. For best results, pass in a
set of checking inputs representative of the space of
shapes and types of inputs you expect the network to see.
If not specified, the original ``example_inputs`` are used for checking
check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
This can be used to relax the checker strictness in the event that
results diverge numerically for a known reason, such as operator fusion.

Returns:
A ``ScriptModule`` object with a single ``forward()`` method containing the traced code.
When ``func`` is a ``torch.nn.Module``, the returned ``ScriptModule`` will have the same set of
sub-modules and parameters as ``func``.

Example::

def f(x):
return x * 2
traced_f = torch.jit.trace(f, torch.rand(1))

"""
if not _enabled:
return mod
executor_options = {'optimize': bool(optimize)}
var_lookup_fn = _create_interpreter_name_lookup_fn(0)

if not isinstance(mod, torch.nn.Module):
raise AttributeError("expected torch.nn.Module as the first argument")

if not isinstance(inputs, dict):
raise AttributeError("expected a dictionary of (method_name, input) pairs")

module = make_module(mod, _module_class, executor_options)

for method_name, example_inputs in inputs.items():

func = getattr(mod, method_name)
example_inputs = make_tuple(example_inputs)
module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
check_trace_method = module._c._get_method(method_name)

# Check the trace against new traces created from user-specified inputs
if check_trace:
if check_inputs is not None:
_check_trace(check_inputs, func, executor_options, check_trace_method, check_tolerance, _force_outplace, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this line is covered by the tests. check_inputs would need to be different for each method but it looks like we are using the same inputs here.

else:
_check_trace([inputs], func, executor_options, check_trace_method, check_tolerance, _force_outplace, True)

return module

class CompilationUnit(object):
def __init__(self, lang=None, optimize=True, _frames_up=0):
Expand Down