@@ -141,7 +141,9 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
141141 int dimh = 2 ;
142142 int dimw = 3 ;
143143
144- if (input->nDimension == 5 )
144+ int fiveDimensionalInput = THCTensor_ (nDimension)(state, input) == 5 ;
145+
146+ if (fiveDimensionalInput)
145147 {
146148 dimt++;
147149 dimh++;
@@ -163,7 +165,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
163165 inputHeight = THCTensor_ (size)(state, input, 2 );
164166 inputWidth = THCTensor_ (size)(state, input, 3 );
165167 }
166- else if (THCTensor_ (nDimension)(state, input) == 5 )
168+ else if (fiveDimensionalInput )
167169 {
168170 /* sizes */
169171 batchSize = THCTensor_ (size)(state, input, 0 );
@@ -200,7 +202,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
200202 --outputWidth;
201203 }
202204
203- if (input-> nDimension == 4 ) /* 4D */
205+ if (!fiveDimensionalInput ) /* 4D */
204206 {
205207 /* resize output */
206208 THCTensor_ (resize4d)(state, output, inputSlices,
@@ -217,23 +219,25 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
217219 // Index tensor packs index offsets as uchars into floats
218220 THCIndexTensor_ (resize5d)(state, indices, batchSize, inputSlices,
219221 outputTime, outputHeight, outputWidth);
222+ fiveDimensionalInput = 1 ;
220223 }
221224
222225 input = THCTensor_ (newContiguous)(state, input);
226+ if (fiveDimensionalInput) {
227+ // Collapse batch and feature dimensions
228+ output = THCTensor_ (newFoldBatchDim)(state, output);
229+
230+ THCTensor *old_input = input;
231+ input = THCTensor_ (newFoldBatchDim)(state, input);
232+ THCTensor_ (free )(state, old_input);
233+ } else {
234+ THCTensor_ (retain)(state, output);
235+ }
223236
224- // Collapse batch and feature dimensions
225237 THCDeviceTensor<real, 4 > cudaInput;
226238 THCDeviceTensor<real, 4 > cudaOutput;
227- if (THCTensor_ (nDimension)(state, input) == 4 )
228- {
229- cudaInput = toDeviceTensor<real, 4 >(state, input);
230- cudaOutput = toDeviceTensor<real, 4 >(state, output);
231- }
232- else
233- {
234- cudaInput = toDeviceTensor<real, 5 >(state, input).downcastOuter <4 >();
235- cudaOutput = toDeviceTensor<real, 5 >(state, output).downcastOuter <4 >();
236- }
239+ cudaInput = toDeviceTensor<real, 4 >(state, input);
240+ cudaOutput = toDeviceTensor<real, 4 >(state, output);
237241
238242 THLongStorage *indicesSize = THLongStorage_newWithSize (4 );
239243 int64_t indicesSizeRaw[4 ] = { batchSize * inputSlices,
@@ -281,6 +285,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
281285 }
282286
283287 THCTensor_ (free )(state, input);
288+ THCTensor_ (free )(state, output);
284289 THCIndexTensor_ (free )(state, indices1);
285290}
286291
@@ -310,13 +315,15 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
310315 int outputHeight;
311316 int outputWidth;
312317
318+ int fiveDimensionalInput = THCTensor_ (nDimension)(state, input) == 5 ;
319+
313320 THCUNN_assertSameGPU (state, 4 , input, indices, gradOutput, gradInput);
314321 THNN_ (VolumetricDilatedMaxPooling_shapeCheck)(
315322 state, input, gradOutput, indices, kT , kW , kH ,
316323 dT, dW, dH, padT, padW, padH,
317324 dilationT, dilationW, dilationH, ceilMode);
318325
319- if (THCTensor_ (nDimension)(state, input) == 4 ) /* 4D */
326+ if (!fiveDimensionalInput ) /* 4D */
320327 {
321328 batchSize = 1 ;
322329 inputSlices = THCTensor_ (size)(state, input, 0 );
@@ -336,22 +343,21 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
336343 }
337344
338345 gradOutput = THCTensor_ (newContiguous)(state, gradOutput);
346+ if (fiveDimensionalInput) {
347+ // Collapse batch and feature dimensions
348+ gradInput = THCTensor_ (newFoldBatchDim)(state, gradInput);
349+
350+ THCTensor *old_gradOutput = gradOutput;
351+ gradOutput = THCTensor_ (newFoldBatchDim)(state, gradOutput);
352+ THCTensor_ (free )(state, old_gradOutput);
353+ } else {
354+ THCTensor_ (retain)(state, gradInput);
355+ }
339356
340- // Collapse batch and feature dimensions
341357 THCDeviceTensor<real, 4 > cudaGradInput;
342358 THCDeviceTensor<real, 4 > cudaGradOutput;
343- if (THCTensor_ (nDimension)(state, input) == 4 )
344- {
345- cudaGradInput = toDeviceTensor<real, 4 >(state, gradInput);
346- cudaGradOutput = toDeviceTensor<real, 4 >(state, gradOutput);
347- }
348- else
349- {
350- cudaGradInput =
351- toDeviceTensor<real, 5 >(state, gradInput).downcastOuter <4 >();
352- cudaGradOutput =
353- toDeviceTensor<real, 5 >(state, gradOutput).downcastOuter <4 >();
354- }
359+ cudaGradInput = toDeviceTensor<real, 4 >(state, gradInput);
360+ cudaGradOutput = toDeviceTensor<real, 4 >(state, gradOutput);
355361
356362 THLongStorage *indicesSize = THLongStorage_newWithSize (4 );
357363 int64_t indicesSizeRaw[4 ] = { batchSize * inputSlices,
@@ -388,6 +394,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
388394 }
389395
390396 // cleanup
397+ THCTensor_ (free )(state, gradInput);
391398 THCTensor_ (free )(state, gradOutput);
392399 THCIndexTensor_ (free )(state, indices1);
393400}
0 commit comments