@@ -41,41 +41,71 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
4141 }
4242}
4343
44+ const int BACKWARD_THREADS = 256 ;
4445
4546template <typename Dtype, typename AccType>
47+ __launch_bounds__ (BACKWARD_THREADS,2048 /BACKWARD_THREADS)
4648__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
4749 const int64_t * top_mask, const int num, const int channels,
4850 const int height, const int width, const int pooled_height,
4951 const int pooled_width, const int kernel_h, const int kernel_w,
5052 const int stride_h, const int stride_w, const int pad_h, const int pad_w,
5153 const int dilation_h, const int dilation_w,
5254 Dtype* bottom_diff) {
53- CUDA_KERNEL_LOOP (index, nthreads) {
54- // find out the local index
55- // find out the local offset
56- int w = index % width;
57- int h = (index / width) % height;
58- int c = (index / width / height) % channels;
59- int n = index / width / height / channels;
60- int phstart =
61- (h + pad_h < ((kernel_h - 1 ) * dilation_h + 1 )) ? 0 : (h + pad_h - ((kernel_h - 1 ) * dilation_h + 1 )) / stride_h + 1 ;
62- int phend = min ((h + pad_h) / stride_h + 1 , pooled_height);
63- int pwstart =
55+ CUDA_KERNEL_LOOP (index, height*width) {
56+ int h = index/width;
57+ int w = index - h * width;
58+ // get some templating performance benefits without actually templating
59+ int phstart, phend, pwstart, pwend;
60+ if (stride_h == 1 ) {
61+ phstart =
62+ (h + pad_h < ((kernel_h - 1 ) * dilation_h + 1 )) ? 0 : (h + pad_h - ((kernel_h - 1 ) * dilation_h + 1 )) + 1 ;
63+ phend = min ((h + pad_h) + 1 , pooled_height);
64+ } else if (stride_h == 2 ) {
65+ phstart =
66+ (h + pad_h < ((kernel_h - 1 ) * dilation_h + 1 )) ? 0 : (h + pad_h - ((kernel_h - 1 ) * dilation_h + 1 )) / 2 + 1 ;
67+ phend = min ((h + pad_h) / 2 + 1 , pooled_height);
68+ } else {
69+ phstart =
70+ (h + pad_h < ((kernel_h - 1 ) * dilation_h + 1 )) ? 0 : (h + pad_h - ((kernel_h - 1 ) * dilation_h + 1 )) / stride_h + 1 ;
71+ phend = min ((h + pad_h) / stride_h + 1 , pooled_height);
72+ }
73+ if (stride_w == 1 ) {
74+ pwstart =
75+ (w + pad_w < ((kernel_w - 1 ) * dilation_w + 1 )) ? 0 : (w + pad_w - ((kernel_w - 1 ) * dilation_w + 1 )) + 1 ;
76+ pwend = min ((w + pad_w) + 1 , pooled_width);
77+ } else if (stride_w == 2 ) {
78+ pwstart =
79+ (w + pad_w < ((kernel_w - 1 ) * dilation_w + 1 )) ? 0 : (w + pad_w - ((kernel_w - 1 ) * dilation_w + 1 )) / 2 + 1 ;
80+ pwend = min ((w + pad_w) / 2 + 1 , pooled_width);
81+ } else {
82+ pwstart =
6483 (w + pad_w < ((kernel_w - 1 ) * dilation_w + 1 )) ? 0 : (w + pad_w - ((kernel_w - 1 ) * dilation_w + 1 )) / stride_w + 1 ;
65- int pwend = min ((w + pad_w) / stride_w + 1 , pooled_width);
84+ pwend = min ((w + pad_w) / stride_w + 1 , pooled_width);
85+ }
86+ for (int n = blockIdx .y ; n < num; n += gridDim .y )
87+ for (int c = blockIdx .z ; c < channels; c+= gridDim .z ) {
6688
67- AccType gradient = AccType (0 );
68- int offset = (n * channels + c) * pooled_height * pooled_width;
69- top_diff += offset;
70- top_mask += offset;
71- for (int ph = phstart; ph < phend; ++ph) {
72- for (int pw = pwstart; pw < pwend; ++pw) {
73- if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
74- gradient += ScalarConvert<Dtype, AccType>::to (top_diff[ph * pooled_width + pw]);
89+ AccType gradient = AccType (0 );
90+ int offset = (n * channels + c) * pooled_height * pooled_width;
91+ top_diff += offset;
92+ top_mask += offset;
93+ // get some templating performance benefits without actually templating
94+ if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
95+ for (int ph = phstart; ph < phend; ++ph) {
96+ for (int pw = pwstart; pw < pwend; ++pw) {
97+ if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
98+ gradient += ScalarConvert<Dtype, AccType>::to (top_diff[ph * pooled_width + pw]);
99+ }
100+ }
75101 }
102+ } else {
103+ if (top_mask[phstart * pooled_width + pwstart] - TH_INDEX_BASE == h * width + w) {
104+ gradient += ScalarConvert<Dtype, AccType>::to (top_diff[phstart * pooled_width + pwstart]);
105+ }
106+ }
107+ bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert<AccType, Dtype>::to (gradient);
76108 }
77- }
78- bottom_diff[index] = ScalarConvert<AccType, Dtype>::to (gradient);
79109 }
80110}
81111
0 commit comments