22#define THC_GENERIC_FILE " generic/THCTensorIndex.cu"
33#else
44
5+ // Check tensor dimensions for index operations, and return the slice size.
6+ // src can be nullptr in case of indexFill: in that case it is ignored.
7+ static ptrdiff_t THCTensor_ (getSliceSize)(THCState *state, THCTensor *dst,
8+ int dim,
9+ THCudaLongTensor *index,
10+ THCTensor *src)
11+ {
12+ int dstDims = THCTensor_ (nDimension)(state, dst);
13+ int srcDims = (src == nullptr ) ? dstDims : THCTensor_ (nDimension)(state, src);
14+
15+ THArgCheck (THCudaLongTensor_nDimension (state, index) == 1 , 4 ,
16+ " expecting vector of indices" );
17+ THArgCheck (dim >= 0 && dim < dstDims, 2 , " Indexing dim is out of bounds" );
18+
19+ ptrdiff_t dstSliceSize = 1 ;
20+ for (int d = 0 ; d < dstDims; d++) {
21+ if (d != dim) {
22+ dstSliceSize *= dst->size [d];
23+ }
24+ }
25+
26+ if (src == nullptr ) return dstSliceSize;
27+
28+ THArgCheck (dim < srcDims, 3 , " Indexing dim is out of bounds" );
29+ THArgCheck (THCudaLongTensor_nElement (state, index) == src->size [dim], 4 ,
30+ " length of src.size[dim] is not equal to length of indices" );
31+
32+ ptrdiff_t srcSliceSize = 1 ;
33+ bool mismatch = false ;
34+
35+ if (dstDims != srcDims) mismatch = true ;
36+
37+ for (int d = 0 ; d < srcDims; d++) {
38+ if (d != dim) {
39+ srcSliceSize *= src->size [d];
40+ if (!mismatch && dst->size [d] != src->size [d]) mismatch = true ;
41+ }
42+ }
43+
44+ THArgCheck (dstSliceSize == srcSliceSize, 2 ,
45+ " Source/destination tensor have different slice sizes (%ld vs %ld)" ,
46+ dstSliceSize, srcSliceSize);
47+
48+ if (mismatch) {
49+ static bool warningShown = false ;
50+ if (!warningShown) {
51+ warningShown = true ;
52+ fprintf (stderr,
53+ " Warning: source/destination slices have same size but different "
54+ " shape for an index operation. This behavior is deprecated.\n " );
55+ }
56+ }
57+
58+ return dstSliceSize;
59+ }
60+
561void THCTensor_ (indexCopy_long)(THCState *state, THCTensor *dst, int dim, THLongTensor *indices, THCTensor *src)
662{
763 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , dst, src));
@@ -26,27 +82,18 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
2682 dims = THCudaLongTensor_nDimension (state, indices);
2783 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
2884
29- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
30-
31- int srcDims = THCTensor_ (nDimension)(state, src);
32- cudaStream_t stream = THCState_getCurrentStream (state);
33-
34- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
35- " expecting vector of indices" );
36- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
37- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
38- THArgCheck (numIndices == src->size [dim], 4 , " length of src.size[dim] is not equal to length of indices" );
39-
40- int indContig = THCudaLongTensor_isContiguous (state, indices);
41-
4285 // The `src` is partitioned into two parts:
4386 // -the size of each slice we are indexing, which is the
4487 // total size of the tensor ignoring dimension `dim`;
4588 // -the number of indices we are choosing, which is the total size
4689 // of the tensor `indices`.
90+ ptrdiff_t sliceSize = THCTensor_ (getSliceSize)(state, dst, dim, indices, src);
4791 ptrdiff_t srcTotalSize = THCTensor_ (nElement)(state, src);
4892 int64_t dstCopyDimSize = THCTensor_ (size)(state, dst, dim);
49- ptrdiff_t sliceSize = srcTotalSize / numIndices;
93+
94+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
95+ cudaStream_t stream = THCState_getCurrentStream (state);
96+ int indContig = THCudaLongTensor_isContiguous (state, indices);
5097
5198 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
5299
@@ -216,27 +263,18 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe
216263 dims = THCudaLongTensor_nDimension (state, indices);
217264 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
218265
219- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
220-
221- int srcDims = THCTensor_ (nDimension)(state, src);
222- cudaStream_t stream = THCState_getCurrentStream (state);
223-
224- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
225- " expecting vector of indices" );
226- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
227- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
228- THArgCheck (numIndices == src->size [dim], 4 , " length of src.size[dim] is not equal to length of indices" );
229-
230- int indContig = THCudaLongTensor_isContiguous (state, indices);
231-
232266 // The `src` is partitioned into two parts:
233267 // -the size of each slice we are indexing, which is the
234268 // total size of the tensor ignoring dimension `dim`;
235269 // -the number of indices we are choosing, which is the total size
236270 // of the tensor `indices`.
271+ ptrdiff_t sliceSize = THCTensor_ (getSliceSize)(state, dst, dim, indices, src);
237272 ptrdiff_t srcTotalSize = THCTensor_ (nElement)(state, src);
238273 int64_t dstAddDimSize = THCTensor_ (size)(state, dst, dim);
239- ptrdiff_t sliceSize = srcTotalSize / numIndices;
274+
275+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
276+ cudaStream_t stream = THCState_getCurrentStream (state);
277+ int indContig = THCudaLongTensor_isContiguous (state, indices);
240278
241279 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
242280
@@ -341,26 +379,19 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT
341379 dims = THCudaLongTensor_nDimension (state, indices);
342380 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
343381
344- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
345-
346- int srcDims = THCTensor_ (nDimension)(state, dst);
347- cudaStream_t stream = THCState_getCurrentStream (state);
348-
349- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
350- " expecting vector of indices" );
351- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
352- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
353-
354- int indContig = THCudaLongTensor_isContiguous (state, indices);
355-
356382 // The `src` is partitioned into two parts:
357383 // -the size of each slice we are indexing, which is the
358384 // total size of the tensor ignoring dimension `dim`;
359385 // -the number of indices we are choosing, which is the total size
360386 // of the tensor `indices`.
387+ ptrdiff_t sliceSize =
388+ THCTensor_ (getSliceSize)(state, dst, dim, indices, nullptr );
361389 ptrdiff_t dstTotalSize = THCTensor_ (nElement)(state, dst);
362390 int64_t dstFillDimSize = THCTensor_ (size)(state, dst, dim);
363- ptrdiff_t sliceSize = dstTotalSize / dstFillDimSize;
391+
392+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
393+ cudaStream_t stream = THCState_getCurrentStream (state);
394+ int indContig = THCudaLongTensor_isContiguous (state, indices);
364395
365396 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
366397
0 commit comments