@@ -56,7 +56,7 @@ def set_training(model, mode):
5656def export (model , args , f , export_params = True , verbose = False , training = False ,
5757 input_names = None , output_names = None , aten = False , export_raw_ir = False ,
5858 operator_export_type = None , opset_version = None , _retain_param_name = True ,
59- do_constant_folding = False , strip_doc_string = True ):
59+ do_constant_folding = False , example_outputs = None , strip_doc_string = True ):
6060 r"""
6161 Export a model into ONNX format. This exporter runs your model
6262 once in order to get a trace of its execution to be exported;
@@ -112,6 +112,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
112112 optimization is applied to the model during export. Constant-folding
113113 optimization will replace some of the ops that have all constant
114114 inputs, with pre-computed constant nodes.
115+ example_outputs (tuple of Tensors, default None): example_outputs must be provided
116+ when exporting a ScriptModule or TorchScript Function.
115117 strip_doc_string (bool, default True): if True, strips the field
116118 "doc_string" from the exported model, which information about the stack
117119 trace.
@@ -128,7 +130,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
128130 _export (model , args , f , export_params , verbose , training , input_names , output_names ,
129131 operator_export_type = operator_export_type , opset_version = opset_version ,
130132 _retain_param_name = _retain_param_name , do_constant_folding = do_constant_folding ,
131- strip_doc_string = strip_doc_string )
133+ example_outputs = example_outputs , strip_doc_string = strip_doc_string )
132134
133135
134136# ONNX can't handle constants that are lists of tensors, which can
0 commit comments