Skip to content

Commit 26f038a

Browse files
yongjiksoumith
authored andcommitted
Improve memory access patterns for index operations. (#4493)
Currently, index operation kernels work in "source/destination index-major order". (E.g., if thread count equals slice size, each thread will process slice #0 in lockstep, and then slice #1, and so on.) However, when elements inside each "slice" is separated by large strides (e.g., selecting columns of a matrix), it is better to switch to "elementInSlice-major order". For example, each thread can process element #0 of every slice, and then element #1 of every slice, and so on.
1 parent 3321cdc commit 26f038a

File tree

2 files changed

+178
-55
lines changed

2 files changed

+178
-55
lines changed

torch/lib/THC/THCTensorIndex.cu

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,30 @@ __global__ void indexCopySmallIndex(TensorInfo<T, IndexType> dst,
6666
// the number of indices chosen is small, then the
6767
// indexCopySmallIndex kernel is a better choice to reduce memory
6868
// accesses.
69-
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
69+
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
70+
bool IndexIsMajor>
7071
__global__ void indexCopyLargeIndex(TensorInfo<T, IndexType> dst,
7172
TensorInfo<T, IndexType> src,
7273
TensorInfo<int64_t, IndexType> indices,
7374
int dstCopyDim,
7475
int srcCopyDim,
76+
IndexType totalSize,
7577
IndexType innerSize,
7678
int64_t dstCopyDimSize) {
7779
// We stride over the output including the indexed dimension
7880
// (totalSize), and calculate the destination index point based on that
7981
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
80-
linearIndex < innerSize * indices.sizes[0];
82+
linearIndex < totalSize;
8183
linearIndex += gridDim.x * blockDim.x) {
82-
IndexType srcIndex = linearIndex / innerSize;
83-
IndexType elementInSlice = linearIndex % innerSize;
84+
IndexType srcIndex, elementInSlice;
85+
if (IndexIsMajor) {
86+
srcIndex = linearIndex / innerSize;
87+
elementInSlice = linearIndex % innerSize;
88+
}
89+
else {
90+
elementInSlice = linearIndex / innerSize;
91+
srcIndex = linearIndex % innerSize;
92+
}
8493

8594
// Lua indices begin at 1
8695
IndexType dstIndex =
@@ -148,21 +157,30 @@ __global__ void indexAddSmallIndex(TensorInfo<T, IndexType> dst,
148157
// the number of indices chosen is small, then the
149158
// indexAddSmallIndex kernel is a better choice to reduce memory
150159
// accesses.
151-
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
160+
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
161+
bool IndexIsMajor>
152162
__global__ void indexAddLargeIndex(TensorInfo<T, IndexType> dst,
153163
TensorInfo<T, IndexType> src,
154164
TensorInfo<int64_t, IndexType> indices,
155165
int dstAddDim,
156166
int srcAddDim,
167+
IndexType totalSize,
157168
IndexType innerSize,
158169
int64_t dstAddDimSize) {
159170
// We stride over the output including the indexed dimension
160171
// (totalSize), and calculate the destination index point based on that
161172
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
162-
linearIndex < innerSize * indices.sizes[0];
173+
linearIndex < totalSize;
163174
linearIndex += gridDim.x * blockDim.x) {
164-
IndexType srcIndex = linearIndex / innerSize;
165-
IndexType elementInSlice = linearIndex % innerSize;
175+
IndexType srcIndex, elementInSlice;
176+
if (IndexIsMajor) {
177+
srcIndex = linearIndex / innerSize;
178+
elementInSlice = linearIndex % innerSize;
179+
}
180+
else {
181+
elementInSlice = linearIndex / innerSize;
182+
srcIndex = linearIndex % innerSize;
183+
}
166184

167185
// Lua indices begin at 1
168186
IndexType dstIndex =
@@ -225,20 +243,29 @@ __global__ void indexFillSmallIndex(TensorInfo<T, IndexType> dst,
225243
// the number of indices chosen is small, then the
226244
// indexFillSmallIndex kernel is a better choice to reduce memory
227245
// accesses.
228-
template <typename T, typename IndexType, int DstDim, int IdxDim>
246+
template <typename T, typename IndexType, int DstDim, int IdxDim,
247+
bool IndexIsMajor>
229248
__global__ void indexFillLargeIndex(TensorInfo<T, IndexType> dst,
230249
TensorInfo<int64_t, IndexType> indices,
231250
int dstFillDim,
251+
IndexType totalSize,
232252
IndexType innerSize,
233253
int64_t dstFillDimSize,
234254
T val) {
235255
// We stride over the output including the indexed dimension
236256
// (totalSize), and calculate the destination index point based on that
237257
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
238-
linearIndex < innerSize * indices.sizes[0];
258+
linearIndex < totalSize;
239259
linearIndex += gridDim.x * blockDim.x) {
240-
IndexType dstIndex = linearIndex / innerSize;
241-
IndexType elementInSlice = linearIndex % innerSize;
260+
IndexType dstIndex, elementInSlice;
261+
if (IndexIsMajor) {
262+
dstIndex = linearIndex / innerSize;
263+
elementInSlice = linearIndex % innerSize;
264+
}
265+
else {
266+
elementInSlice = linearIndex / innerSize;
267+
dstIndex = linearIndex % innerSize;
268+
}
242269

243270
// Lua indices begin at 1
244271
IndexType dstIndex_ =
@@ -302,7 +329,8 @@ __global__ void indexSelectSmallIndex(TensorInfo<T, IndexType> dst,
302329
// the number of indices chosen is small, then the
303330
// indexSelectSmallIndex kernel is a better choice to reduce memory
304331
// accesses.
305-
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
332+
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
333+
bool IndexIsMajor>
306334
__global__ void indexSelectLargeIndex(TensorInfo<T, IndexType> dst,
307335
TensorInfo<T, IndexType> src,
308336
TensorInfo<int64_t, IndexType> indices,
@@ -316,8 +344,15 @@ __global__ void indexSelectLargeIndex(TensorInfo<T, IndexType> dst,
316344
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
317345
linearIndex < totalSize;
318346
linearIndex += gridDim.x * blockDim.x) {
319-
IndexType dstIndex = linearIndex / innerSize;
320-
IndexType elementInSlice = linearIndex % innerSize;
347+
IndexType dstIndex, elementInSlice;
348+
if (IndexIsMajor) {
349+
dstIndex = linearIndex / innerSize;
350+
elementInSlice = linearIndex % innerSize;
351+
}
352+
else {
353+
elementInSlice = linearIndex / innerSize;
354+
dstIndex = linearIndex % innerSize;
355+
}
321356

322357
// Lua indices begin at 1
323358
IndexType srcIndex =

0 commit comments

Comments
 (0)