Skip to content

Commit c5b895a

Browse files
ssnlsoumith
authored andcommitted
Try to fix TORCH_CUDA_ARCH_LIST for PyTorch again (#7936)
* try again * use DEFINED * use a loop * Minor fixes
1 parent f8e83dc commit c5b895a

File tree

1 file changed

+54
-20
lines changed

1 file changed

+54
-20
lines changed

cmake/public/cuda.cmake

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -299,28 +299,62 @@ function(caffe2_select_nvcc_arch_flags out_variable)
299299
unset(CUDA_ARCH_PTX CACHE)
300300
endif()
301301

302-
if($ENV{TORCH_CUDA_ARCH_LIST})
303-
# Pass CUDA architecture directly
304-
set(__cuda_arch_bin $ENV{TORCH_CUDA_ARCH_LIST})
305-
message(STATUS "Set CUDA arch from TORCH_CUDA_ARCH_LIST: ${__cuda_arch_bin}")
306-
elseif(${CUDA_ARCH_NAME} STREQUAL "Kepler")
307-
set(__cuda_arch_bin "30 35")
308-
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
309-
set(__cuda_arch_bin "50")
310-
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
311-
set(__cuda_arch_bin "60 61")
312-
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
313-
set(__cuda_arch_bin "70")
314-
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
315-
set(__cuda_arch_bin ${Caffe2_known_gpu_archs})
316-
elseif(${CUDA_ARCH_NAME} STREQUAL "Manual")
317-
set(__cuda_arch_bin ${CUDA_ARCH_BIN})
318-
set(__cuda_arch_ptx ${CUDA_ARCH_PTX})
319-
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
320-
caffe2_detect_installed_gpus(__cuda_arch_bin)
302+
set(CUDA_ARCH_LIST)
303+
if(DEFINED ENV{TORCH_CUDA_ARCH_LIST})
304+
set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST})
305+
string(REGEX REPLACE "[ \t]+" ";" TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST}")
306+
list(APPEND CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
307+
message(STATUS "Set CUDA arch from TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
321308
else()
322-
message(FATAL_ERROR "Invalid CUDA_ARCH_NAME")
309+
list(APPEND CUDA_ARCH_LIST ${CUDA_ARCH_NAME})
310+
message(STATUS "Set CUDA arch from CUDA_ARCH_NAME: ${CUDA_ARCH_NAME}")
323311
endif()
312+
list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
313+
314+
set(__cuda_arch_bin)
315+
set(__cuda_arch_ptx)
316+
foreach(arch_name ${CUDA_ARCH_LIST})
317+
set(arch_bin)
318+
set(arch_ptx)
319+
set(add_ptx FALSE)
320+
# Check to see if we are compiling PTX
321+
if(arch_name MATCHES "(.*)\\+PTX$")
322+
set(add_ptx TRUE)
323+
set(arch_name ${CMAKE_MATCH_1})
324+
endif()
325+
if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$")
326+
set(arch_bin ${CMAKE_MATCH_1})
327+
set(arch_ptx ${arch_bin})
328+
else()
329+
# Look for it in our list of known architectures
330+
if(${arch_name} STREQUAL "Kepler")
331+
set(arch_bin "30 35")
332+
elseif(${arch_name} STREQUAL "Maxwell")
333+
set(arch_bin "50")
334+
elseif(${arch_name} STREQUAL "Pascal")
335+
set(arch_bin "60 61")
336+
elseif(${arch_name} STREQUAL "Volta")
337+
set(arch_bin "70")
338+
elseif(${arch_name} STREQUAL "All")
339+
set(arch_bin ${Caffe2_known_gpu_archs})
340+
elseif(${arch_name} STREQUAL "Manual")
341+
set(arch_bin ${CUDA_ARCH_BIN})
342+
set(arch_ptx ${CUDA_ARCH_PTX})
343+
set(add_ptx TRUE)
344+
elseif(${arch_name} STREQUAL "Auto")
345+
caffe2_detect_installed_gpus(arch_bin)
346+
else()
347+
message(FATAL_ERROR "Unknown CUDA architecture name ${arch_name}")
348+
endif()
349+
endif()
350+
list(APPEND __cuda_arch_bin ${arch_bin})
351+
if(add_ptx)
352+
if (NOT arch_ptx)
353+
set(arch_ptx ${arch_bin})
354+
endif()
355+
list(APPEND __cuda_arch_ptx ${arch_ptx})
356+
endif()
357+
endforeach()
324358

325359
# Remove dots and convert to lists
326360
string(REGEX REPLACE "\\." "" __cuda_arch_bin "${__cuda_arch_bin}")

0 commit comments

Comments
 (0)