@@ -116,6 +116,8 @@ def load(name,
116116 The loaded PyTorch extension as a Python module.
117117 '''
118118
119+ verify_ninja_availability ()
120+
119121 # Allows sources to be a single path or a list of paths.
120122 if isinstance (sources , str ):
121123 sources = [sources ]
@@ -140,6 +142,14 @@ def load(name,
140142 return _import_module_from_library (name , build_directory )
141143
142144
145+ def verify_ninja_availability ():
146+ with open (os .devnull , 'wb' ) as devnull :
147+ try :
148+ subprocess .check_call ('ninja --version' .split (), stdout = devnull )
149+ except OSError :
150+ raise RuntimeError ("Ninja is required to load C++ extensions" )
151+
152+
143153def _get_build_directory (name , verbose ):
144154 root_extensions_directory = os .environ .get ('TORCH_EXTENSIONS_DIR' )
145155 if root_extensions_directory is None :
@@ -183,64 +193,60 @@ def _import_module_from_library(module_name, path):
183193
184194def _write_ninja_file (path , name , sources , extra_cflags , extra_ldflags ,
185195 extra_include_paths ):
186- try :
187- import ninja
188- except ImportError :
189- raise RuntimeError ("Ninja is required to load C++ extensions. "
190- "Install it with 'pip install ninja'." )
196+ # Version 1.3 is required for the `deps` directive.
197+ config = ['ninja_required_version = 1.3' ]
198+ config .append ('cxx = {}' .format (os .environ .get ('CXX' , 'c++' )))
199+
200+ # Turn into absolute paths so we can emit them into the ninja build
201+ # file wherever it is.
202+ sources = [os .path .abspath (file ) for file in sources ]
203+ includes = [os .path .abspath (file ) for file in extra_include_paths ]
204+
205+ # include_paths() gives us the location of torch/torch.h
206+ includes += include_paths ()
207+ # sysconfig.get_paths()['include'] gives us the location of Python.h
208+ includes .append (sysconfig .get_paths ()['include' ])
209+
210+ cflags = ['-fPIC' , '-std=c++11' ]
211+ cflags += ['-I{}' .format (include ) for include in includes ]
212+ cflags += extra_cflags
213+ flags = ['cflags = {}' .format (' ' .join (cflags ))]
214+
215+ ldflags = ['-shared' ] + extra_ldflags
216+ # The darwin linker needs explicit consent to ignore unresolved symbols
217+ if sys .platform == 'darwin' :
218+ ldflags .append ('-undefined dynamic_lookup' )
219+ flags .append ('ldflags = {}' .format (' ' .join (ldflags )))
220+
221+ # See https://ninja-build.org/build.ninja.html for reference.
222+ compile_rule = ['rule compile' ]
223+ compile_rule .append (
224+ ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out' )
225+ compile_rule .append (' depfile = $out.d' )
226+ compile_rule .append (' deps = gcc' )
227+ compile_rule .append ('' )
228+
229+ link_rule = ['rule link' ]
230+ link_rule .append (' command = $cxx $ldflags $in -o $out' )
231+
232+ # Emit one build rule per source to enable incremental build.
233+ object_files = []
234+ build = []
235+ for source_file in sources :
236+ # '/path/to/file.cpp' -> 'file'
237+ file_name = os .path .splitext (os .path .basename (source_file ))[0 ]
238+ target = '{}.o' .format (file_name )
239+ object_files .append (target )
240+ build .append ('build {}: compile {}' .format (target , source_file ))
241+
242+ library_target = '{}.so' .format (name )
243+ link = ['build {}: link {}' .format (library_target , ' ' .join (object_files ))]
244+
245+ default = ['default {}' .format (library_target )]
246+
247+ # 'Blocks' should be separated by newlines, for visual benefit.
248+ blocks = [config , flags , compile_rule , link_rule , build , link , default ]
191249 with open (path , 'w' ) as build_file :
192- writer = ninja .Writer (build_file )
193- # Version 1.3 is required for the `deps` directive.
194- writer .variable ('ninja_required_version' , '1.3' )
195- writer .variable ('cxx' , os .environ .get ('CXX' , 'c++' ))
196- writer .newline ()
197-
198- # Turn into absolute paths so we can emit them into the ninja build
199- # file wherever it is.
200- sources = [os .path .abspath (file ) for file in sources ]
201- includes = [os .path .abspath (file ) for file in extra_include_paths ]
202-
203- # include_paths() gives us the location of torch/torch.h
204- includes += include_paths ()
205- # sysconfig.get_paths()['include'] gives us the location of Python.h
206- includes .append (sysconfig .get_paths ()['include' ])
207-
208- cflags = ['-fPIC' , '-std=c++11' ]
209- cflags += ['-I{}' .format (include ) for include in includes ]
210- cflags += extra_cflags
211- writer .variable ('cflags' , ' ' .join (cflags ))
212-
213- ldflags = ['-shared' ] + extra_ldflags
214- # The darwin linker needs explicit consent to ignore unresolved symbols
215- if sys .platform == 'darwin' :
216- ldflags .append ('-undefined dynamic_lookup' )
217- writer .variable ('ldflags' , ' ' .join (ldflags ))
218- writer .newline ()
219-
220- # See https://ninja-build.org/build.ninja.html for reference.
221- writer .rule (
222- 'compile' ,
223- command = '$cxx -MMD -MF $out.d $cflags -c $in -o $out' ,
224- depfile = '$out.d' ,
225- deps = 'gcc' )
226- writer .newline ()
227-
228- writer .rule ('link' , command = '$cxx $ldflags $in -o $out' )
229- writer .newline ()
230-
231- # Emit one build rule per source to enable incremental build.
232- object_files = []
233- for source_file in sources :
234- # '/path/to/file.cpp' -> 'file'
235- file_name = os .path .splitext (os .path .basename (source_file ))[0 ]
236- target = '{}.o' .format (file_name )
237- object_files .append (target )
238- writer .build (outputs = target , rule = 'compile' , inputs = source_file )
239- writer .newline ()
240-
241- library_target = '{}.so' .format (name )
242- writer .build (outputs = library_target , rule = 'link' , inputs = object_files )
243- writer .newline ()
244-
245- writer .default (library_target )
246- writer .close ()
250+ for block in blocks :
251+ lines = '\n ' .join (block )
252+ build_file .write ('{}\n \n ' .format (lines ))
0 commit comments