Skip to content

Commit 9bf5e40

Browse files
ezyangsoumith
authored andcommitted
Refactor cudnn code layout / make build more robust. (#4201)
* Refactor cudnn code layout / make build more robust. When I previously moved cuDNN into ATen, I wasn't too familiar with the ATen native function directory layout, and so I did a number of suboptimal things. This commit fixes those problems. - If NO_CUDA was set but cuDNN is installed on your system, we'd incorrectly assume that CUDNN was enabled, to hilarious effect. - We now distinguish between cudnn implementation files and cudnn native function files. The native files now live in ATen/native/cudnn, and are *unconditionally compiled*, even when we are not building with cuDNN. This means that we can unconditionally declare cudnn functions in yaml and they are always available, even if they are broken. The cuDNN specific files live in 'cudnn', they are *never* installed, and they are used purely for implementation purposes. I had to add stub implementations of all ATen functions to achieve this. - I had written headers for at::native functions manually, but codegen will generate them for me automatically. So I deleted the headers. That lets me get rid of some header install logic as well. - There's a new note about ATen preprocessor philosophy.
1 parent 94ff31f commit 9bf5e40

File tree

19 files changed

+260
-224
lines changed

19 files changed

+260
-224
lines changed

aten/CMakeLists.txt

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,6 @@ ELSE()
158158
MESSAGE(STATUS "MAGMA not found. Compiling without MAGMA support")
159159
ENDIF()
160160

161-
find_package(CuDNN)
162-
IF(NOT CUDNN_FOUND)
163-
MESSAGE(STATUS "CuDNN not found. Compiling without CUDNN support")
164-
set(AT_CUDNN_ENABLED 0)
165-
ELSE()
166-
INCLUDE_DIRECTORIES(${CUDNN_INCLUDE_DIRS})
167-
set(AT_CUDNN_ENABLED 1)
168-
ENDIF()
169-
170-
171161
# ARM specific flags
172162
FIND_PACKAGE(ARM)
173163
IF (ASIMD_FOUND)
@@ -386,16 +376,22 @@ else()
386376
add_subdirectory(src/THCS)
387377
endif()
388378

379+
find_package(CuDNN)
380+
IF(NOT AT_CUDA_ENABLED OR NOT CUDNN_FOUND)
381+
MESSAGE(STATUS "CuDNN not found. Compiling without CUDNN support")
382+
set(AT_CUDNN_ENABLED 0)
383+
ELSE()
384+
INCLUDE_DIRECTORIES(${CUDNN_INCLUDE_DIRS})
385+
set(AT_CUDNN_ENABLED 1)
386+
ENDIF()
387+
389388
set(cwrap_files
390389
${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/Declarations.cwrap
391390
${CMAKE_CURRENT_SOURCE_DIR}/src/THNN/generic/THNN.h
392391
${CMAKE_CURRENT_SOURCE_DIR}/src/THCUNN/generic/THCUNN.h
393392
${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/nn.yaml
394393
${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/native/native_functions.yaml
395394
)
396-
if(CUDNN_FOUND)
397-
set(cwrap_files ${cwrap_files} ${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/cudnn/cuDNN.yaml)
398-
endif()
399395

400396
include_directories(
401397
${CMAKE_CURRENT_SOURCE_DIR}/src/THNN

aten/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,17 @@ of this processing occurs, it's helpful to look at the generated file
241241
ATen methods in a uniform manner. This file is utilized by PyTorch
242242
which further extends the ATen interface with support for automatic
243243
differentation.
244+
245+
#### Note [ATen preprocessor philosophy]
246+
247+
ATen is designed to be simple to use, and one of the things this implies is
248+
that it should not be necessary to use preprocessor macros when using ATen;
249+
we would rather provide all symbols, even for functionality that is not
250+
available on the system ATen is running on.
251+
252+
This means that internally inside ATen, whereas other libraries might
253+
simply omit source files for, e.g., CuDNN, when CuDNN libraries are not
254+
installed, ATen will always build these source files, compiling stub
255+
functions for anything that is not available. ATen never uses
256+
`AT_ENABLED_CUDA()` in header files, and all types in ATen's public API
257+
are always available no matter your build configuration.

aten/src/ATen/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ CONFIGURE_FILE(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
9494
FILE(GLOB base_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.h")
9595
FILE(GLOB base_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
9696
FILE(GLOB native_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/*.cpp")
97+
FILE(GLOB native_cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cudnn/*.cpp")
9798

98-
FILE(GLOB cudnn_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.h")
9999
FILE(GLOB cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.cpp")
100100

101101
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
@@ -140,7 +140,7 @@ ADD_CUSTOM_TARGET(aten_files_are_generated
140140
)
141141

142142

143-
SET(all_cpp ${base_cpp} ${native_cpp} ${generated_cpp} ${ATen_CPU_SRCS})
143+
SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS})
144144

145145
INCLUDE_DIRECTORIES(${ATen_CPU_INCLUDE})
146146
IF(NOT NO_CUDA)
@@ -283,9 +283,6 @@ INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
283283
FOREACH(HEADER ${base_h})
284284
INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen)
285285
ENDFOREACH()
286-
FOREACH(HEADER ${cudnn_h})
287-
INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/cudnn)
288-
ENDFOREACH()
289286
FOREACH(HEADER ${generated_h})
290287
INSTALL(FILES ${CMAKE_CURRENT_BINARY_DIR}/${HEADER}
291288
DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen)

aten/src/ATen/Config.h.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66

77
#define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@
88
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
9+
10+
#if !AT_CUDA_ENABLED() && AT_CUDNN_ENABLED()
11+
#error "Cannot enable CuDNN without CUDA"
12+
#endif

aten/src/ATen/cudnn/AffineGridGenerator.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

aten/src/ATen/cudnn/BatchNorm.h

Lines changed: 0 additions & 18 deletions
This file was deleted.

aten/src/ATen/cudnn/Conv.h

Lines changed: 0 additions & 67 deletions
This file was deleted.

aten/src/ATen/cudnn/GridSampler.h

Lines changed: 0 additions & 14 deletions
This file was deleted.

aten/src/ATen/cudnn/cuDNN.yaml

Lines changed: 0 additions & 59 deletions
This file was deleted.

aten/src/ATen/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def declare_outputs():
321321
def generate_outputs():
322322
cwrap_files = [f for f in files if f.endswith('.cwrap')]
323323
nn_files = [f for f in files if f.endswith('nn.yaml') or f.endswith('.h')]
324-
native_files = [f for f in files if f.endswith('native_functions.yaml') or f.endswith('cuDNN.yaml')]
324+
native_files = [f for f in files if f.endswith('native_functions.yaml')]
325325

326326
declarations = [d
327327
for file in cwrap_files

0 commit comments

Comments
 (0)