Skip to content

Commit 63d6afd

Browse files
ngimelsoumith
authored andcommitted
improve performance of maxpooling backwards (#4106)
1 parent 8c3e1b7 commit 63d6afd

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

torch/lib/THCUNN/SpatialDilatedMaxPooling.cu

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,41 +41,71 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
4141
}
4242
}
4343

44+
const int BACKWARD_THREADS = 256;
4445

4546
template <typename Dtype, typename AccType>
47+
__launch_bounds__(BACKWARD_THREADS,2048/BACKWARD_THREADS)
4648
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
4749
const int64_t* top_mask, const int num, const int channels,
4850
const int height, const int width, const int pooled_height,
4951
const int pooled_width, const int kernel_h, const int kernel_w,
5052
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
5153
const int dilation_h, const int dilation_w,
5254
Dtype* bottom_diff) {
53-
CUDA_KERNEL_LOOP(index, nthreads) {
54-
// find out the local index
55-
// find out the local offset
56-
int w = index % width;
57-
int h = (index / width) % height;
58-
int c = (index / width / height) % channels;
59-
int n = index / width / height / channels;
60-
int phstart =
61-
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
62-
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
63-
int pwstart =
55+
CUDA_KERNEL_LOOP(index, height*width) {
56+
int h = index/width;
57+
int w = index - h * width;
58+
//get some templating performance benefits without actually templating
59+
int phstart, phend, pwstart, pwend;
60+
if (stride_h == 1) {
61+
phstart =
62+
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) + 1;
63+
phend = min((h + pad_h) + 1, pooled_height);
64+
} else if (stride_h == 2) {
65+
phstart =
66+
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / 2 + 1;
67+
phend = min((h + pad_h) / 2 + 1, pooled_height);
68+
} else {
69+
phstart =
70+
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
71+
phend = min((h + pad_h) / stride_h + 1, pooled_height);
72+
}
73+
if (stride_w == 1) {
74+
pwstart =
75+
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) + 1;
76+
pwend = min((w + pad_w) + 1, pooled_width);
77+
} else if (stride_w == 2) {
78+
pwstart =
79+
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / 2 + 1;
80+
pwend = min((w + pad_w) / 2 + 1, pooled_width);
81+
} else {
82+
pwstart =
6483
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
65-
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
84+
pwend = min((w + pad_w) / stride_w + 1, pooled_width);
85+
}
86+
for (int n = blockIdx.y; n < num; n += gridDim.y)
87+
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
6688

67-
AccType gradient = AccType(0);
68-
int offset = (n * channels + c) * pooled_height * pooled_width;
69-
top_diff += offset;
70-
top_mask += offset;
71-
for (int ph = phstart; ph < phend; ++ph) {
72-
for (int pw = pwstart; pw < pwend; ++pw) {
73-
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
74-
gradient += ScalarConvert<Dtype, AccType>::to(top_diff[ph * pooled_width + pw]);
89+
AccType gradient = AccType(0);
90+
int offset = (n * channels + c) * pooled_height * pooled_width;
91+
top_diff += offset;
92+
top_mask += offset;
93+
//get some templating performance benefits without actually templating
94+
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
95+
for (int ph = phstart; ph < phend; ++ph) {
96+
for (int pw = pwstart; pw < pwend; ++pw) {
97+
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
98+
gradient += ScalarConvert<Dtype, AccType>::to(top_diff[ph * pooled_width + pw]);
99+
}
100+
}
75101
}
102+
} else {
103+
if (top_mask[phstart * pooled_width + pwstart] - TH_INDEX_BASE == h * width + w) {
104+
gradient += ScalarConvert<Dtype, AccType>::to(top_diff[phstart * pooled_width + pwstart]);
105+
}
106+
}
107+
bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert<AccType, Dtype>::to(gradient);
76108
}
77-
}
78-
bottom_diff[index] = ScalarConvert<AccType, Dtype>::to(gradient);
79109
}
80110
}
81111

torch/lib/THCUNN/generic/SpatialDilatedMaxPooling.cu

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,17 @@ void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
217217
THCTensor_(resizeAs)(state, gradInput, input);
218218

219219
int count = THCTensor_(nElement)(state, input);
220-
221-
MaxPoolBackward<real, accreal> <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
220+
dim3 grid;
221+
int imgcount = nInputCols * nInputRows;
222+
const int blocks = (imgcount + BACKWARD_THREADS - 1) / BACKWARD_THREADS;
223+
grid.x = blocks;
224+
grid.y = batchSize;
225+
grid.z = nInputPlane;
226+
uint64_t maxGridY = THCState_getCurrentDeviceProperties(state)->maxGridSize[1];
227+
uint64_t maxGridZ = THCState_getCurrentDeviceProperties(state)->maxGridSize[2];
228+
if (maxGridY < grid.y) grid.y = maxGridY;
229+
if (maxGridZ < grid.z) grid.z = maxGridZ;
230+
MaxPoolBackward<real, accreal> <<< grid, BACKWARD_THREADS, 0, THCState_getCurrentStream(state) >>>
222231
(count,
223232
THCTensor_(data)(state, gradOutput),
224233
THCIndexTensor_(data)(state, indices),

0 commit comments

Comments
 (0)