@@ -64,6 +64,9 @@ def _find_cuda_home():
6464# it the below pattern.
6565BUILT_FROM_SOURCE_VERSION_PATTERN = re .compile (r'\d+\.\d+\.\d+\w+\+\w+' )
6666
67+ def is_binary_build ():
68+ return not BUILT_FROM_SOURCE_VERSION_PATTERN .match (torch .version .__version__ )
69+
6770
6871def check_compiler_abi_compatibility (compiler ):
6972 '''
@@ -77,7 +80,7 @@ def check_compiler_abi_compatibility(compiler):
7780 False if the compiler is (likely) ABI-incompatible with PyTorch,
7881 else True.
7982 '''
80- if BUILT_FROM_SOURCE_VERSION_PATTERN . match ( torch . version . __version__ ):
83+ if not is_binary_build ( ):
8184 return True
8285 try :
8386 check_cmd = '{}' if sys .platform == 'win32' else '{} --version'
@@ -134,6 +137,7 @@ def build_extensions(self):
134137 self ._check_abi ()
135138 for extension in self .extensions :
136139 self ._define_torch_extension_name (extension )
140+ self ._add_gnu_abi_flag_if_binary (extension )
137141
138142 # Register .cu and .cuh as valid source extensions.
139143 self .compiler .src_extensions += ['.cu' , '.cuh' ]
@@ -266,6 +270,21 @@ def _define_torch_extension_name(self, extension):
266270 else :
267271 extension .extra_compile_args .append (define )
268272
273+ def _add_gnu_abi_flag_if_binary (self , extension ):
274+ # If the version string looks like a binary build,
275+ # we know that PyTorch was compiled with gcc 4.9.2.
276+ # if the extension is compiled with gcc >= 5.1,
277+ # then we have to define _GLIBCXX_USE_CXX11_ABI=0
278+ # so that the std::string in the API is resolved to
279+ # non-C++11 symbols
280+ define = '-D_GLIBCXX_USE_CXX11_ABI=0'
281+ if is_binary_build ():
282+ if isinstance (extension .extra_compile_args , dict ):
283+ for args in extension .extra_compile_args .values ():
284+ args .append (define )
285+ else :
286+ extension .extra_compile_args .append (define )
287+
269288
270289def CppExtension (name , sources , * args , ** kwargs ):
271290 '''
@@ -785,6 +804,9 @@ def _write_ninja_file(path,
785804 common_cflags = ['-DTORCH_EXTENSION_NAME={}' .format (name )]
786805 common_cflags += ['-I{}' .format (include ) for include in includes ]
787806
807+ if is_binary_build ():
808+ common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=0' ]
809+
788810 cflags = common_cflags + ['-fPIC' , '-std=c++11' ] + extra_cflags
789811 if sys .platform == 'win32' :
790812 from distutils .spawn import _nt_quote_args
0 commit comments