Skip to content

Commit d7da504

Browse files
yongjiksoumith
authored andcommitted
Add check for slice shape match in index_copy_ and index_add_. (#4342)
Emits a warning if slices have the same size but different shapes. (It shouldn't be allowed, but it was, so some code might be unknowingly depending on the behavior.) Also refactored argument checking code, including index_fill_.
1 parent 5b91b24 commit d7da504

File tree

1 file changed

+72
-41
lines changed

1 file changed

+72
-41
lines changed

aten/src/THC/generic/THCTensorIndex.cu

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
#define THC_GENERIC_FILE "generic/THCTensorIndex.cu"
33
#else
44

5+
// Check tensor dimensions for index operations, and return the slice size.
6+
// src can be nullptr in case of indexFill: in that case it is ignored.
7+
static ptrdiff_t THCTensor_(getSliceSize)(THCState *state, THCTensor *dst,
8+
int dim,
9+
THCudaLongTensor *index,
10+
THCTensor *src)
11+
{
12+
int dstDims = THCTensor_(nDimension)(state, dst);
13+
int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimension)(state, src);
14+
15+
THArgCheck(THCudaLongTensor_nDimension(state, index) == 1, 4,
16+
"expecting vector of indices");
17+
THArgCheck(dim >= 0 && dim < dstDims, 2, "Indexing dim is out of bounds");
18+
19+
ptrdiff_t dstSliceSize = 1;
20+
for (int d = 0; d < dstDims; d++) {
21+
if (d != dim) {
22+
dstSliceSize *= dst->size[d];
23+
}
24+
}
25+
26+
if (src == nullptr) return dstSliceSize;
27+
28+
THArgCheck(dim < srcDims, 3, "Indexing dim is out of bounds");
29+
THArgCheck(THCudaLongTensor_nElement(state, index) == src->size[dim], 4,
30+
"length of src.size[dim] is not equal to length of indices");
31+
32+
ptrdiff_t srcSliceSize = 1;
33+
bool mismatch = false;
34+
35+
if (dstDims != srcDims) mismatch = true;
36+
37+
for (int d = 0; d < srcDims; d++) {
38+
if (d != dim) {
39+
srcSliceSize *= src->size[d];
40+
if (!mismatch && dst->size[d] != src->size[d]) mismatch = true;
41+
}
42+
}
43+
44+
THArgCheck(dstSliceSize == srcSliceSize, 2,
45+
"Source/destination tensor have different slice sizes (%ld vs %ld)",
46+
dstSliceSize, srcSliceSize);
47+
48+
if (mismatch) {
49+
static bool warningShown = false;
50+
if (!warningShown) {
51+
warningShown = true;
52+
fprintf(stderr,
53+
"Warning: source/destination slices have same size but different "
54+
"shape for an index operation. This behavior is deprecated.\n");
55+
}
56+
}
57+
58+
return dstSliceSize;
59+
}
60+
561
void THCTensor_(indexCopy_long)(THCState *state, THCTensor *dst, int dim, THLongTensor *indices, THCTensor *src)
662
{
763
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
@@ -26,27 +82,18 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
2682
dims = THCudaLongTensor_nDimension(state, indices);
2783
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
2884

29-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
30-
31-
int srcDims = THCTensor_(nDimension)(state, src);
32-
cudaStream_t stream = THCState_getCurrentStream(state);
33-
34-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
35-
"expecting vector of indices");
36-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
37-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
38-
THArgCheck(numIndices == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices");
39-
40-
int indContig = THCudaLongTensor_isContiguous(state, indices);
41-
4285
// The `src` is partitioned into two parts:
4386
// -the size of each slice we are indexing, which is the
4487
// total size of the tensor ignoring dimension `dim`;
4588
// -the number of indices we are choosing, which is the total size
4689
// of the tensor `indices`.
90+
ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src);
4791
ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src);
4892
int64_t dstCopyDimSize = THCTensor_(size)(state, dst, dim);
49-
ptrdiff_t sliceSize = srcTotalSize / numIndices;
93+
94+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
95+
cudaStream_t stream = THCState_getCurrentStream(state);
96+
int indContig = THCudaLongTensor_isContiguous(state, indices);
5097

5198
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
5299

@@ -221,27 +268,18 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe
221268
dims = THCudaLongTensor_nDimension(state, indices);
222269
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
223270

224-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
225-
226-
int srcDims = THCTensor_(nDimension)(state, src);
227-
cudaStream_t stream = THCState_getCurrentStream(state);
228-
229-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
230-
"expecting vector of indices");
231-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
232-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
233-
THArgCheck(numIndices == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices");
234-
235-
int indContig = THCudaLongTensor_isContiguous(state, indices);
236-
237271
// The `src` is partitioned into two parts:
238272
// -the size of each slice we are indexing, which is the
239273
// total size of the tensor ignoring dimension `dim`;
240274
// -the number of indices we are choosing, which is the total size
241275
// of the tensor `indices`.
276+
ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src);
242277
ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src);
243278
int64_t dstAddDimSize = THCTensor_(size)(state, dst, dim);
244-
ptrdiff_t sliceSize = srcTotalSize / numIndices;
279+
280+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
281+
cudaStream_t stream = THCState_getCurrentStream(state);
282+
int indContig = THCudaLongTensor_isContiguous(state, indices);
245283

246284
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
247285

@@ -346,26 +384,19 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT
346384
dims = THCudaLongTensor_nDimension(state, indices);
347385
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
348386

349-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
350-
351-
int srcDims = THCTensor_(nDimension)(state, dst);
352-
cudaStream_t stream = THCState_getCurrentStream(state);
353-
354-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
355-
"expecting vector of indices");
356-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
357-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
358-
359-
int indContig = THCudaLongTensor_isContiguous(state, indices);
360-
361387
// The `src` is partitioned into two parts:
362388
// -the size of each slice we are indexing, which is the
363389
// total size of the tensor ignoring dimension `dim`;
364390
// -the number of indices we are choosing, which is the total size
365391
// of the tensor `indices`.
392+
ptrdiff_t sliceSize =
393+
THCTensor_(getSliceSize)(state, dst, dim, indices, nullptr);
366394
ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst);
367395
int64_t dstFillDimSize = THCTensor_(size)(state, dst, dim);
368-
ptrdiff_t sliceSize = dstTotalSize / dstFillDimSize;
396+
397+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
398+
cudaStream_t stream = THCState_getCurrentStream(state);
399+
int indContig = THCudaLongTensor_isContiguous(state, indices);
369400

370401
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
371402

0 commit comments

Comments
 (0)