Skip to content

Commit 6ff04fb

Browse files
yongjiksoumith
authored andcommitted
Add check for slice shape match in index_copy_ and index_add_. (#4342)
Emits a warning if slices have the same size but different shapes. (It shouldn't be allowed, but it was, so some code might be unknowingly depending on the behavior.) Also refactored argument checking code, including index_fill_.
1 parent ab5b03e commit 6ff04fb

File tree

1 file changed

+72
-41
lines changed

1 file changed

+72
-41
lines changed

torch/lib/THC/generic/THCTensorIndex.cu

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
#define THC_GENERIC_FILE "generic/THCTensorIndex.cu"
33
#else
44

5+
// Check tensor dimensions for index operations, and return the slice size.
6+
// src can be nullptr in case of indexFill: in that case it is ignored.
7+
static ptrdiff_t THCTensor_(getSliceSize)(THCState *state, THCTensor *dst,
8+
int dim,
9+
THCudaLongTensor *index,
10+
THCTensor *src)
11+
{
12+
int dstDims = THCTensor_(nDimension)(state, dst);
13+
int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimension)(state, src);
14+
15+
THArgCheck(THCudaLongTensor_nDimension(state, index) == 1, 4,
16+
"expecting vector of indices");
17+
THArgCheck(dim >= 0 && dim < dstDims, 2, "Indexing dim is out of bounds");
18+
19+
ptrdiff_t dstSliceSize = 1;
20+
for (int d = 0; d < dstDims; d++) {
21+
if (d != dim) {
22+
dstSliceSize *= dst->size[d];
23+
}
24+
}
25+
26+
if (src == nullptr) return dstSliceSize;
27+
28+
THArgCheck(dim < srcDims, 3, "Indexing dim is out of bounds");
29+
THArgCheck(THCudaLongTensor_nElement(state, index) == src->size[dim], 4,
30+
"length of src.size[dim] is not equal to length of indices");
31+
32+
ptrdiff_t srcSliceSize = 1;
33+
bool mismatch = false;
34+
35+
if (dstDims != srcDims) mismatch = true;
36+
37+
for (int d = 0; d < srcDims; d++) {
38+
if (d != dim) {
39+
srcSliceSize *= src->size[d];
40+
if (!mismatch && dst->size[d] != src->size[d]) mismatch = true;
41+
}
42+
}
43+
44+
THArgCheck(dstSliceSize == srcSliceSize, 2,
45+
"Source/destination tensor have different slice sizes (%ld vs %ld)",
46+
dstSliceSize, srcSliceSize);
47+
48+
if (mismatch) {
49+
static bool warningShown = false;
50+
if (!warningShown) {
51+
warningShown = true;
52+
fprintf(stderr,
53+
"Warning: source/destination slices have same size but different "
54+
"shape for an index operation. This behavior is deprecated.\n");
55+
}
56+
}
57+
58+
return dstSliceSize;
59+
}
60+
561
void THCTensor_(indexCopy_long)(THCState *state, THCTensor *dst, int dim, THLongTensor *indices, THCTensor *src)
662
{
763
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
@@ -26,27 +82,18 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
2682
dims = THCudaLongTensor_nDimension(state, indices);
2783
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
2884

29-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
30-
31-
int srcDims = THCTensor_(nDimension)(state, src);
32-
cudaStream_t stream = THCState_getCurrentStream(state);
33-
34-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
35-
"expecting vector of indices");
36-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
37-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
38-
THArgCheck(numIndices == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices");
39-
40-
int indContig = THCudaLongTensor_isContiguous(state, indices);
41-
4285
// The `src` is partitioned into two parts:
4386
// -the size of each slice we are indexing, which is the
4487
// total size of the tensor ignoring dimension `dim`;
4588
// -the number of indices we are choosing, which is the total size
4689
// of the tensor `indices`.
90+
ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src);
4791
ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src);
4892
int64_t dstCopyDimSize = THCTensor_(size)(state, dst, dim);
49-
ptrdiff_t sliceSize = srcTotalSize / numIndices;
93+
94+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
95+
cudaStream_t stream = THCState_getCurrentStream(state);
96+
int indContig = THCudaLongTensor_isContiguous(state, indices);
5097

5198
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
5299

@@ -216,27 +263,18 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe
216263
dims = THCudaLongTensor_nDimension(state, indices);
217264
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
218265

219-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
220-
221-
int srcDims = THCTensor_(nDimension)(state, src);
222-
cudaStream_t stream = THCState_getCurrentStream(state);
223-
224-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
225-
"expecting vector of indices");
226-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
227-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
228-
THArgCheck(numIndices == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices");
229-
230-
int indContig = THCudaLongTensor_isContiguous(state, indices);
231-
232266
// The `src` is partitioned into two parts:
233267
// -the size of each slice we are indexing, which is the
234268
// total size of the tensor ignoring dimension `dim`;
235269
// -the number of indices we are choosing, which is the total size
236270
// of the tensor `indices`.
271+
ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src);
237272
ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src);
238273
int64_t dstAddDimSize = THCTensor_(size)(state, dst, dim);
239-
ptrdiff_t sliceSize = srcTotalSize / numIndices;
274+
275+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
276+
cudaStream_t stream = THCState_getCurrentStream(state);
277+
int indContig = THCudaLongTensor_isContiguous(state, indices);
240278

241279
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
242280

@@ -341,26 +379,19 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT
341379
dims = THCudaLongTensor_nDimension(state, indices);
342380
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
343381

344-
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
345-
346-
int srcDims = THCTensor_(nDimension)(state, dst);
347-
cudaStream_t stream = THCState_getCurrentStream(state);
348-
349-
THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
350-
"expecting vector of indices");
351-
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
352-
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
353-
354-
int indContig = THCudaLongTensor_isContiguous(state, indices);
355-
356382
// The `src` is partitioned into two parts:
357383
// -the size of each slice we are indexing, which is the
358384
// total size of the tensor ignoring dimension `dim`;
359385
// -the number of indices we are choosing, which is the total size
360386
// of the tensor `indices`.
387+
ptrdiff_t sliceSize =
388+
THCTensor_(getSliceSize)(state, dst, dim, indices, nullptr);
361389
ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst);
362390
int64_t dstFillDimSize = THCTensor_(size)(state, dst, dim);
363-
ptrdiff_t sliceSize = dstTotalSize / dstFillDimSize;
391+
392+
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
393+
cudaStream_t stream = THCState_getCurrentStream(state);
394+
int indContig = THCudaLongTensor_isContiguous(state, indices);
364395

365396
int mpc = THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
366397

0 commit comments

Comments
 (0)