Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/Error.h>
#include <ATen/RegisterCUDA.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/native/cuda/CuFFTPlanCache.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/detail/CUDAHooksInterface.h>

Expand Down Expand Up @@ -165,6 +166,38 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
#endif
}

int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize() const {
#ifndef __HIP_PLATFORM_HCC__
return at::native::detail::cufft_get_plan_cache_max_size_impl();
#else
AT_ERROR("cuFFT with HIP is not supported");
#endif
}

void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
#ifndef __HIP_PLATFORM_HCC__
at::native::detail::cufft_set_plan_cache_max_size_impl(max_size);
#else
AT_ERROR("cuFFT with HIP is not supported");
#endif
}

int64_t CUDAHooks::cuFFTGetPlanCacheSize() const {
#ifndef __HIP_PLATFORM_HCC__
return at::native::detail::cufft_get_plan_cache_size_impl();
#else
AT_ERROR("cuFFT with HIP is not supported");
#endif
}

void CUDAHooks::cuFFTClearPlanCache() const {
#ifndef __HIP_PLATFORM_HCC__
at::native::detail::cufft_clear_plan_cache_impl();
#else
AT_ERROR("cuFFT with HIP is not supported");
#endif
}

int CUDAHooks::getNumGPUs() const {
int count;
auto err = cudaGetDeviceCount(&count);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool supportsDilatedConvolutionWithCuDNN() const override;
long versionCuDNN() const override;
double batchnormMinEpsilonCuDNN() const override;
int64_t cuFFTGetPlanCacheMaxSize() const override;
void cuFFTSetPlanCacheMaxSize(int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize() const override;
void cuFFTClearPlanCache() const override;
int getNumGPUs() const override;
};

Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ struct AT_API CUDAHooksInterface {
"cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library");
}

virtual int64_t cuFFTGetPlanCacheMaxSize() const {
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
}

virtual void cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
}

virtual int64_t cuFFTGetPlanCacheSize() const {
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
}

virtual void cuFFTClearPlanCache() const {
AT_ERROR("cannot access cuFFT plan cache without ATen_cuda library");
}

virtual int getNumGPUs() const {
return 0;
}
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,9 @@ def native_get_return_types(option):
if isinstance(t_raw, string_type):
t = t_raw
name = None
elif t_raw is None:
t = 'void'
name = None
else:
t = t_raw['type']
name = t_raw['name']
Expand Down Expand Up @@ -967,14 +970,14 @@ def find_formal(formal_name, formals):
option['const_mark'] = '' if option['inplace'] else ' const'

is_method = 'method' in option['variants']
is_function = 'function' in option['variants']
is_namespace_function = 'function' in option['variants']
is_factory_method = find_formal('TensorOptions', formals)
is_deprecated_factory_method = \
formals[0]['dynamic_type'] == 'Type' and option['return_type'] == 'Tensor' and option['deprecated']
is_deprecated_factory_method = len(formals) > 0 and \
formals[0]['dynamic_type'] == 'Type' and \
option['return_type'] == 'Tensor' and option['deprecated']
needs_native_definition = not is_deprecated_factory_method

has_dispatch = dispatch_tensor or dispatch_type
is_namespace_function = is_function and (has_dispatch or is_factory_method)

This comment was marked as off-topic.


option['method_prefix_derived'] = ''
option['device_guard_declaration'] = device_guard(option, formals, is_factory_method)
Expand Down Expand Up @@ -1045,6 +1048,9 @@ def find_formal(formal_name, formals):
option['inferred_type'] = dispatch_type['name']
elif dispatch_tensor:
option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
else:
# doesn't depend on a specific type, use undefined float
option['inferred_type'] = 'at::getType(at::Backend::Undefined, at::ScalarType::Float)'
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
top_env['function_declarations'].append(declaration.substitute(env))
if is_factory_method:
Expand Down
25 changes: 22 additions & 3 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ATen/ATen.h"
#include "ATen/Config.h"
#include "ATen/NativeFunctions.h"
#include "ATen/detail/CUDAHooksInterface.h"
#include "ATen/native/SpectralOpsUtils.h"

#include <algorithm>
Expand Down Expand Up @@ -36,7 +37,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
throw std::runtime_error(ss.str());
}

auto signal_tensor_ndim = signal_ndim + static_cast<int>(complex_input); // add complex dim
auto signal_tensor_ndim = signal_ndim + static_cast<int64_t>(complex_input); // add complex dim
if (self.dim() < signal_tensor_ndim) {
std::ostringstream ss;
ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor "
Expand Down Expand Up @@ -83,7 +84,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
<< signal_ndim << "D, but got signal_sizes=" << signal_sizes;
throw std::runtime_error(ss.str());
}
std::vector<int64_t> output_sizes(signal_ndim + 1 + static_cast<int>(complex_output));
std::vector<int64_t> output_sizes(signal_ndim + 1 + static_cast<int64_t>(complex_output));
output_sizes[0] = input.size(0); // batch size
std::vector<int64_t> checked_signal_sizes(signal_ndim);
for (int64_t i = 0; i < signal_ndim; i++) {
Expand Down Expand Up @@ -133,7 +134,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
// slightly faster path for non-batch mode
output = output.squeeze(0);
} else if (batch_ndim > 1) {
auto output_ndim = self.dim() + static_cast<int>(complex_output) - static_cast<int>(complex_input);
auto output_ndim = self.dim() + static_cast<int64_t>(complex_output) - static_cast<int64_t>(complex_input);
std::vector<int64_t> unflatten_output_shape(output_ndim);
std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin());
std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim);
Expand All @@ -142,6 +143,24 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
return output;
}

// We call the following methods via CUDA hooks because they are really only
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
int64_t _cufft_get_plan_cache_max_size() {
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize();
}

void _cufft_set_plan_cache_max_size(int64_t max_size) {
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(max_size);
}

int64_t _cufft_get_plan_cache_size() {
return detail::getCUDAHooks().cuFFTGetPlanCacheSize();
}

void _cufft_clear_plan_cache() {
detail::getCUDAHooks().cuFFTClearPlanCache();
}

Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
return _fft(self, signal_ndim, /* complex_input */ true,
/* complex_output */ true, /* inverse */ false, {}, normalized,
Expand Down
Loading