Skip to content

Commit 363de58

Browse files
li-roysoumith
authored andcommitted
implement double backwards for MaxPool3d (#5328)
* implement double backwards for MaxPool3d * change MaxUnpool3d to use same indices as MaxPool3d * fix nits
1 parent 04461fa commit 363de58

File tree

9 files changed

+164
-232
lines changed

9 files changed

+164
-232
lines changed

aten/src/THCUNN/VolumetricDilatedMaxPooling.cu

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
template <typename Dtype>
1313
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
14-
THCDeviceTensor<Dtype, 4> input,
14+
Dtype* inputData, int inputT, int inputH, int inputW,
1515
THCDeviceTensor<THCIndex_t, 4> indices,
1616
THCDeviceTensor<Dtype, 4> output,
1717
int kT, int kH, int kW,
@@ -27,56 +27,53 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
2727

2828
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
2929
{
30-
int iColumn = oColumn * dW - padW;
31-
int iRow = oRow * dH - padH;
32-
int iFrame = oFrame * dT - padT;
33-
34-
int maxColumn = 0;
35-
int maxRow = 0;
36-
int maxFrame = 0;
30+
int tStart = oFrame * dT - padT;
31+
int hStart = oRow * dH - padH;
32+
int wStart = oColumn * dW - padW;
33+
int tEnd = fminf(tStart + (kT - 1) * dilationT + 1, inputT);
34+
int hEnd = fminf(hStart + (kH - 1) * dilationH + 1, inputH);
35+
int wEnd = fminf(wStart + (kW - 1) * dilationW + 1, inputW);
36+
37+
while(tStart < 0)
38+
tStart += dilationT;
39+
while(hStart < 0)
40+
hStart += dilationH;
41+
while(wStart < 0)
42+
wStart += dilationW;
43+
44+
int index = 0;
45+
int maxIndex = -1;
46+
inputData += slice * inputT * inputH * inputW;
3747

3848
Dtype max = THCNumerics<Dtype>::min();
3949

40-
for (int frame = 0; frame < kT; ++frame)
50+
for (int t = tStart; t < tEnd; t += dilationT)
4151
{
42-
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
52+
for (int h = hStart; h < hEnd; h += dilationH)
4353
{
44-
for (int row = 0; row < kH; ++row)
54+
for (int w = wStart; w < wEnd; w += dilationW)
4555
{
46-
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
56+
index = t * inputH * inputW + h * inputW + w;
57+
Dtype val = inputData[index];
58+
59+
if (max < val)
4760
{
48-
for (int column = 0; column < kW; ++column)
49-
{
50-
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
51-
{
52-
Dtype val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
53-
54-
if (max < val)
55-
{
56-
max = val;
57-
maxColumn = column;
58-
maxRow = row;
59-
maxFrame = frame;
60-
}
61-
}
62-
}
61+
max = val;
62+
maxIndex = index;
6363
}
6464
}
6565
}
6666
}
6767

6868
output[slice][oFrame][oRow][oColumn] = max;
69-
THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
70-
((unsigned char*)(idx))[0] = maxFrame;
71-
((unsigned char*)(idx))[1] = maxRow;
72-
((unsigned char*)(idx))[2] = maxColumn;
73-
((unsigned char*)(idx))[3] = 0;
69+
indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
7470
}
7571
}
7672

7773
template <int KERNEL_WIDTH, typename Dtype>
7874
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
79-
THCDeviceTensor<Dtype, 4> input, THCDeviceTensor<THCIndex_t, 4> indices,
75+
Dtype* inputData, int inputT, int inputH, int inputW,
76+
THCDeviceTensor<THCIndex_t, 4> indices,
8077
THCDeviceTensor<Dtype, 4> output,
8178
int kT, int kH,
8279
int dT, int dH, int dW,
@@ -91,58 +88,54 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
9188

9289
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
9390
{
94-
int iColumn = oColumn * dW - padW;
95-
int iRow = oRow * dH - padH;
96-
int iFrame = oFrame * dT - padT;
97-
98-
int maxColumn = 0;
99-
int maxRow = 0;
100-
int maxFrame;
91+
int tStart = oFrame * dT - padT;
92+
int hStart = oRow * dH - padH;
93+
int wStart = oColumn * dW - padW;
94+
int tEnd = fminf(tStart + (kT - 1) * dilationT + 1, inputT);
95+
int hEnd = fminf(hStart + (kH - 1) * dilationH + 1, inputH);
96+
int wEnd = fminf(wStart + (KERNEL_WIDTH - 1) * dilationW + 1, inputW);
97+
98+
while(tStart < 0)
99+
tStart += dilationT;
100+
while(hStart < 0)
101+
hStart += dilationH;
102+
while(wStart < 0)
103+
wStart += dilationW;
104+
105+
int index = 0;
106+
int maxIndex = -1;
101107

102108
Dtype max = THCNumerics<Dtype>::min();
103109

104-
for (int frame = 0; frame < kT; ++frame)
110+
for (int t = tStart; t < tEnd; t += dilationT)
105111
{
106-
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
112+
for (int h = hStart; h < hEnd; h += dilationH)
107113
{
108-
for (int row = 0; row < kH; ++row)
114+
for (int w = wStart; w < wEnd; w += dilationW)
109115
{
110-
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
116+
index = t * inputH * inputW + h * inputW + w;
117+
Dtype val = inputData[slice * inputT * inputH * inputW + index];
118+
119+
if (max < val)
111120
{
112-
for (int column = 0; column < KERNEL_WIDTH; ++column)
113-
{
114-
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
115-
{
116-
Dtype val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
117-
118-
if (max < val)
119-
{
120-
max = val;
121-
maxColumn = column;
122-
maxRow = row;
123-
maxFrame = frame;
124-
}
125-
}
126-
}
121+
max = val;
122+
maxIndex = index;
127123
}
128124
}
129125
}
130126
}
131127

132128
output[slice][oFrame][oRow][oColumn] = max;
133-
THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
134-
((unsigned char*)(idx))[0] = maxFrame;
135-
((unsigned char*)(idx))[1] = maxRow;
136-
((unsigned char*)(idx))[2] = maxColumn;
137-
((unsigned char*)(idx))[3] = 0;
129+
indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
138130
}
139131
}
140132

141133
template <typename Dtype>
142134
__global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
143135
THCDeviceTensor<Dtype, 4> gradOutput,
144136
THCDeviceTensor<THCIndex_t, 4> indices,
145-
THCDeviceTensor<Dtype, 4> gradInput,
137+
Dtype* gradInputData,
138+
int inputT, int inputH, int inputW,
146139
int dT, int dH, int dW,
147140
int padT, int padH, int padW,
148141
int dilationT, int dilationH, int dilationW,
@@ -155,12 +148,11 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
155148

156149
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
157150
{
158-
THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
159-
int iFrame = ((unsigned char*)(idx))[0] * dilationT + oFrame * dT - padT;
160-
int iRow = ((unsigned char*)(idx))[1] * dilationH + oRow * dH - padH;
161-
int iColumn = ((unsigned char*)(idx))[2] * dilationW + oColumn * dW - padW;
162-
atomicAdd(&gradInput[slice][iFrame][iRow][iColumn],
163-
gradOutput[slice][oFrame][oRow][oColumn]);
151+
int maxIndex = indices[slice][oFrame][oRow][oColumn] - TH_INDEX_BASE;
152+
if (maxIndex != -1) {
153+
atomicAdd(&gradInputData[slice * inputT * inputH * inputW + maxIndex],
154+
gradOutput[slice][oFrame][oRow][oColumn]);
155+
}
164156
}
165157
}
166158

aten/src/THCUNN/VolumetricMaxUnpooling.cu

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ template <typename Dtype>
1212
__global__ void cuda_VolumetricMaxUnpooling_updateOutput(
1313
THCDeviceTensor<Dtype, 4> input,
1414
THCDeviceTensor<THCIndex_t, 4> indices,
15-
THCDeviceTensor<Dtype, 4> output,
15+
Dtype* outputData,
16+
int oT, int oH, int oW,
1617
int dT, int dH, int dW,
1718
int padT, int padH, int padW, int offsetZ)
1819
{
@@ -23,23 +24,16 @@ __global__ void cuda_VolumetricMaxUnpooling_updateOutput(
2324

2425
if (iRow < input.getSize(2) && iColumn < input.getSize(3))
2526
{
26-
int64_t start_t = iFrame * dT - padT;
27-
int64_t start_h = iRow * dH - padH;
28-
int64_t start_w = iColumn * dW - padW;
29-
3027
Dtype val = input[slice][iFrame][iRow][iColumn];
31-
32-
THCIndex_t *idx = &indices[slice][iFrame][iRow][iColumn];
33-
int64_t maxz = ((unsigned char*)(idx))[0];
34-
int64_t maxy = ((unsigned char*)(idx))[1];
35-
int64_t maxx = ((unsigned char*)(idx))[2];
36-
output[slice][start_t + maxz][start_h + maxy][start_w + maxx] = val;
28+
int64_t index = indices[slice][iFrame][iRow][iColumn];
29+
outputData[slice*oT*oH*oW + index] = val;
3730
}
3831
}
3932

4033
template <typename Dtype>
4134
__global__ void cuda_VolumetricMaxUnpooling_updateGradInput(
42-
THCDeviceTensor<Dtype, 4> gradOutput,
35+
Dtype* gradOutputData,
36+
int oT, int oH, int oW,
4337
THCDeviceTensor<THCIndex_t, 4> indices,
4438
THCDeviceTensor<Dtype, 4> gradInput,
4539
int dT, int dH, int dW,
@@ -52,18 +46,8 @@ __global__ void cuda_VolumetricMaxUnpooling_updateGradInput(
5246

5347
if (iRow < gradInput.getSize(2) && iColumn < gradInput.getSize(3))
5448
{
55-
56-
int64_t start_t = iFrame * dT - padT;
57-
int64_t start_h = iRow * dH - padH;
58-
int64_t start_w = iColumn * dW - padW;
59-
60-
THCIndex_t *idx = &indices[slice][iFrame][iRow][iColumn];
61-
int64_t maxz = ((unsigned char*)(idx))[0];
62-
int64_t maxy = ((unsigned char*)(idx))[1];
63-
int64_t maxx = ((unsigned char*)(idx))[2];
64-
65-
Dtype grad_val = gradOutput[slice][start_t + maxz][start_h + maxy][start_w + maxx];
66-
49+
int64_t index = indices[slice][iFrame][iRow][iColumn];
50+
Dtype grad_val = gradOutputData[slice*oT*oH*oW + index];
6751
gradInput[slice][iFrame][iRow][iColumn] = grad_val;
6852
}
6953
}

aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
66
cuda_VolumetricDilatedMaxPooling_updateOutput<KW><<<grid, block, \
77
0, THCState_getCurrentStream(state)>>>( \
8-
cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW,\
8+
inputData, inputTime, inputHeight, inputWidth, \
9+
cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW,\
910
dilationT, dilationH, dilationW, offsetZ); \
1011
break
1112

@@ -233,10 +234,10 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
233234
} else {
234235
THCTensor_(retain)(state, output);
235236
}
237+
238+
real* inputData = THCTensor_(data)(state, input);
236239

237-
THCDeviceTensor<real, 4> cudaInput;
238240
THCDeviceTensor<real, 4> cudaOutput;
239-
cudaInput = toDeviceTensor<real, 4>(state, input);
240241
cudaOutput = toDeviceTensor<real, 4>(state, output);
241242

242243
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
@@ -275,7 +276,8 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
275276
default:
276277
cuda_VolumetricDilatedMaxPooling_updateOutput<<<grid, block,
277278
0, THCState_getCurrentStream(state)>>>(
278-
cudaInput, cudaIndices, cudaOutput,
279+
inputData, inputTime, inputHeight, inputWidth,
280+
cudaIndices, cudaOutput,
279281
kT, kH, kW, dT, dH, dW,
280282
padT, padH, padW, dilationT, dilationH, dilationW, offsetZ);
281283
}
@@ -306,14 +308,14 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
306308
// TODO: gradOutput shape check
307309
// Resize and initialize result tensor.
308310
THCTensor_(resizeAs)(state, gradInput, input);
311+
THCTensor_(newContiguous)(state, gradInput);
309312
THCTensor_(zero)(state, gradInput);
310313

311314
int batchSize;
312315
int inputSlices;
313316

314-
int outputTime;
315-
int outputHeight;
316-
int outputWidth;
317+
int outputTime, outputHeight, outputWidth;
318+
int inputTime, inputHeight, inputWidth;
317319

318320
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
319321

@@ -331,6 +333,9 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
331333
outputTime = THCTensor_(size)(state, gradOutput, 1);
332334
outputHeight = THCTensor_(size)(state, gradOutput, 2);
333335
outputWidth = THCTensor_(size)(state, gradOutput, 3);
336+
inputTime = THCTensor_(size)(state, gradInput, 1);
337+
inputHeight = THCTensor_(size)(state, gradInput, 2);
338+
inputWidth = THCTensor_(size)(state, gradInput, 3);
334339
}
335340
else
336341
{
@@ -340,6 +345,9 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
340345
outputTime = THCTensor_(size)(state, gradOutput, 2);
341346
outputHeight = THCTensor_(size)(state, gradOutput, 3);
342347
outputWidth = THCTensor_(size)(state, gradOutput, 4);
348+
inputTime = THCTensor_(size)(state, gradInput, 2);
349+
inputHeight = THCTensor_(size)(state, gradInput, 3);
350+
inputWidth = THCTensor_(size)(state, gradInput, 4);
343351
}
344352

345353
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
@@ -354,10 +362,9 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
354362
THCTensor_(retain)(state, gradInput);
355363
}
356364

357-
THCDeviceTensor<real, 4> cudaGradInput;
358365
THCDeviceTensor<real, 4> cudaGradOutput;
359-
cudaGradInput = toDeviceTensor<real, 4>(state, gradInput);
360366
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
367+
real* gradInputData = THCTensor_(data)(state, gradInput);
361368

362369
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
363370
int64_t indicesSizeRaw[4] = { batchSize * inputSlices,
@@ -384,7 +391,8 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
384391
0, THCState_getCurrentStream(state)>>>(
385392
cudaGradOutput,
386393
cudaIndices,
387-
cudaGradInput,
394+
gradInputData,
395+
inputTime, inputHeight, inputWidth,
388396
dT, dH, dW,
389397
padT, padH, padW,
390398
dilationT, dilationH, dilationW, offsetZ);

0 commit comments

Comments
 (0)