Skip to content

Commit ce8c9d9

Browse files
peterjc123facebook-github-bot
authored andcommitted
Fix cuda detection script (#22527)
Summary: Fixes #22507 Pull Request resolved: #22527 Differential Revision: D16126220 Pulled By: ezyang fbshipit-source-id: eb05141282b0f058324da1b3d3cb34566f222a67
1 parent d4464d3 commit ce8c9d9

File tree

3 files changed

+24
-32
lines changed

3 files changed

+24
-32
lines changed

tools/setup_helpers/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34

45
def escape_path(path):
@@ -7,3 +8,17 @@ def escape_path(path):
78
if os.path.sep != '/' and path is not None:
89
return path.replace(os.path.sep, '/')
910
return path
11+
12+
13+
def which(thefile):
14+
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
15+
for d in path:
16+
fname = os.path.join(d, thefile)
17+
fnames = [fname]
18+
if sys.platform == 'win32':
19+
exts = os.environ.get('PATHEXT', '').split(os.pathsep)
20+
fnames += [fname + ext for ext in exts]
21+
for name in fnames:
22+
if os.access(name, os.F_OK | os.X_OK) and not os.path.isdir(name):
23+
return name
24+
return None

tools/setup_helpers/cmake.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import distutils.sysconfig
1212
from distutils.version import LooseVersion
1313

14-
from . import escape_path
14+
from . import escape_path, which
1515
from .env import (IS_64BIT, IS_DARWIN, IS_WINDOWS,
1616
DEBUG, REL_WITH_DEB_INFO,
1717
check_env_flag, check_negative_env_flag)
@@ -22,20 +22,6 @@
2222
from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR
2323

2424

25-
def _which(thefile):
26-
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
27-
for d in path:
28-
fname = os.path.join(d, thefile)
29-
fnames = [fname]
30-
if IS_WINDOWS:
31-
exts = os.environ.get('PATHEXT', '').split(os.pathsep)
32-
fnames += [fname + ext for ext in exts]
33-
for name in fnames:
34-
if os.access(name, os.F_OK | os.X_OK) and not os.path.isdir(name):
35-
return name
36-
return None
37-
38-
3925
def _mkdir_p(d):
4026
try:
4127
os.makedirs(d)
@@ -47,7 +33,7 @@ def _mkdir_p(d):
4733
# Use ninja if it is on the PATH. Previous version of PyTorch required the
4834
# ninja python package, but we no longer use it, so we do not have to import it
4935
USE_NINJA = (not check_negative_env_flag('USE_NINJA') and
50-
_which('ninja') is not None)
36+
which('ninja') is not None)
5137

5238

5339
class CMake:
@@ -80,9 +66,9 @@ def _get_cmake_command():
8066
cmake_command = 'cmake'
8167
if IS_WINDOWS:
8268
return cmake_command
83-
cmake3 = _which('cmake3')
69+
cmake3 = which('cmake3')
8470
if cmake3 is not None:
85-
cmake = _which('cmake')
71+
cmake = which('cmake')
8672
if cmake is not None:
8773
bare_version = CMake._get_version(cmake)
8874
if (bare_version < LooseVersion("3.5.0") and

tools/setup_helpers/cuda.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,18 @@
44
import ctypes.util
55
from subprocess import Popen, PIPE
66

7+
from . import escape_path, which
78
from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag, check_negative_env_flag
89

910
LINUX_HOME = '/usr/local/cuda'
1011
WINDOWS_HOME = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
1112

1213

1314
def find_nvcc():
14-
if IS_WINDOWS:
15-
proc = Popen(['where', 'nvcc.exe'], stdout=PIPE, stderr=PIPE)
16-
else:
17-
proc = Popen(['which', 'nvcc'], stdout=PIPE, stderr=PIPE)
18-
out, err = proc.communicate()
19-
out = out.decode().strip()
20-
if len(out) > 0:
21-
if IS_WINDOWS:
22-
if out.find('\r\n') != -1:
23-
out = out.split('\r\n')[0]
24-
out = os.path.abspath(os.path.join(os.path.dirname(out), ".."))
25-
out = out.replace('\\', '/')
26-
out = str(out)
27-
return os.path.dirname(out)
15+
nvcc = which('nvcc')
16+
if nvcc is not None:
17+
nvcc = escape_path(nvcc)
18+
return os.path.dirname(nvcc)
2819
else:
2920
return None
3021

0 commit comments

Comments
 (0)