@@ -364,14 +364,14 @@ def reference_args(args):
364364
365365 def get_trace_outputs (declaration ):
366366 if declaration ['return_type' ] == 'std::vector<Tensor>' :
367- return 'flatten_tensor ({})' .format (declaration ['returns' ][0 ]['name' ])
367+ return 'flatten_tensor_args ({})' .format (declaration ['returns' ][0 ]['name' ])
368368 elif name .endswith ('_out' ):
369369 output_args = [arg ['name' ] for arg in arguments
370370 if arg .get ('output' , False )]
371371 return '{' + ', ' .join (output_args ) + '}'
372372 trace_outs = [r ['name' ] for r in declaration ['returns' ]]
373373 if any (ret ['dynamic_type' ] == 'TensorList' for ret in declaration ['returns' ]):
374- return CodeTemplate ("flatten_tensor ( ${outs} )" ).substitute (outs = trace_outs )
374+ return CodeTemplate ("flatten_tensor_args ( ${outs} )" ).substitute (outs = trace_outs )
375375 else :
376376 return CodeTemplate ("{ ${outs} }" ).substitute (outs = trace_outs )
377377
@@ -408,7 +408,7 @@ def emit_record_trace(env):
408408 local ['tensor_args' ] = [arg ['name' ] for arg in tensor_args ]
409409 if any (arg ['simple_type' ] == 'TensorList' for arg in tensor_args ):
410410 # Allocate a temporary vector with flatten and pass it in
411- local ['trace_inputs' ] = CodeTemplate ("flatten_tensor ( $tensor_args )" ).substitute (local )
411+ local ['trace_inputs' ] = CodeTemplate ("flatten_tensor_args ( $tensor_args )" ).substitute (local )
412412 else :
413413 local ['trace_inputs' ] = CodeTemplate ("{ ${tensor_args} }" ).substitute (local )
414414
@@ -496,7 +496,7 @@ def emit_history():
496496 fn = 'rebase' if modifies_arguments and not is_view else 'set'
497497 output_names = [r ['name' ] for r in differentiable_outputs ]
498498 # TODO: flatten allocates a std::vector, which could be expensive
499- outs = CodeTemplate ("flatten_tensor ( ${outs} )" ).substitute (outs = output_names )
499+ outs = CodeTemplate ("flatten_tensor_args ( ${outs} )" ).substitute (outs = output_names )
500500 return SET_HISTORY .substitute (fn = fn , differentiable_outputs = outs )
501501
502502 def emit_save_outputs ():
0 commit comments