@@ -66,21 +66,30 @@ __global__ void indexCopySmallIndex(TensorInfo<T, IndexType> dst,
6666// the number of indices chosen is small, then the
6767// indexCopySmallIndex kernel is a better choice to reduce memory
6868// accesses.
69- template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
69+ template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
70+ bool IndexIsMajor>
7071__global__ void indexCopyLargeIndex (TensorInfo<T, IndexType> dst,
7172 TensorInfo<T, IndexType> src,
7273 TensorInfo<int64_t , IndexType> indices,
7374 int dstCopyDim,
7475 int srcCopyDim,
76+ IndexType totalSize,
7577 IndexType innerSize,
7678 int64_t dstCopyDimSize) {
7779 // We stride over the output including the indexed dimension
7880 // (totalSize), and calculate the destination index point based on that
7981 for (IndexType linearIndex = blockIdx .x * blockDim .x + threadIdx .x ;
80- linearIndex < innerSize * indices. sizes [ 0 ] ;
82+ linearIndex < totalSize ;
8183 linearIndex += gridDim .x * blockDim .x ) {
82- IndexType srcIndex = linearIndex / innerSize;
83- IndexType elementInSlice = linearIndex % innerSize;
84+ IndexType srcIndex, elementInSlice;
85+ if (IndexIsMajor) {
86+ srcIndex = linearIndex / innerSize;
87+ elementInSlice = linearIndex % innerSize;
88+ }
89+ else {
90+ elementInSlice = linearIndex / innerSize;
91+ srcIndex = linearIndex % innerSize;
92+ }
8493
8594 // Lua indices begin at 1
8695 IndexType dstIndex =
@@ -148,21 +157,30 @@ __global__ void indexAddSmallIndex(TensorInfo<T, IndexType> dst,
148157// the number of indices chosen is small, then the
149158// indexAddSmallIndex kernel is a better choice to reduce memory
150159// accesses.
151- template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
160+ template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
161+ bool IndexIsMajor>
152162__global__ void indexAddLargeIndex (TensorInfo<T, IndexType> dst,
153163 TensorInfo<T, IndexType> src,
154164 TensorInfo<int64_t , IndexType> indices,
155165 int dstAddDim,
156166 int srcAddDim,
167+ IndexType totalSize,
157168 IndexType innerSize,
158169 int64_t dstAddDimSize) {
159170 // We stride over the output including the indexed dimension
160171 // (totalSize), and calculate the destination index point based on that
161172 for (IndexType linearIndex = blockIdx .x * blockDim .x + threadIdx .x ;
162- linearIndex < innerSize * indices. sizes [ 0 ] ;
173+ linearIndex < totalSize ;
163174 linearIndex += gridDim .x * blockDim .x ) {
164- IndexType srcIndex = linearIndex / innerSize;
165- IndexType elementInSlice = linearIndex % innerSize;
175+ IndexType srcIndex, elementInSlice;
176+ if (IndexIsMajor) {
177+ srcIndex = linearIndex / innerSize;
178+ elementInSlice = linearIndex % innerSize;
179+ }
180+ else {
181+ elementInSlice = linearIndex / innerSize;
182+ srcIndex = linearIndex % innerSize;
183+ }
166184
167185 // Lua indices begin at 1
168186 IndexType dstIndex =
@@ -225,20 +243,29 @@ __global__ void indexFillSmallIndex(TensorInfo<T, IndexType> dst,
225243// the number of indices chosen is small, then the
226244// indexFillSmallIndex kernel is a better choice to reduce memory
227245// accesses.
228- template <typename T, typename IndexType, int DstDim, int IdxDim>
246+ template <typename T, typename IndexType, int DstDim, int IdxDim,
247+ bool IndexIsMajor>
229248__global__ void indexFillLargeIndex (TensorInfo<T, IndexType> dst,
230249 TensorInfo<int64_t , IndexType> indices,
231250 int dstFillDim,
251+ IndexType totalSize,
232252 IndexType innerSize,
233253 int64_t dstFillDimSize,
234254 T val) {
235255 // We stride over the output including the indexed dimension
236256 // (totalSize), and calculate the destination index point based on that
237257 for (IndexType linearIndex = blockIdx .x * blockDim .x + threadIdx .x ;
238- linearIndex < innerSize * indices. sizes [ 0 ] ;
258+ linearIndex < totalSize ;
239259 linearIndex += gridDim .x * blockDim .x ) {
240- IndexType dstIndex = linearIndex / innerSize;
241- IndexType elementInSlice = linearIndex % innerSize;
260+ IndexType dstIndex, elementInSlice;
261+ if (IndexIsMajor) {
262+ dstIndex = linearIndex / innerSize;
263+ elementInSlice = linearIndex % innerSize;
264+ }
265+ else {
266+ elementInSlice = linearIndex / innerSize;
267+ dstIndex = linearIndex % innerSize;
268+ }
242269
243270 // Lua indices begin at 1
244271 IndexType dstIndex_ =
@@ -302,7 +329,8 @@ __global__ void indexSelectSmallIndex(TensorInfo<T, IndexType> dst,
302329// the number of indices chosen is small, then the
303330// indexSelectSmallIndex kernel is a better choice to reduce memory
304331// accesses.
305- template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
332+ template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
333+ bool IndexIsMajor>
306334__global__ void indexSelectLargeIndex (TensorInfo<T, IndexType> dst,
307335 TensorInfo<T, IndexType> src,
308336 TensorInfo<int64_t , IndexType> indices,
@@ -316,8 +344,15 @@ __global__ void indexSelectLargeIndex(TensorInfo<T, IndexType> dst,
316344 for (IndexType linearIndex = blockIdx .x * blockDim .x + threadIdx .x ;
317345 linearIndex < totalSize;
318346 linearIndex += gridDim .x * blockDim .x ) {
319- IndexType dstIndex = linearIndex / innerSize;
320- IndexType elementInSlice = linearIndex % innerSize;
347+ IndexType dstIndex, elementInSlice;
348+ if (IndexIsMajor) {
349+ dstIndex = linearIndex / innerSize;
350+ elementInSlice = linearIndex % innerSize;
351+ }
352+ else {
353+ elementInSlice = linearIndex / innerSize;
354+ dstIndex = linearIndex % innerSize;
355+ }
321356
322357 // Lua indices begin at 1
323358 IndexType srcIndex =
0 commit comments