@@ -534,6 +534,52 @@ def cpp_compile_command(
534534 ).strip ()
535535
536536
537+ class AotCodeCache :
538+ cache = dict ()
539+ clear = staticmethod (cache .clear )
540+
541+ @classmethod
542+ def compile (cls , source_code ):
543+ from .codegen .wrapper import CppWrapperCodeGen
544+
545+ # TODO: update cpp_compile_command for different platforms
546+ picked_vec_isa = pick_vec_isa ()
547+ key , input_path = write (
548+ source_code ,
549+ "cpp" ,
550+ code_hash (repr (cpp_compile_command ("i" , "o" , vec_isa = picked_vec_isa ))),
551+ )
552+ if key not in cls .cache :
553+ from filelock import FileLock
554+
555+ lock_dir = get_lock_dir ()
556+ lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
557+ with lock :
558+ output_so = (
559+ os .path .join (os .getcwd (), f"{ config .aot_codegen_output_prefix } .so" )
560+ if config .aot_codegen_output_prefix
561+ else f"{ input_path [:- 3 ]} .so"
562+ )
563+
564+ output_header = f"{ output_so [:- 3 ]} .h"
565+ with open (output_header , "w" ) as header_file :
566+ header_file .writelines ("#include <torch/torch.h>\n \n " )
567+ header_file .writelines (f"{ CppWrapperCodeGen .decl_str } ;\n " )
568+
569+ log .info (f"AOT-Inductor compiles code into: { output_so } " )
570+ if not os .path .exists (output_so ):
571+ cmd = cpp_compile_command (
572+ input = input_path , output = output_so , vec_isa = picked_vec_isa
573+ ).split (" " )
574+ try :
575+ subprocess .check_output (cmd , stderr = subprocess .STDOUT )
576+ except subprocess .CalledProcessError as e :
577+ raise exc .CppCompileError (cmd , e .output ) from e
578+
579+ cls .cache [key ] = output_so
580+ return cls .cache [key ]
581+
582+
537583class CppCodeCache :
538584 cache = dict ()
539585 clear = staticmethod (cache .clear )
0 commit comments