@@ -299,28 +299,62 @@ function(caffe2_select_nvcc_arch_flags out_variable)
299299 unset (CUDA_ARCH_PTX CACHE )
300300 endif ()
301301
302- if ($ENV{TORCH_CUDA_ARCH_LIST} )
303- # Pass CUDA architecture directly
304- set (__cuda_arch_bin $ENV{TORCH_CUDA_ARCH_LIST} )
305- message (STATUS "Set CUDA arch from TORCH_CUDA_ARCH_LIST: ${__cuda_arch_bin} " )
306- elseif (${CUDA_ARCH_NAME} STREQUAL "Kepler" )
307- set (__cuda_arch_bin "30 35" )
308- elseif (${CUDA_ARCH_NAME} STREQUAL "Maxwell" )
309- set (__cuda_arch_bin "50" )
310- elseif (${CUDA_ARCH_NAME} STREQUAL "Pascal" )
311- set (__cuda_arch_bin "60 61" )
312- elseif (${CUDA_ARCH_NAME} STREQUAL "Volta" )
313- set (__cuda_arch_bin "70" )
314- elseif (${CUDA_ARCH_NAME} STREQUAL "All" )
315- set (__cuda_arch_bin ${Caffe2_known_gpu_archs} )
316- elseif (${CUDA_ARCH_NAME} STREQUAL "Manual" )
317- set (__cuda_arch_bin ${CUDA_ARCH_BIN} )
318- set (__cuda_arch_ptx ${CUDA_ARCH_PTX} )
319- elseif (${CUDA_ARCH_NAME} STREQUAL "Auto" )
320- caffe2_detect_installed_gpus(__cuda_arch_bin)
302+ set (CUDA_ARCH_LIST)
303+ if (DEFINED ENV{TORCH_CUDA_ARCH_LIST})
304+ set (TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST} )
305+ string (REGEX REPLACE "[ \t ]+" ";" TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST} " )
306+ list (APPEND CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST} )
307+ message (STATUS "Set CUDA arch from TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST} " )
321308 else ()
322- message (FATAL_ERROR "Invalid CUDA_ARCH_NAME" )
309+ list (APPEND CUDA_ARCH_LIST ${CUDA_ARCH_NAME} )
310+ message (STATUS "Set CUDA arch from CUDA_ARCH_NAME: ${CUDA_ARCH_NAME} " )
323311 endif ()
312+ list (REMOVE_DUPLICATES CUDA_ARCH_LIST)
313+
314+ set (__cuda_arch_bin)
315+ set (__cuda_arch_ptx)
316+ foreach (arch_name ${CUDA_ARCH_LIST} )
317+ set (arch_bin)
318+ set (arch_ptx)
319+ set (add_ptx FALSE )
320+ # Check to see if we are compiling PTX
321+ if (arch_name MATCHES "(.*)\\ +PTX$" )
322+ set (add_ptx TRUE )
323+ set (arch_name ${CMAKE_MATCH_1} )
324+ endif ()
325+ if (arch_name MATCHES "(^[0-9]\\ .[0-9](\\ ([0-9]\\ .[0-9]\\ ))?)$" )
326+ set (arch_bin ${CMAKE_MATCH_1} )
327+ set (arch_ptx ${arch_bin} )
328+ else ()
329+ # Look for it in our list of known architectures
330+ if (${arch_name} STREQUAL "Kepler" )
331+ set (arch_bin "30 35" )
332+ elseif (${arch_name} STREQUAL "Maxwell" )
333+ set (arch_bin "50" )
334+ elseif (${arch_name} STREQUAL "Pascal" )
335+ set (arch_bin "60 61" )
336+ elseif (${arch_name} STREQUAL "Volta" )
337+ set (arch_bin "70" )
338+ elseif (${arch_name} STREQUAL "All" )
339+ set (arch_bin ${Caffe2_known_gpu_archs} )
340+ elseif (${arch_name} STREQUAL "Manual" )
341+ set (arch_bin ${CUDA_ARCH_BIN} )
342+ set (arch_ptx ${CUDA_ARCH_PTX} )
343+ set (add_ptx TRUE )
344+ elseif (${arch_name} STREQUAL "Auto" )
345+ caffe2_detect_installed_gpus(arch_bin)
346+ else ()
347+ message (FATAL_ERROR "Unknown CUDA architecture name ${arch_name} " )
348+ endif ()
349+ endif ()
350+ list (APPEND __cuda_arch_bin ${arch_bin} )
351+ if (add_ptx)
352+ if (NOT arch_ptx)
353+ set (arch_ptx ${arch_bin} )
354+ endif ()
355+ list (APPEND __cuda_arch_ptx ${arch_ptx} )
356+ endif ()
357+ endforeach ()
324358
325359 # Remove dots and convert to lists
326360 string (REGEX REPLACE "\\ ." "" __cuda_arch_bin "${__cuda_arch_bin} " )
0 commit comments