Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions src/backend/cuda/cudnnModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ auto cudnnVersionComponents(size_t version) {
return make_tuple(major, minor, patch);
}

auto cudaRuntimeVersionComponents(size_t version) {
auto major = version / 1000;
auto minor = (version - (major * 1000)) / 10;
return make_tuple(major, minor);
}

cudnnModule::cudnnModule()
: module({"cudnn"}, {"", "64_7", "64_8", "64_6", "64_5", "64_4"}, {""},
cudnnVersions.size(), cudnnVersions.data()) {
Expand Down Expand Up @@ -83,8 +89,7 @@ cudnnModule::cudnnModule()
major, minor);
}

std::tie(rtmajor, rtminor, std::ignore) =
cudnnVersionComponents(cudnn_rtversion);
std::tie(rtmajor, rtminor) = cudaRuntimeVersionComponents(cudnn_rtversion);

AF_TRACE("cuDNN Version: {}.{}.{} cuDNN CUDA Runtime: {}.{}", major, minor,
patch, rtmajor, rtminor);
Expand All @@ -101,15 +106,16 @@ cudnnModule::cudnnModule()

int afcuda_runtime = 0;
cudaRuntimeGetVersion(&afcuda_runtime);
if (afcuda_runtime != static_cast<int>(cudnn_version)) {
if (afcuda_runtime != static_cast<int>(cudnn_rtversion)) {
getLogger()->warn(
"WARNING: ArrayFire CUDA Runtime({}) and cuDNN CUDA "
"Runtime({}.{}) do not match. For maximum compatibility, make sure "
"Runtime({}) do not match. For maximum compatibility, make sure "
"the two versions match.(Ignoring check)",
// NOTE: the int version formats from CUDA and cuDNN are different
// so we are using int_version_to_string for the ArrayFire CUDA
// runtime
int_version_to_string(afcuda_runtime), rtmajor, rtminor);
int_version_to_string(afcuda_runtime),
int_version_to_string(cudnn_rtversion));
}

MODULE_FUNCTION_INIT(cudnnConvolutionBackwardData);
Expand All @@ -135,16 +141,6 @@ cudnnModule::cudnnModule()
MODULE_FUNCTION_INIT(cudnnSetStream);
MODULE_FUNCTION_INIT(cudnnSetTensor4dDescriptor);

// Check to see if the cuDNN runtime is compatible with the current device
cudaDeviceProp prop = getDeviceProp(getActiveDeviceId());
if (!checkDeviceWithRuntime(cudnn_rtversion, {prop.major, prop.minor})) {
string error_message = fmt::format(
"Error: cuDNN CUDA Runtime({}.{}) does not support the "
"current device's compute capability(sm_{}{}).",
rtmajor, rtminor, prop.major, prop.minor);
AF_ERROR(error_message, AF_ERR_RUNTIME);
}

if (!module.symbolsLoaded()) {
string error_message =
"Error loading cuDNN symbols. ArrayFire was unable to load some "
Expand Down