22#define THC_GENERIC_FILE " generic/THCTensorIndex.cu"
33#else
44
5+ // Check tensor dimensions for index operations, and return the slice size.
6+ // src can be nullptr in case of indexFill: in that case it is ignored.
7+ static ptrdiff_t THCTensor_ (getSliceSize)(THCState *state, THCTensor *dst,
8+ int dim,
9+ THCudaLongTensor *index,
10+ THCTensor *src)
11+ {
12+ int dstDims = THCTensor_ (nDimension)(state, dst);
13+ int srcDims = (src == nullptr ) ? dstDims : THCTensor_ (nDimension)(state, src);
14+
15+ THArgCheck (THCudaLongTensor_nDimension (state, index) == 1 , 4 ,
16+ " expecting vector of indices" );
17+ THArgCheck (dim >= 0 && dim < dstDims, 2 , " Indexing dim is out of bounds" );
18+
19+ ptrdiff_t dstSliceSize = 1 ;
20+ for (int d = 0 ; d < dstDims; d++) {
21+ if (d != dim) {
22+ dstSliceSize *= dst->size [d];
23+ }
24+ }
25+
26+ if (src == nullptr ) return dstSliceSize;
27+
28+ THArgCheck (dim < srcDims, 3 , " Indexing dim is out of bounds" );
29+ THArgCheck (THCudaLongTensor_nElement (state, index) == src->size [dim], 4 ,
30+ " length of src.size[dim] is not equal to length of indices" );
31+
32+ ptrdiff_t srcSliceSize = 1 ;
33+ bool mismatch = false ;
34+
35+ if (dstDims != srcDims) mismatch = true ;
36+
37+ for (int d = 0 ; d < srcDims; d++) {
38+ if (d != dim) {
39+ srcSliceSize *= src->size [d];
40+ if (!mismatch && dst->size [d] != src->size [d]) mismatch = true ;
41+ }
42+ }
43+
44+ THArgCheck (dstSliceSize == srcSliceSize, 2 ,
45+ " Source/destination tensor have different slice sizes (%ld vs %ld)" ,
46+ dstSliceSize, srcSliceSize);
47+
48+ if (mismatch) {
49+ static bool warningShown = false ;
50+ if (!warningShown) {
51+ warningShown = true ;
52+ fprintf (stderr,
53+ " Warning: source/destination slices have same size but different "
54+ " shape for an index operation. This behavior is deprecated.\n " );
55+ }
56+ }
57+
58+ return dstSliceSize;
59+ }
60+
561void THCTensor_ (indexCopy_long)(THCState *state, THCTensor *dst, int dim, THLongTensor *indices, THCTensor *src)
662{
763 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , dst, src));
@@ -26,27 +82,18 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
2682 dims = THCudaLongTensor_nDimension (state, indices);
2783 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
2884
29- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
30-
31- int srcDims = THCTensor_ (nDimension)(state, src);
32- cudaStream_t stream = THCState_getCurrentStream (state);
33-
34- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
35- " expecting vector of indices" );
36- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
37- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
38- THArgCheck (numIndices == src->size [dim], 4 , " length of src.size[dim] is not equal to length of indices" );
39-
40- int indContig = THCudaLongTensor_isContiguous (state, indices);
41-
4285 // The `src` is partitioned into two parts:
4386 // -the size of each slice we are indexing, which is the
4487 // total size of the tensor ignoring dimension `dim`;
4588 // -the number of indices we are choosing, which is the total size
4689 // of the tensor `indices`.
90+ ptrdiff_t sliceSize = THCTensor_ (getSliceSize)(state, dst, dim, indices, src);
4791 ptrdiff_t srcTotalSize = THCTensor_ (nElement)(state, src);
4892 int64_t dstCopyDimSize = THCTensor_ (size)(state, dst, dim);
49- ptrdiff_t sliceSize = srcTotalSize / numIndices;
93+
94+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
95+ cudaStream_t stream = THCState_getCurrentStream (state);
96+ int indContig = THCudaLongTensor_isContiguous (state, indices);
5097
5198 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
5299
@@ -221,27 +268,18 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe
221268 dims = THCudaLongTensor_nDimension (state, indices);
222269 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
223270
224- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
225-
226- int srcDims = THCTensor_ (nDimension)(state, src);
227- cudaStream_t stream = THCState_getCurrentStream (state);
228-
229- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
230- " expecting vector of indices" );
231- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
232- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
233- THArgCheck (numIndices == src->size [dim], 4 , " length of src.size[dim] is not equal to length of indices" );
234-
235- int indContig = THCudaLongTensor_isContiguous (state, indices);
236-
237271 // The `src` is partitioned into two parts:
238272 // -the size of each slice we are indexing, which is the
239273 // total size of the tensor ignoring dimension `dim`;
240274 // -the number of indices we are choosing, which is the total size
241275 // of the tensor `indices`.
276+ ptrdiff_t sliceSize = THCTensor_ (getSliceSize)(state, dst, dim, indices, src);
242277 ptrdiff_t srcTotalSize = THCTensor_ (nElement)(state, src);
243278 int64_t dstAddDimSize = THCTensor_ (size)(state, dst, dim);
244- ptrdiff_t sliceSize = srcTotalSize / numIndices;
279+
280+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
281+ cudaStream_t stream = THCState_getCurrentStream (state);
282+ int indContig = THCudaLongTensor_isContiguous (state, indices);
245283
246284 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
247285
@@ -346,26 +384,19 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT
346384 dims = THCudaLongTensor_nDimension (state, indices);
347385 THArgCheck (dims <= MAX_CUTORCH_DIMS, 4 , CUTORCH_DIM_WARNING);
348386
349- ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
350-
351- int srcDims = THCTensor_ (nDimension)(state, dst);
352- cudaStream_t stream = THCState_getCurrentStream (state);
353-
354- THArgCheck (THCudaLongTensor_nDimension (state, indices) == 1 , 3 ,
355- " expecting vector of indices" );
356- THArgCheck (dim < srcDims, 4 , " Indexing dim is out of bounds" );
357- THArgCheck (srcDims > 0 , 2 , " Source tensor is empty" );
358-
359- int indContig = THCudaLongTensor_isContiguous (state, indices);
360-
361387 // The `src` is partitioned into two parts:
362388 // -the size of each slice we are indexing, which is the
363389 // total size of the tensor ignoring dimension `dim`;
364390 // -the number of indices we are choosing, which is the total size
365391 // of the tensor `indices`.
392+ ptrdiff_t sliceSize =
393+ THCTensor_ (getSliceSize)(state, dst, dim, indices, nullptr );
366394 ptrdiff_t dstTotalSize = THCTensor_ (nElement)(state, dst);
367395 int64_t dstFillDimSize = THCTensor_ (size)(state, dst, dim);
368- ptrdiff_t sliceSize = dstTotalSize / dstFillDimSize;
396+
397+ ptrdiff_t numIndices = THCudaLongTensor_nElement (state, indices);
398+ cudaStream_t stream = THCState_getCurrentStream (state);
399+ int indContig = THCudaLongTensor_isContiguous (state, indices);
369400
370401 int mpc = THCState_getCurrentDeviceProperties (state)->multiProcessorCount ;
371402
0 commit comments