Skip to content

Commit e6c7b38

Browse files
authored
Cache cufft plans (#8344)
* cache cufft plans * use an LRU cache * suffix CuFFTParams members with _ * import print_function for py2 * lint * fix potential race; add dummy impl for CPU only builds * cpp formatting; remove nccl makefile change * Use CUDA hooks instead * comments and doc * update the error message * move LRU cachae to a separate file and native::detail namespace * update comment * specify NOTE location in CuFFTPlanCache.h * update disabled_features.yaml to make amd ci work * another fix for AMD CI in disabled_features.yaml * Wrap cufft_plan_cache_* methods in __HIP_PLATFORM_HCC__ * improve the notes * lint * revert onnx change * put back inlining for CUFFT_CHECK
1 parent fed44cb commit e6c7b38

21 files changed

+834
-278
lines changed

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/Error.h>
66
#include <ATen/RegisterCUDA.h>
77
#include <ATen/cuda/CUDAConfig.h>
8+
#include <ATen/native/cuda/CuFFTPlanCache.h>
89
#include <ATen/cuda/PinnedMemoryAllocator.h>
910
#include <ATen/detail/CUDAHooksInterface.h>
1011

@@ -165,6 +166,38 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
165166
#endif
166167
}
167168

169+
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize() const {
170+
#ifndef __HIP_PLATFORM_HCC__
171+
return at::native::detail::cufft_get_plan_cache_max_size_impl();
172+
#else
173+
AT_ERROR("cuFFT with HIP is not supported");
174+
#endif
175+
}
176+
177+
void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
178+
#ifndef __HIP_PLATFORM_HCC__
179+
at::native::detail::cufft_set_plan_cache_max_size_impl(max_size);
180+
#else
181+
AT_ERROR("cuFFT with HIP is not supported");
182+
#endif
183+
}
184+
185+
int64_t CUDAHooks::cuFFTGetPlanCacheSize() const {
186+
#ifndef __HIP_PLATFORM_HCC__
187+
return at::native::detail::cufft_get_plan_cache_size_impl();
188+
#else
189+
AT_ERROR("cuFFT with HIP is not supported");
190+
#endif
191+
}
192+
193+
void CUDAHooks::cuFFTClearPlanCache() const {
194+
#ifndef __HIP_PLATFORM_HCC__
195+
at::native::detail::cufft_clear_plan_cache_impl();
196+
#else
197+
AT_ERROR("cuFFT with HIP is not supported");
198+
#endif
199+
}
200+
168201
int CUDAHooks::getNumGPUs() const {
169202
int count;
170203
auto err = cudaGetDeviceCount(&count);

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ struct CUDAHooks : public at::CUDAHooksInterface {
2525
bool supportsDilatedConvolutionWithCuDNN() const override;
2626
long versionCuDNN() const override;
2727
double batchnormMinEpsilonCuDNN() const override;
28+
int64_t cuFFTGetPlanCacheMaxSize() const override;
29+
void cuFFTSetPlanCacheMaxSize(int64_t max_size) const override;
30+
int64_t cuFFTGetPlanCacheSize() const override;
31+
void cuFFTClearPlanCache() const override;
2832
int getNumGPUs() const override;
2933
};
3034

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ struct AT_API CUDAHooksInterface {
108108
"cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library");
109109
}
110110

111+
virtual int64_t cuFFTGetPlanCacheMaxSize() const {
112+
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
113+
}
114+
115+
virtual void cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
116+
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
117+
}
118+
119+
virtual int64_t cuFFTGetPlanCacheSize() const {
120+
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
121+
}
122+
123+
virtual void cuFFTClearPlanCache() const {
124+
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
125+
}
126+
111127
virtual int getNumGPUs() const {
112128
return 0;
113129
}

aten/src/ATen/function_wrapper.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,9 @@ def native_get_return_types(option):
906906
if isinstance(t_raw, string_type):
907907
t = t_raw
908908
name = None
909+
elif t_raw is None:
910+
t = 'void'
911+
name = None
909912
else:
910913
t = t_raw['type']
911914
name = t_raw['name']
@@ -967,14 +970,14 @@ def find_formal(formal_name, formals):
967970
option['const_mark'] = '' if option['inplace'] else ' const'
968971

969972
is_method = 'method' in option['variants']
970-
is_function = 'function' in option['variants']
973+
is_namespace_function = 'function' in option['variants']
971974
is_factory_method = find_formal('TensorOptions', formals)
972-
is_deprecated_factory_method = \
973-
formals[0]['dynamic_type'] == 'Type' and option['return_type'] == 'Tensor' and option['deprecated']
975+
is_deprecated_factory_method = len(formals) > 0 and \
976+
formals[0]['dynamic_type'] == 'Type' and \
977+
option['return_type'] == 'Tensor' and option['deprecated']
974978
needs_native_definition = not is_deprecated_factory_method
975979

976980
has_dispatch = dispatch_tensor or dispatch_type
977-
is_namespace_function = is_function and (has_dispatch or is_factory_method)
978981

979982
option['method_prefix_derived'] = ''
980983
option['device_guard_declaration'] = device_guard(option, formals, is_factory_method)
@@ -1045,6 +1048,9 @@ def find_formal(formal_name, formals):
10451048
option['inferred_type'] = dispatch_type['name']
10461049
elif dispatch_tensor:
10471050
option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
1051+
else:
1052+
# doesn't depend on a specific type, use undefined float
1053+
option['inferred_type'] = 'at::getType(at::Backend::Undefined, at::ScalarType::Float)'
10481054
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
10491055
top_env['function_declarations'].append(declaration.substitute(env))
10501056
if is_factory_method:

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ATen/ATen.h"
88
#include "ATen/Config.h"
99
#include "ATen/NativeFunctions.h"
10+
#include "ATen/detail/CUDAHooksInterface.h"
1011
#include "ATen/native/SpectralOpsUtils.h"
1112

1213
#include <algorithm>
@@ -36,7 +37,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
3637
throw std::runtime_error(ss.str());
3738
}
3839

39-
auto signal_tensor_ndim = signal_ndim + static_cast<int>(complex_input); // add complex dim
40+
auto signal_tensor_ndim = signal_ndim + static_cast<int64_t>(complex_input); // add complex dim
4041
if (self.dim() < signal_tensor_ndim) {
4142
std::ostringstream ss;
4243
ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor "
@@ -83,7 +84,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
8384
<< signal_ndim << "D, but got signal_sizes=" << signal_sizes;
8485
throw std::runtime_error(ss.str());
8586
}
86-
std::vector<int64_t> output_sizes(signal_ndim + 1 + static_cast<int>(complex_output));
87+
std::vector<int64_t> output_sizes(signal_ndim + 1 + static_cast<int64_t>(complex_output));
8788
output_sizes[0] = input.size(0); // batch size
8889
std::vector<int64_t> checked_signal_sizes(signal_ndim);
8990
for (int64_t i = 0; i < signal_ndim; i++) {
@@ -133,7 +134,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
133134
// slightly faster path for non-batch mode
134135
output = output.squeeze(0);
135136
} else if (batch_ndim > 1) {
136-
auto output_ndim = self.dim() + static_cast<int>(complex_output) - static_cast<int>(complex_input);
137+
auto output_ndim = self.dim() + static_cast<int64_t>(complex_output) - static_cast<int64_t>(complex_input);
137138
std::vector<int64_t> unflatten_output_shape(output_ndim);
138139
std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin());
139140
std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim);
@@ -142,6 +143,24 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
142143
return output;
143144
}
144145

146+
// We call the following methods via CUDA hooks because they are really only
147+
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
148+
int64_t _cufft_get_plan_cache_max_size() {
149+
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize();
150+
}
151+
152+
void _cufft_set_plan_cache_max_size(int64_t max_size) {
153+
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(max_size);
154+
}
155+
156+
int64_t _cufft_get_plan_cache_size() {
157+
return detail::getCUDAHooks().cuFFTGetPlanCacheSize();
158+
}
159+
160+
void _cufft_clear_plan_cache() {
161+
detail::getCUDAHooks().cuFFTClearPlanCache();
162+
}
163+
145164
Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
146165
return _fft(self, signal_ndim, /* complex_input */ true,
147166
/* complex_output */ true, /* inverse */ false, {}, normalized,

0 commit comments

Comments
 (0)