-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathpycudnn.cpp
More file actions
162 lines (143 loc) · 6.93 KB
/
pycudnn.cpp
File metadata and controls
162 lines (143 loc) · 6.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <utility>
#include "pybind11/pybind11.h"
#include "pybind11/cast.h"
#include "pybind11/stl.h"
#include "cudnn_frontend.h"
namespace py = pybind11;
using namespace pybind11::literals;
namespace cudnn_frontend {
#ifdef _WIN32
HMODULE cudnn_dlhandle = nullptr;
#else
void *cudnn_dlhandle = nullptr;
#endif
namespace python_bindings {
// Raise C++ exceptions corresponding to C++ FE error codes.
// Pybinds will automatically convert C++ exceptions to python exceptions.
void
throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) {
if (cond == false) return;
switch (error_code) {
case cudnn_frontend::error_code_t::OK:
return;
case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::CUDA_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::HANDLE_ERROR:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::INVALID_VALUE:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::NVRTC_COMPILATION_FAILED:
throw std::runtime_error(error_msg);
}
}
// pybinds for pygraph class
void
init_pygraph_submodule(py::module_ &);
// pybinds for kernel_cache class
void
create_kernel_cache_submodule(py::module_ &);
// pybinds for all properties and helpers
void
init_properties(py::module_ &);
void
set_dlhandle_cudnn(std::intptr_t dlhandle) {
#ifdef _WIN32
cudnn_dlhandle = reinterpret_cast<HMODULE>(dlhandle);
#else
cudnn_dlhandle = reinterpret_cast<void *>(dlhandle);
#endif
}
PYBIND11_MODULE(_compiled_module, m) {
m.def("backend_version", &detail::get_backend_version);
m.def("backend_version_string", &detail::get_backend_version_string);
init_properties(m);
init_pygraph_submodule(m);
m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn);
py::register_exception<cudnnGraphNotSupportedException>(m, "cudnnGraphNotSupportedError");
#if CUDNN_VERSION >= 92200
m.def("causal_conv1d_forward",
[](std::intptr_t stream,
std::intptr_t x_ptr,
std::intptr_t weight_ptr,
std::intptr_t bias_ptr,
std::intptr_t out_ptr,
int batch,
int dim,
int seq_len,
int kernel_size,
int data_type,
int activation) {
auto status = detail::causal_conv1d_forward(reinterpret_cast<cudaStream_t>(stream),
reinterpret_cast<const void *>(x_ptr),
reinterpret_cast<const void *>(weight_ptr),
reinterpret_cast<const void *>(bias_ptr),
reinterpret_cast<void *>(out_ptr),
batch,
dim,
seq_len,
kernel_size,
static_cast<cudnnDataType_t>(data_type),
static_cast<cudnnCausalConv1dActivation_t>(activation));
if (status != 0)
throw std::runtime_error("cudnnCausalConv1dForward failed with status " + std::to_string(status));
});
m.def("causal_conv1d_backward",
[](std::intptr_t stream,
std::intptr_t x_ptr,
std::intptr_t weight_ptr,
std::intptr_t bias_ptr,
std::intptr_t dy_ptr,
std::intptr_t dx_ptr,
std::intptr_t dweight_ptr,
std::intptr_t dbias_ptr,
int batch,
int dim,
int seq_len,
int kernel_size,
int data_type,
int dw_data_type,
int activation) {
auto status = detail::causal_conv1d_backward(reinterpret_cast<cudaStream_t>(stream),
reinterpret_cast<const void *>(x_ptr),
reinterpret_cast<const void *>(weight_ptr),
reinterpret_cast<const void *>(bias_ptr),
reinterpret_cast<const void *>(dy_ptr),
reinterpret_cast<void *>(dx_ptr),
reinterpret_cast<void *>(dweight_ptr),
reinterpret_cast<void *>(dbias_ptr),
batch,
dim,
seq_len,
kernel_size,
static_cast<cudnnDataType_t>(data_type),
static_cast<cudnnDataType_t>(dw_data_type),
static_cast<cudnnCausalConv1dActivation_t>(activation));
if (status != 0)
throw std::runtime_error("cudnnCausalConv1dBackward failed with status " + std::to_string(status));
});
#endif
}
} // namespace python_bindings
} // namespace cudnn_frontend