Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 81 additions & 18 deletions src/backend/common/DependencyModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <common/DependencyModule.hpp>
#include <common/Logger.hpp>
#include <common/module_loading.hpp>

#include <algorithm>
#include <string>

Expand All @@ -19,30 +20,73 @@
#include <dlfcn.h>
#endif

using common::Version;
using std::make_tuple;
using std::string;
using std::to_string;
using std::vector;

constexpr Version NullVersion{-1, -1, -1};

#ifdef OS_WIN
#include <Windows.h>

static const char* librarySuffix = ".dll";
static const char* libraryPrefix = "";

namespace {
vector<string> libNames(const std::string& name, const string& suffix,
const Version& ver = NullVersion) {
UNUSED(ver); // Windows DLL files are not version suffixed
return {name + suffix + librarySuffix};
}
} // namespace

#elif defined(OS_MAC)

static const char* librarySuffix = ".dylib";
static const char* libraryPrefix = "lib";

namespace {
vector<string> libNames(const std::string& name, const string& suffix,
const Version& ver = NullVersion) {
UNUSED(suffix);
const string noVerName = libraryPrefix + name + librarySuffix;
if (ver != NullVersion) {
const string infix = "." + to_string(std::get<0>(ver)) + ".";
return {libraryPrefix + name + infix + librarySuffix, noVerName};
} else {
return {noVerName};
}
}
} // namespace

#elif defined(OS_LNX)

static const char* librarySuffix = ".so";
static const char* libraryPrefix = "lib";
#else
#error "Unsupported platform"
#endif

using std::string;
using std::vector;

namespace {
vector<string> libNames(const std::string& name, const string& suffix,
const Version& ver = NullVersion) {
UNUSED(suffix);
const string noVerName = libraryPrefix + name + librarySuffix;
if (ver != NullVersion) {
const string soname("." + to_string(std::get<0>(ver)));

std::string libName(const std::string& name) {
return libraryPrefix + name + librarySuffix;
const string vsfx = "." + to_string(std::get<0>(ver)) + "." +
to_string(std::get<1>(ver)) + "." +
to_string(std::get<2>(ver));
return {noVerName + vsfx, noVerName + soname, noVerName};
} else {
return {noVerName};
}
}
} // namespace

#else
#error "Unsupported platform"
#endif

namespace common {

DependencyModule::DependencyModule(const char* plugin_file_name,
Expand All @@ -51,11 +95,11 @@ DependencyModule::DependencyModule(const char* plugin_file_name,
// TODO(umar): Implement handling of non-standard paths
UNUSED(paths);
if (plugin_file_name) {
string filename = libName(plugin_file_name);
AF_TRACE("Attempting to load: {}", filename);
handle = loadLibrary(filename.c_str());
auto fileNames = libNames(plugin_file_name, "");
AF_TRACE("Attempting to load: {}", fileNames[0]);
handle = loadLibrary(fileNames[0].c_str());
if (handle) {
AF_TRACE("Found: {}", filename);
AF_TRACE("Found: {}", fileNames[0]);
} else {
AF_TRACE("Unable to open {}", plugin_file_name);
}
Expand All @@ -64,17 +108,36 @@ DependencyModule::DependencyModule(const char* plugin_file_name,

DependencyModule::DependencyModule(const vector<string>& plugin_base_file_name,
const vector<string>& suffixes,
const vector<string>& paths)
const vector<string>& paths,
const size_t verListSize,
const Version* versions)
: handle(nullptr), logger(common::loggerFactory("platform")) {
for (const string& base_name : plugin_base_file_name) {
for (const string& path : paths) {
UNUSED(path);
for (const string& suffix : suffixes) {
string filename = libName(base_name + suffix);
AF_TRACE("Attempting to load: {}", filename);
handle = loadLibrary(filename.c_str());
#if !defined(OS_WIN)
// For a non-windows OS, i.e. most likely unix, shared library
// names have versions suffix based on the version. Lookup for
// libraries for given versions and proceed to a simple name
// lookup if versioned library is not found.
for (size_t v = 0; v < verListSize; v++) {
auto fileNames = libNames(base_name, suffix, versions[v]);
for (auto& fileName : fileNames) {
AF_TRACE("Attempting to load: {}", fileName);
handle = loadLibrary(fileName.c_str());
if (handle) {
AF_TRACE("Found: {}", fileName);
return;
}
}
}
#endif
auto fileNames = libNames(base_name, suffix);
AF_TRACE("Attempting to load: {}", fileNames[0]);
handle = loadLibrary(fileNames[0].c_str());
if (handle) {
AF_TRACE("Found: {}", filename);
AF_TRACE("Found: {}", fileNames[0]);
return;
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/backend/common/DependencyModule.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
********************************************************/

#pragma once

#include <common/Logger.hpp>
#include <common/defines.hpp>
#include <common/module_loading.hpp>

#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand All @@ -22,6 +24,8 @@ class logger;
}
namespace common {

using Version = std::tuple<int, int, int>; // major, minor, patch

/// Allows you to create classes which dynamically load dependencies at runtime
///
/// Creates a dependency module which will dynamically load a library
Expand All @@ -39,7 +43,9 @@ class DependencyModule {

DependencyModule(const std::vector<std::string>& plugin_base_file_name,
const std::vector<std::string>& suffixes,
const std::vector<std::string>& paths);
const std::vector<std::string>& paths,
const size_t verListSize = 0,
const Version* versions = nullptr);

~DependencyModule() noexcept;

Expand Down
9 changes: 8 additions & 1 deletion src/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,14 @@ endfunction()

if(AF_INSTALL_STANDALONE)
if(AF_WITH_CUDNN)
afcu_collect_libs(cudnn)
if(WIN32)
set(cudnn_lib "${cuDNN_DLL_LIBRARY}")
else()
get_filename_component(cudnn_lib "${cuDNN_LINK_LIBRARY}" REALPATH)
endif()
install(FILES ${cudnn_lib}
DESTINATION ${AF_INSTALL_LIB_DIR}
COMPONENT cuda_dependencies)
endif()

afcu_collect_libs(nvrtc FULL_VERSION)
Expand Down
21 changes: 20 additions & 1 deletion src/backend/cuda/cudnnModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,32 @@
#include <device_manager.hpp>
#include <utility.hpp>

#include <array>
#include <string>
#include <tuple>

using common::Version;
using std::make_tuple;
using std::string;

namespace cuda {

// clang-format off
// Latest version from each minor releases are enlisted below
constexpr std::array<common::Version, 10> cudnnVersions = {
make_tuple(7, 6, 5),
make_tuple(7, 5, 1),
make_tuple(7, 4, 2),
make_tuple(7, 3, 1),
make_tuple(7, 2, 1),
make_tuple(7, 1, 4),
make_tuple(7, 0, 5),
make_tuple(6, 0, 21),
make_tuple(5, 1, 10),
make_tuple(4, 0, 7)
};
// clang-format on

spdlog::logger* cudnnModule::getLogger() const noexcept {
return module.getLogger();
}
Expand All @@ -34,7 +52,8 @@ auto cudnnVersionComponents(size_t version) {
}

cudnnModule::cudnnModule()
: module({"cudnn"}, {"", "64_7", "64_8", "64_6", "64_5", "64_4"}, {""}) {
: module({"cudnn"}, {"", "64_7", "64_8", "64_6", "64_5", "64_4"}, {""},
cudnnVersions.size(), cudnnVersions.data()) {
if (!module.isLoaded()) {
AF_TRACE(
"WARNING: Unable to load cuDNN: {}"
Expand Down