Skip to content

Commit f08f222

Browse files
committed
add _GLIBCXX_USE_CXX11_ABI=0 to cpp_extensions when binary built pytorch is detected
1 parent 8f91617 commit f08f222

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

torch/utils/cpp_extension.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def _find_cuda_home():
6464
# it the below pattern.
6565
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
6666

67+
def is_binary_build():
68+
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
69+
6770

6871
def check_compiler_abi_compatibility(compiler):
6972
'''
@@ -77,7 +80,7 @@ def check_compiler_abi_compatibility(compiler):
7780
False if the compiler is (likely) ABI-incompatible with PyTorch,
7881
else True.
7982
'''
80-
if BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__):
83+
if not is_binary_build():
8184
return True
8285
try:
8386
check_cmd = '{}' if sys.platform == 'win32' else '{} --version'
@@ -134,6 +137,7 @@ def build_extensions(self):
134137
self._check_abi()
135138
for extension in self.extensions:
136139
self._define_torch_extension_name(extension)
140+
self._add_gnu_abi_flag_if_binary(extension)
137141

138142
# Register .cu and .cuh as valid source extensions.
139143
self.compiler.src_extensions += ['.cu', '.cuh']
@@ -266,6 +270,21 @@ def _define_torch_extension_name(self, extension):
266270
else:
267271
extension.extra_compile_args.append(define)
268272

273+
def _add_gnu_abi_flag_if_binary(self, extension):
274+
# If the version string looks like a binary build,
275+
# we know that PyTorch was compiled with gcc 4.9.2.
276+
# if the extension is compiled with gcc >= 5.1,
277+
# then we have to define _GLIBCXX_USE_CXX11_ABI=0
278+
# so that the std::string in the API is resolved to
279+
# non-C++11 symbols
280+
define = '-D_GLIBCXX_USE_CXX11_ABI=0'
281+
if is_binary_build():
282+
if isinstance(extension.extra_compile_args, dict):
283+
for args in extension.extra_compile_args.values():
284+
args.append(define)
285+
else:
286+
extension.extra_compile_args.append(define)
287+
269288

270289
def CppExtension(name, sources, *args, **kwargs):
271290
'''
@@ -785,6 +804,9 @@ def _write_ninja_file(path,
785804
common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)]
786805
common_cflags += ['-I{}'.format(include) for include in includes]
787806

807+
if is_binary_build():
808+
common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=0']
809+
788810
cflags = common_cflags + ['-fPIC', '-std=c++11'] + extra_cflags
789811
if sys.platform == 'win32':
790812
from distutils.spawn import _nt_quote_args

0 commit comments

Comments
 (0)