@@ -73,8 +73,12 @@ def write_outputs(self, filename):
7373 '' .join (name + ";" for name in sorted (self .filenames )))
7474 self .outputs_written = True
7575
76- def write (self , filename , s ):
76+ def write (self , filename , s , env = None ):
7777 filename = '{}/{}' .format (options .output_dir , filename )
78+ if isinstance (s , CodeTemplate ):
79+ assert env is not None
80+ env ['generated_comment' ] = "@" + "generated by aten/src/ATen/gen.py"
81+ s = s .substitute (env )
7882 self ._write_if_changed (filename , s )
7983 if filename not in self .filenames :
8084 self .undeclared_files .append (filename )
@@ -316,19 +320,19 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
316320 if density != 'Sparse' :
317321 # there are no special storage types for Sparse, they are composed
318322 # of Dense tensors
319- fm .write (env ['Storage' ] + ".cpp" , STORAGE_DERIVED_CPP . substitute ( env ) )
320- fm .write (env ['Storage' ] + ".h" , STORAGE_DERIVED_H . substitute ( env ) )
323+ fm .write (env ['Storage' ] + ".cpp" , STORAGE_DERIVED_CPP , env )
324+ fm .write (env ['Storage' ] + ".h" , STORAGE_DERIVED_H , env )
321325 env ['TensorDenseOrSparse' ] = TENSOR_DENSE_CPP .substitute (env )
322326 env ['THTensor_nDimension' ] = 'tensor->nDimension'
323327 else :
324328 env ['TensorDenseOrSparse' ] = TENSOR_SPARSE_CPP .substitute (env )
325329 env ['THTensor_nDimension' ] = 'tensor->nDimensionI + tensor->nDimensionV'
326330
327- fm .write (env ['Type' ] + ".cpp" , TYPE_DERIVED_CPP . substitute ( env ) )
328- fm .write (env ['Type' ] + ".h" , TYPE_DERIVED_H . substitute ( env ) )
331+ fm .write (env ['Type' ] + ".cpp" , TYPE_DERIVED_CPP , env )
332+ fm .write (env ['Type' ] + ".h" , TYPE_DERIVED_H , env )
329333
330- fm .write (env ['Tensor' ] + ".cpp" , TENSOR_DERIVED_CPP . substitute ( env ) )
331- fm .write (env ['Tensor' ] + ".h" , TENSOR_DERIVED_H . substitute ( env ) )
334+ fm .write (env ['Tensor' ] + ".cpp" , TENSOR_DERIVED_CPP , env )
335+ fm .write (env ['Tensor' ] + ".h" , TENSOR_DERIVED_H , env )
332336
333337 type_register = TYPE_REGISTER .substitute (backend = env ['Backend' ], scalar_type = scalar_name , type_name = env ['Type' ])
334338 if env ['DenseBackend' ] == 'CPU' :
@@ -410,7 +414,7 @@ def generate_outputs():
410414 fm = file_manager
411415 if env ['name' ] == 'CUDA' :
412416 fm = cuda_file_manager
413- fm .write (fname , GENERATOR_DERIVED . substitute ( env ) )
417+ fm .write (fname , GENERATOR_DERIVED , env )
414418
415419 # note: this will fill in top_env['type/tensor_method_declarations/definitions']
416420 # and modify the declarations to include any information that will all_backends
@@ -426,19 +430,19 @@ def generate_outputs():
426430 all_types .append (generate_storage_type_and_tensor (
427431 backend , density , scalar_type , declarations ))
428432
429- file_manager .write ('Type.h' , TYPE_H . substitute ( top_env ) )
430- file_manager .write ('Type.cpp' , TYPE_CPP . substitute ( top_env ) )
433+ file_manager .write ('Type.h' , TYPE_H , top_env )
434+ file_manager .write ('Type.cpp' , TYPE_CPP , top_env )
431435
432- cuda_file_manager .write ('RegisterCUDA.h' , REGISTER_CUDA_H . substitute ( top_env ) )
433- cuda_file_manager .write ('RegisterCUDA.cpp' , REGISTER_CUDA_CPP . substitute ( top_env ) )
436+ cuda_file_manager .write ('RegisterCUDA.h' , REGISTER_CUDA_H , top_env )
437+ cuda_file_manager .write ('RegisterCUDA.cpp' , REGISTER_CUDA_CPP , top_env )
434438
435- file_manager .write ('Tensor.h' , TENSOR_H . substitute ( top_env ) )
436- file_manager .write ('TensorMethods.h' , TENSOR_METHODS_H . substitute ( top_env ) )
437- file_manager .write ('Functions.h' , FUNCTIONS_H . substitute ( top_env ) )
439+ file_manager .write ('Tensor.h' , TENSOR_H , top_env )
440+ file_manager .write ('TensorMethods.h' , TENSOR_METHODS_H , top_env )
441+ file_manager .write ('Functions.h' , FUNCTIONS_H , top_env )
438442
439443 file_manager .write ('CPUCopy.cpp' , copy_wrapper .create (all_types , 'CPU' ))
440444 cuda_file_manager .write ('CUDACopy.cpp' , copy_wrapper .create (all_types , 'CUDA' ))
441- file_manager .write ('NativeFunctions.h' , NATIVE_FUNCTIONS_H . substitute ( top_env ) )
445+ file_manager .write ('NativeFunctions.h' , NATIVE_FUNCTIONS_H , top_env )
442446
443447 file_manager .check_all_files_written ()
444448 cuda_file_manager .check_all_files_written ()
0 commit comments