forked from kevinlin311tw/Caffe-DeepBinaryCode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcudnn_conv_layer.cpp
More file actions
130 lines (111 loc) · 4.23 KB
/
cudnn_conv_layer.cpp
File metadata and controls
130 lines (111 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
// Set to three for the benefit of the backward pass, which
// can use separate streams for calculating the gradient w.r.t.
// bias, filter weights, and bottom data for each group independently
#define CUDNN_STREAMS_PER_GROUP 3
/**
* TODO(dox) explain cuDNN interface
*/
template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
// Initialize CUDA streams and cuDNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
workspaceSizeInBytes = 0;
workspace = NULL;
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
CUDNN_CHECK(cudnnCreate(&handle_[g]));
CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
}
// Set the indexing parameters.
weight_offset_ = (this->num_output_ / this->group_)
* (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_;
bias_offset_ = (this->num_output_ / this->group_);
// Create filter descriptor.
cudnn::createFilterDesc<Dtype>(&filter_desc_,
this->num_output_ / this->group_, this->channels_ / this->group_,
this->kernel_h_, this->kernel_w_);
// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
cudnn::createConvolutionDesc<Dtype>(&conv_desc);
conv_descs_.push_back(conv_desc);
}
// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
}
handles_setup_ = true;
}
template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
ConvolutionLayer<Dtype>::Reshape(bottom, top);
bottom_offset_ = (this->channels_ / this->group_)
* this->height_ * this->width_;
top_offset_ = (this->num_output_ / this->group_)
* this->height_out_ * this->width_out_;
for (int i = 0; i < bottom.size(); i++) {
cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
this->num_,
this->channels_ / this->group_,
this->height_, this->width_,
this->channels_ * this->height_ * this->width_,
this->height_ * this->width_,
this->width_, 1);
cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
this->num_,
this->num_output_ / this->group_,
this->height_out_, this->width_out_,
this->num_output_ * this->height_out_ * this->width_out_,
this->height_out_ * this->width_out_,
this->width_out_, 1);
cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
filter_desc_, this->pad_h_, this->pad_w_,
this->stride_h_, this->stride_w_);
}
// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::setTensor4dDesc<Dtype>(&bias_desc_,
1, this->num_output_ / this->group_, 1, 1);
}
}
template <typename Dtype>
CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
for (int i = 0; i < bottom_descs_.size(); i++) {
cudnnDestroyTensorDescriptor(bottom_descs_[i]);
cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
cudaStreamDestroy(stream_[g]);
cudnnDestroy(handle_[g]);
}
delete [] stream_;
delete [] handle_;
}
INSTANTIATE_CLASS(CuDNNConvolutionLayer);
} // namespace caffe
#endif