Skip to content

Commit 2108b42

Browse files
committed
Fix bug in cat when dimension is not specified.
- Code was using dimension specified which was negative - Changed the cat_dimension variable to be more explicit - Fixed code to use the cat_dimension variable
1 parent bae8df6 commit 2108b42

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

generic/THCTensorMath.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
8787
// loop below will overwrite the value
8888
int maxDim = dimension + 1;
8989

90-
// ldimension is the actual dimension we cat along (minus 1, for 0-based indexing)
91-
int ldimension = dimension;
90+
// cat_dimension is the actual dimension we cat along
91+
int cat_dimension = dimension;
9292

9393
for (i = 0; i < numInputs; i++)
9494
{
@@ -100,13 +100,13 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
100100
// In the event that the user specified -1 as the concat dimension, then
101101
// we want to pick the maxDim as dimension to cat along (and thus maxDim - 1 as the
102102
// value due to 0-based indexing). If the maxDim is // 0 (i.e. we are catting all
103-
// empty tensors), then we set ldimension to be 0
103+
// empty tensors), then we set cat_dimension to be 0
104104
if (dimension + TH_INDEX_BASE == -1) {
105-
ldimension = maxDim ? (maxDim - 1) : 0;
105+
cat_dimension = maxDim ? (maxDim - 1) : 0;
106106
}
107107

108108
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
109-
THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
109+
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
110110

111111
size = THLongStorage_newWithSize(maxDim);
112112
for(i = 0; i < maxDim; i++)
@@ -115,7 +115,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
115115
long dimSize = i < THCTensor_(nDimension)(state, inputs[0])
116116
? THCTensor_(size)(state, inputs[0], i)
117117
: THMin(THCTensor_(nDimension)(state, inputs[0]), 1);
118-
if (i == ldimension)
118+
if (i == cat_dimension)
119119
{
120120
for (j = 1; j < numInputs; j++)
121121
{
@@ -203,15 +203,15 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
203203

204204
// Template Declarations for dim = 1, 2, 3, 4
205205
#define HANDLE_CASE(DIMS) \
206-
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock>>>(data, d_inputs, param, ldimension, param.outputStride[dimension]);
206+
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
207207

208208
// Now we loop
209209
offset = 0;
210210
for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
211211
cohortMax = 0;
212212
for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
213-
long dimSize = ldimension < THCTensor_(nDimension)(state, inputs[i+j])
214-
? THCTensor_(size)(state, inputs[i+j], ldimension)
213+
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j])
214+
? THCTensor_(size)(state, inputs[i+j], cat_dimension)
215215
: 1;
216216

217217
stackInputs[j].input = THCTensor_(data)(state, inputs[i+j]);
@@ -267,12 +267,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
267267
// No reason to copy when input is empty
268268
if (!THCTensor_(nDimension)(state, inputs[j])) continue;
269269

270-
long dimSize = ldimension < THCTensor_(nDimension)(state, inputs[j])
271-
? THCTensor_(size)(state, inputs[j], ldimension)
270+
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[j])
271+
? THCTensor_(size)(state, inputs[j], cat_dimension)
272272
: 1;
273273

274274
THCTensor *nt = THCTensor_(newWithTensor)(state, result);
275-
THCTensor_(narrow)(state, nt, NULL, ldimension, offset, dimSize);
275+
THCTensor_(narrow)(state, nt, NULL, cat_dimension, offset, dimSize);
276276
THCTensor_(copy)(state, nt, inputs[j]);
277277
THCTensor_(free)(state, nt);
278278
offset += dimSize;

0 commit comments

Comments
 (0)