Skip to content

Commit c088e32

Browse files
committed
Replace downcastOuter with newFoldBatchDim
1 parent eb5daa9 commit c088e32

File tree

6 files changed

+144
-75
lines changed

6 files changed

+144
-75
lines changed

aten/src/THC/generic/THCTensor.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,20 @@ THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage
280280
return self;
281281
}
282282

283+
// Collapses the first two dimensions of a tensor
284+
THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input) {
285+
int in_dims = THCTensor_(nDimension)(state, input);
286+
THArgCheck(in_dims >= 2, 1, "Tensor needs to have at least two dimensions");
287+
THLongStorage *newSize = THLongStorage_newWithSize(in_dims - 1);
288+
newSize->data[0] = THCTensor_(size)(state, input, 0) * THCTensor_(size)(state, input, 1);
289+
for (int i = 2; i < in_dims; i++) {
290+
newSize->data[i - 1] = THCTensor_(size)(state, input, i);
291+
}
292+
THCTensor *output = THCTensor_(newView)(state, input, newSize);
293+
THLongStorage_free(newSize);
294+
return output;
295+
}
296+
283297
/* Resize */
284298
void THCTensor_(resize)(THCState *state, THCTensor *self, THLongStorage *size, THLongStorage *stride)
285299
{

aten/src/THC/generic/THCTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ THC_API THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int
6767
THC_API THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_);
6868
THC_API THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimension_, int64_t size_, int64_t step_);
6969
THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size);
70+
THC_API THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input);
7071
THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THLongStorage *size);
7172

7273
THC_API void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes);

aten/src/THCUNN/generic/VolumetricAveragePooling.cu

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
128128
int dimh = 2;
129129
int dimw = 3;
130130

131-
if (input->nDimension == 5)
131+
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
132+
if (fiveDimensionalInput)
132133
{
133134
dimt++;
134135
dimh++;
@@ -139,7 +140,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
139140
(state, input, NULL, kT, kW, kH, dT, dW, dH,
140141
padT, padW, padH, ceil_mode);
141142

142-
if (THCTensor_(nDimension)(state, input) == 4)
143+
if (!fiveDimensionalInput) /* 4D */
143144
{
144145
/* sizes */
145146
batchSize = 1;
@@ -186,7 +187,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
186187
--outputWidth;
187188
}
188189

189-
if (input->nDimension == 4) /* 4D */
190+
if (!fiveDimensionalInput) /* 4D */
190191
{
191192
/* resize output */
192193
THCTensor_(resize4d)(state, output, inputSlices,
@@ -199,20 +200,21 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
199200
}
200201

201202
input = THCTensor_(newContiguous)(state, input);
203+
if (fiveDimensionalInput) {
204+
// Collapse batch and feature dimensions
205+
output = THCTensor_(newFoldBatchDim)(state, output);
206+
207+
THCTensor *old_input = input;
208+
input = THCTensor_(newFoldBatchDim)(state, input);
209+
THCTensor_(free)(state, old_input);
210+
} else {
211+
THCTensor_(retain)(state, output);
212+
}
202213

203-
// Collapse batch and feature dimensions
204214
THCDeviceTensor<real, 4> cudaInput;
205215
THCDeviceTensor<real, 4> cudaOutput;
206-
if (THCTensor_(nDimension)(state, input) == 4)
207-
{
208-
cudaInput = toDeviceTensor<real, 4>(state, input);
209-
cudaOutput = toDeviceTensor<real, 4>(state, output);
210-
}
211-
else
212-
{
213-
cudaInput = toDeviceTensor<real, 5>(state, input).downcastOuter<4>();
214-
cudaOutput = toDeviceTensor<real, 5>(state, output).downcastOuter<4>();
215-
}
216+
cudaInput = toDeviceTensor<real, 4>(state, input);
217+
cudaOutput = toDeviceTensor<real, 4>(state, output);
216218

217219
int totalZ = outputTime * inputSlices * batchSize;
218220
int offsetZ = 0;
@@ -247,7 +249,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
247249
offsetZ += 65535;
248250
THCudaCheck(cudaGetLastError());
249251
}
252+
250253
THCTensor_(free)(state, input);
254+
THCTensor_(free)(state, output);
251255
}
252256

253257
void THNN_(VolumetricAveragePooling_updateGradInput)(
@@ -280,7 +284,8 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
280284
int outputHeight;
281285
int outputWidth;
282286

283-
if (THCTensor_(nDimension)(state, input) == 4) /* 4D */
287+
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
288+
if (!fiveDimensionalInput) /* 4D */
284289
{
285290
batchSize = 1;
286291
inputSlices = THCTensor_(size)(state, input, 0);
@@ -306,22 +311,21 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
306311
}
307312

308313
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
314+
if (fiveDimensionalInput) {
315+
// Collapse batch and feature dimensions
316+
gradInput = THCTensor_(newFoldBatchDim)(state, gradInput);
317+
318+
THCTensor *old_gradOutput = gradOutput;
319+
gradOutput = THCTensor_(newFoldBatchDim)(state, gradOutput);
320+
THCTensor_(free)(state, old_gradOutput);
321+
} else {
322+
THCTensor_(retain)(state, gradInput);
323+
}
309324

310-
// Collapse batch and feature dimensions
311325
THCDeviceTensor<real, 4> cudaGradInput;
312326
THCDeviceTensor<real, 4> cudaGradOutput;
313-
if (THCTensor_(nDimension)(state, input) == 4)
314-
{
315-
cudaGradInput = toDeviceTensor<real, 4>(state, gradInput);
316-
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
317-
}
318-
else
319-
{
320-
cudaGradInput =
321-
toDeviceTensor<real, 5>(state, gradInput).downcastOuter<4>();
322-
cudaGradOutput =
323-
toDeviceTensor<real, 5>(state, gradOutput).downcastOuter<4>();
324-
}
327+
cudaGradInput = toDeviceTensor<real, 4>(state, gradInput);
328+
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
325329

326330
dim3 block(32, 8);
327331

@@ -372,6 +376,7 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
372376
}
373377
}
374378

379+
THCTensor_(free)(state, gradInput);
375380
THCTensor_(free)(state, gradOutput);
376381
}
377382

aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
141141
int dimh = 2;
142142
int dimw = 3;
143143

144-
if (input->nDimension == 5)
144+
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
145+
146+
if (fiveDimensionalInput)
145147
{
146148
dimt++;
147149
dimh++;
@@ -163,7 +165,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
163165
inputHeight = THCTensor_(size)(state, input, 2);
164166
inputWidth = THCTensor_(size)(state, input, 3);
165167
}
166-
else if (THCTensor_(nDimension)(state, input) == 5)
168+
else if (fiveDimensionalInput)
167169
{
168170
/* sizes */
169171
batchSize = THCTensor_(size)(state, input, 0);
@@ -200,7 +202,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
200202
--outputWidth;
201203
}
202204

203-
if (input->nDimension == 4) /* 4D */
205+
if (!fiveDimensionalInput) /* 4D */
204206
{
205207
/* resize output */
206208
THCTensor_(resize4d)(state, output, inputSlices,
@@ -217,23 +219,25 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
217219
// Index tensor packs index offsets as uchars into floats
218220
THCIndexTensor_(resize5d)(state, indices, batchSize, inputSlices,
219221
outputTime, outputHeight, outputWidth);
222+
fiveDimensionalInput = 1;
220223
}
221224

222225
input = THCTensor_(newContiguous)(state, input);
226+
if (fiveDimensionalInput) {
227+
// Collapse batch and feature dimensions
228+
output = THCTensor_(newFoldBatchDim)(state, output);
229+
230+
THCTensor *old_input = input;
231+
input = THCTensor_(newFoldBatchDim)(state, input);
232+
THCTensor_(free)(state, old_input);
233+
} else {
234+
THCTensor_(retain)(state, output);
235+
}
223236

224-
// Collapse batch and feature dimensions
225237
THCDeviceTensor<real, 4> cudaInput;
226238
THCDeviceTensor<real, 4> cudaOutput;
227-
if (THCTensor_(nDimension)(state, input) == 4)
228-
{
229-
cudaInput = toDeviceTensor<real, 4>(state, input);
230-
cudaOutput = toDeviceTensor<real, 4>(state, output);
231-
}
232-
else
233-
{
234-
cudaInput = toDeviceTensor<real, 5>(state, input).downcastOuter<4>();
235-
cudaOutput = toDeviceTensor<real, 5>(state, output).downcastOuter<4>();
236-
}
239+
cudaInput = toDeviceTensor<real, 4>(state, input);
240+
cudaOutput = toDeviceTensor<real, 4>(state, output);
237241

238242
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
239243
int64_t indicesSizeRaw[4] = { batchSize * inputSlices,
@@ -281,6 +285,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
281285
}
282286

283287
THCTensor_(free)(state, input);
288+
THCTensor_(free)(state, output);
284289
THCIndexTensor_(free)(state, indices1);
285290
}
286291

@@ -310,13 +315,15 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
310315
int outputHeight;
311316
int outputWidth;
312317

318+
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
319+
313320
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
314321
THNN_(VolumetricDilatedMaxPooling_shapeCheck)(
315322
state, input, gradOutput, indices, kT, kW, kH,
316323
dT, dW, dH, padT, padW, padH,
317324
dilationT, dilationW, dilationH, ceilMode);
318325

319-
if (THCTensor_(nDimension)(state, input) == 4) /* 4D */
326+
if (!fiveDimensionalInput) /* 4D */
320327
{
321328
batchSize = 1;
322329
inputSlices = THCTensor_(size)(state, input, 0);
@@ -336,22 +343,21 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
336343
}
337344

338345
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
346+
if (fiveDimensionalInput) {
347+
// Collapse batch and feature dimensions
348+
gradInput = THCTensor_(newFoldBatchDim)(state, gradInput);
349+
350+
THCTensor *old_gradOutput = gradOutput;
351+
gradOutput = THCTensor_(newFoldBatchDim)(state, gradOutput);
352+
THCTensor_(free)(state, old_gradOutput);
353+
} else {
354+
THCTensor_(retain)(state, gradInput);
355+
}
339356

340-
// Collapse batch and feature dimensions
341357
THCDeviceTensor<real, 4> cudaGradInput;
342358
THCDeviceTensor<real, 4> cudaGradOutput;
343-
if (THCTensor_(nDimension)(state, input) == 4)
344-
{
345-
cudaGradInput = toDeviceTensor<real, 4>(state, gradInput);
346-
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
347-
}
348-
else
349-
{
350-
cudaGradInput =
351-
toDeviceTensor<real, 5>(state, gradInput).downcastOuter<4>();
352-
cudaGradOutput =
353-
toDeviceTensor<real, 5>(state, gradOutput).downcastOuter<4>();
354-
}
359+
cudaGradInput = toDeviceTensor<real, 4>(state, gradInput);
360+
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
355361

356362
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
357363
int64_t indicesSizeRaw[4] = { batchSize * inputSlices,
@@ -388,6 +394,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
388394
}
389395

390396
// cleanup
397+
THCTensor_(free)(state, gradInput);
391398
THCTensor_(free)(state, gradOutput);
392399
THCIndexTensor_(free)(state, indices1);
393400
}

0 commit comments

Comments
 (0)