1111
1212template <typename Dtype>
1313__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput (
14- THCDeviceTensor< Dtype, 4 > input ,
14+ Dtype* inputData, int inputT, int inputH, int inputW ,
1515 THCDeviceTensor<THCIndex_t, 4 > indices,
1616 THCDeviceTensor<Dtype, 4 > output,
1717 int kT , int kH , int kW ,
@@ -27,56 +27,53 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
2727
2828 if (oRow < output.getSize (2 ) && oColumn < output.getSize (3 ))
2929 {
30- int iColumn = oColumn * dW - padW;
31- int iRow = oRow * dH - padH;
32- int iFrame = oFrame * dT - padT;
33-
34- int maxColumn = 0 ;
35- int maxRow = 0 ;
36- int maxFrame = 0 ;
30+ int tStart = oFrame * dT - padT;
31+ int hStart = oRow * dH - padH;
32+ int wStart = oColumn * dW - padW;
33+ int tEnd = fminf (tStart + (kT - 1 ) * dilationT + 1 , inputT);
34+ int hEnd = fminf (hStart + (kH - 1 ) * dilationH + 1 , inputH);
35+ int wEnd = fminf (wStart + (kW - 1 ) * dilationW + 1 , inputW);
36+
37+ while (tStart < 0 )
38+ tStart += dilationT;
39+ while (hStart < 0 )
40+ hStart += dilationH;
41+ while (wStart < 0 )
42+ wStart += dilationW;
43+
44+ int index = 0 ;
45+ int maxIndex = -1 ;
46+ inputData += slice * inputT * inputH * inputW;
3747
3848 Dtype max = THCNumerics<Dtype>::min ();
3949
40- for (int frame = 0 ; frame < kT ; ++frame )
50+ for (int t = tStart; t < tEnd; t += dilationT )
4151 {
42- if (iFrame + frame * dilationT < input. getSize ( 1 ) && iFrame + frame * dilationT >= 0 )
52+ for ( int h = hStart; h < hEnd; h += dilationH )
4353 {
44- for (int row = 0 ; row < kH ; ++row )
54+ for (int w = wStart; w < wEnd; w += dilationW )
4555 {
46- if (iRow + row * dilationH < input.getSize (2 ) && iRow + row * dilationH >= 0 )
56+ index = t * inputH * inputW + h * inputW + w;
57+ Dtype val = inputData[index];
58+
59+ if (max < val)
4760 {
48- for (int column = 0 ; column < kW ; ++column)
49- {
50- if (iColumn + column * dilationW < input.getSize (3 ) && iColumn + column * dilationW >= 0 )
51- {
52- Dtype val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
53-
54- if (max < val)
55- {
56- max = val;
57- maxColumn = column;
58- maxRow = row;
59- maxFrame = frame;
60- }
61- }
62- }
61+ max = val;
62+ maxIndex = index;
6363 }
6464 }
6565 }
6666 }
6767
6868 output[slice][oFrame][oRow][oColumn] = max;
69- THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
70- ((unsigned char *)(idx))[0 ] = maxFrame;
71- ((unsigned char *)(idx))[1 ] = maxRow;
72- ((unsigned char *)(idx))[2 ] = maxColumn;
73- ((unsigned char *)(idx))[3 ] = 0 ;
69+ indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
7470 }
7571}
7672
7773template <int KERNEL_WIDTH, typename Dtype>
7874__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput (
79- THCDeviceTensor<Dtype, 4 > input, THCDeviceTensor<THCIndex_t, 4 > indices,
75+ Dtype* inputData, int inputT, int inputH, int inputW,
76+ THCDeviceTensor<THCIndex_t, 4 > indices,
8077 THCDeviceTensor<Dtype, 4 > output,
8178 int kT , int kH ,
8279 int dT, int dH, int dW,
@@ -91,58 +88,54 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
9188
9289 if (oRow < output.getSize (2 ) && oColumn < output.getSize (3 ))
9390 {
94- int iColumn = oColumn * dW - padW;
95- int iRow = oRow * dH - padH;
96- int iFrame = oFrame * dT - padT;
97-
98- int maxColumn = 0 ;
99- int maxRow = 0 ;
100- int maxFrame;
91+ int tStart = oFrame * dT - padT;
92+ int hStart = oRow * dH - padH;
93+ int wStart = oColumn * dW - padW;
94+ int tEnd = fminf (tStart + (kT - 1 ) * dilationT + 1 , inputT);
95+ int hEnd = fminf (hStart + (kH - 1 ) * dilationH + 1 , inputH);
96+ int wEnd = fminf (wStart + (KERNEL_WIDTH - 1 ) * dilationW + 1 , inputW);
97+
98+ while (tStart < 0 )
99+ tStart += dilationT;
100+ while (hStart < 0 )
101+ hStart += dilationH;
102+ while (wStart < 0 )
103+ wStart += dilationW;
104+
105+ int index = 0 ;
106+ int maxIndex = -1 ;
101107
102108 Dtype max = THCNumerics<Dtype>::min ();
103109
104- for (int frame = 0 ; frame < kT ; ++frame )
110+ for (int t = tStart; t < tEnd; t += dilationT )
105111 {
106- if (iFrame + frame * dilationT < input. getSize ( 1 ) && iFrame + frame * dilationT >= 0 )
112+ for ( int h = hStart; h < hEnd; h += dilationH )
107113 {
108- for (int row = 0 ; row < kH ; ++row )
114+ for (int w = wStart; w < wEnd; w += dilationW )
109115 {
110- if (iRow + row * dilationH < input.getSize (2 ) && iRow + row * dilationH >= 0 )
116+ index = t * inputH * inputW + h * inputW + w;
117+ Dtype val = inputData[slice * inputT * inputH * inputW + index];
118+
119+ if (max < val)
111120 {
112- for (int column = 0 ; column < KERNEL_WIDTH; ++column)
113- {
114- if (iColumn + column * dilationW < input.getSize (3 ) && iColumn + column * dilationW >= 0 )
115- {
116- Dtype val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
117-
118- if (max < val)
119- {
120- max = val;
121- maxColumn = column;
122- maxRow = row;
123- maxFrame = frame;
124- }
125- }
126- }
121+ max = val;
122+ maxIndex = index;
127123 }
128124 }
129125 }
130126 }
131127
132128 output[slice][oFrame][oRow][oColumn] = max;
133- THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
134- ((unsigned char *)(idx))[0 ] = maxFrame;
135- ((unsigned char *)(idx))[1 ] = maxRow;
136- ((unsigned char *)(idx))[2 ] = maxColumn;
137- ((unsigned char *)(idx))[3 ] = 0 ;
129+ indices[slice][oFrame][oRow][oColumn] = maxIndex + TH_INDEX_BASE;
138130 }
139131}
140132
141133template <typename Dtype>
142134__global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput (
143135 THCDeviceTensor<Dtype, 4 > gradOutput,
144136 THCDeviceTensor<THCIndex_t, 4 > indices,
145- THCDeviceTensor<Dtype, 4 > gradInput,
137+ Dtype* gradInputData,
138+ int inputT, int inputH, int inputW,
146139 int dT, int dH, int dW,
147140 int padT, int padH, int padW,
148141 int dilationT, int dilationH, int dilationW,
@@ -155,12 +148,11 @@ __global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
155148
156149 if (oRow < gradOutput.getSize (2 ) && oColumn < gradOutput.getSize (3 ))
157150 {
158- THCIndex_t *idx = &indices[slice][oFrame][oRow][oColumn];
159- int iFrame = ((unsigned char *)(idx))[0 ] * dilationT + oFrame * dT - padT;
160- int iRow = ((unsigned char *)(idx))[1 ] * dilationH + oRow * dH - padH;
161- int iColumn = ((unsigned char *)(idx))[2 ] * dilationW + oColumn * dW - padW;
162- atomicAdd (&gradInput[slice][iFrame][iRow][iColumn],
163- gradOutput[slice][oFrame][oRow][oColumn]);
151+ int maxIndex = indices[slice][oFrame][oRow][oColumn] - TH_INDEX_BASE;
152+ if (maxIndex != -1 ) {
153+ atomicAdd (&gradInputData[slice * inputT * inputH * inputW + maxIndex],
154+ gradOutput[slice][oFrame][oRow][oColumn]);
155+ }
164156 }
165157}
166158
0 commit comments