Skip to content

Commit 24e7b82

Browse files
pritamdamaniafacebook-github-bot
authored andcommitted
Add metadata for torch jit TracedModules. (#17640)
Summary: Pull Request resolved: #17640 Pull Request resolved: #17311 I've extended our model metadata framework in this diff to support traced modules as well. Re-used a lot of components from the previous implementation of ScriptModule metadata. Tracing is a little different from Scripting since you can't just create a subclass of TopLevelTraceModule (type returned by torch.jit.trace) and attach metadata the way we did for ScriptModule. As a result, I've introduced a separate API torch.fb.jit_trace which returns an instance of TracedModuleWithMetadata which is a subclass of TopLevelTracedModule. As a result, we can now attach metadata to this instance. Reviewed By: dzhulgakov Differential Revision: D14117966 fbshipit-source-id: 3eee5eef733cb8d6a219c02e2f41d08698eca326
1 parent 320c697 commit 24e7b82

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch/jit/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,8 @@ def trace(func,
591591
check_trace=True,
592592
check_inputs=None,
593593
check_tolerance=1e-5,
594-
_force_outplace=False):
594+
_force_outplace=False,
595+
_module_class=None):
595596
"""
596597
Trace a function and return an executable trace that will be optimized
597598
using just-in-time compilation.
@@ -657,7 +658,10 @@ def trace(func,
657658
# done primarily so that weird iterables fail here and not pybind11 code
658659
elif not isinstance(example_inputs, tuple):
659660
example_inputs = tuple(example_inputs)
660-
module = TopLevelTracedModule(func, **executor_options)
661+
if _module_class:
662+
module = _module_class(func, **executor_options)
663+
else:
664+
module = TopLevelTracedModule(func, **executor_options)
661665
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
662666
module._create_method_from_trace('forward', func, example_inputs,
663667
var_lookup_fn, _force_outplace)

0 commit comments

Comments
 (0)