Skip to content

Commit 343c15c

Browse files
yongjiksoumith
authored andcommitted
Rearrange dimensions for pointwise operations for better performance. (#4174)
* Rearrange dimensions for pointwise operations for better performance. In existing code, pointwise operations on transposed tensors process data "column by column", resulting in poor performance. The worse case happens when all operands are transposed tensors. This change tries to "un-transpose" tensors in such a case, so that memory access patterns are as sequential as possible. * More explanation on what rearrangeDims() does. * Fixed a very important (and stupid) typo.
1 parent 85ea548 commit 343c15c

File tree

1 file changed

+111
-6
lines changed

1 file changed

+111
-6
lines changed

torch/lib/THC/THCApply.cuh

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,101 @@
1212
// copying or temporary storage.
1313
//
1414

15+
// Rearrange dimensions for pointwise operations so that strides are in
16+
// decreasing order as much as possible, so that kernels have better memory
17+
// access patterns.
18+
//
19+
// For example, consider a binary operation on two "transposed" 2-dim tensors:
20+
// sizes: 256 512
21+
// aInfo->strides: 1 256
22+
// bInfo->strides: 1 256
23+
//
24+
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
25+
// exactly 256 elements apart, resulting in poor performance.
26+
//
27+
// This function exchanges dimensions so that memory access is contiguous:
28+
// sizes: 512 256
29+
// aInfo->strides: 256 1
30+
// bInfo->strides: 256 1
31+
//
32+
// (Actually, it becomes even better because now collapseDims() can turn each
33+
// input into one contiguous array.)
34+
//
35+
// In general, given M (<=3) TensorInfo's with N dimensions, we can view each
36+
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
37+
// strides[i] and [j] if
38+
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
39+
// (exchanging them will benefit input #k), and
40+
// (2) strides[i][k] <= strieds[j][k] for all k
41+
// (exchanging them will not make any input worse).
42+
template <typename T1, typename IndexType,
43+
typename T2 = void, typename T3 = void>
44+
void rearrangeDims(TensorInfo<T1, IndexType>* aInfo,
45+
TensorInfo<T2, IndexType>* bInfo = nullptr,
46+
TensorInfo<T3, IndexType>* cInfo = nullptr) {
47+
int numInfos = 1;
48+
int dims = aInfo->dims;
49+
IndexType *sizes[3] = { aInfo->sizes, };
50+
IndexType *strides[3] = { aInfo->strides, };
51+
52+
if (bInfo != nullptr) {
53+
++numInfos;
54+
if (bInfo->dims != dims) return;
55+
sizes[1] = bInfo->sizes;
56+
strides[1] = bInfo->strides;
57+
}
58+
59+
if (cInfo != nullptr) {
60+
++numInfos;
61+
if (cInfo->dims != dims) return;
62+
sizes[2] = cInfo->sizes;
63+
strides[2] = cInfo->strides;
64+
}
65+
66+
// Bail out if sizes do not match: we are using "deprecated pointwise
67+
// behavior" among tensors of different shapes but same number of elements.
68+
for (int i = 1; i < numInfos; ++i) {
69+
for (int j = 0; j < dims; ++j) {
70+
if (sizes[i][j] != sizes[0][j]) return;
71+
}
72+
}
73+
74+
for (int i = 0; i < dims - 1; ++i) {
75+
// No need to consider dimensions of size 1.
76+
if (sizes[0][i] == 1) continue;
77+
78+
for (int j = i + 1; j < dims; ++j) {
79+
if (sizes[0][j] == 1) continue;
80+
81+
// Compare the relative sizes of strides between dim #i and dim #j.
82+
bool hasIncreasingStrides = false;
83+
bool hasDecreasingStrides = false;
84+
85+
for (int k = 0; k < numInfos; k++) {
86+
IndexType stride_i = strides[k][i];
87+
IndexType stride_j = strides[k][j];
88+
if (stride_i < stride_j) {
89+
hasIncreasingStrides = true;
90+
} else if (stride_i > stride_j) {
91+
hasDecreasingStrides = true;
92+
}
93+
}
94+
95+
if (hasIncreasingStrides && !hasDecreasingStrides) {
96+
for (int k = 0; k < numInfos; k++) {
97+
IndexType size = sizes[k][i];
98+
sizes[k][i] = sizes[k][j];
99+
sizes[k][j] = size;
100+
101+
IndexType stride = strides[k][i];
102+
strides[k][i] = strides[k][j];
103+
strides[k][j] = stride;
104+
}
105+
}
106+
}
107+
}
108+
}
109+
15110
// Threads per block for our apply kernel
16111
// FIXME: use occupancy calculator instead
17112
#define THC_APPLY_THREADS_PER_BLOCK 32 * 16
@@ -197,6 +292,7 @@ bool THC_pointwiseApply1(THCState* state,
197292
if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a)) {
198293
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
199294
getTensorInfo<TensorTypeA, unsigned int>(state, a);
295+
rearrangeDims(&aInfo);
200296
aInfo.collapseDims();
201297
#if CUDA_VERSION < 9000
202298
if (!aInfo.isContiguous())
@@ -206,6 +302,7 @@ bool THC_pointwiseApply1(THCState* state,
206302
} else {
207303
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
208304
getTensorInfo<TensorTypeA, uint64_t>(state, a);
305+
rearrangeDims(&aInfo);
209306
aInfo.collapseDims();
210307

211308
// For large tensors, we only compile the completely contiguous
@@ -359,10 +456,12 @@ bool THC_pointwiseApply2(THCState* state,
359456
TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b)) {
360457
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
361458
getTensorInfo<TensorTypeA, unsigned int>(state, a);
362-
aInfo.collapseDims();
363459

364460
TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
365461
getTensorInfo<TensorTypeB, unsigned int>(state, b);
462+
463+
rearrangeDims(&aInfo, &bInfo);
464+
aInfo.collapseDims();
366465
bInfo.collapseDims();
367466
#if CUDA_VERSION < 9000
368467
if (!(aInfo.isContiguous() && bInfo.isContiguous()))
@@ -373,10 +472,12 @@ bool THC_pointwiseApply2(THCState* state,
373472
} else {
374473
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
375474
getTensorInfo<TensorTypeA, uint64_t>(state, a);
376-
aInfo.collapseDims();
377475

378476
TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
379477
getTensorInfo<TensorTypeB, uint64_t>(state, b);
478+
479+
rearrangeDims(&aInfo, &bInfo);
480+
aInfo.collapseDims();
380481
bInfo.collapseDims();
381482

382483
// For large tensors, we only compile the completely contiguous
@@ -566,14 +667,16 @@ bool THC_pointwiseApply3(THCState* state,
566667
TensorUtils<TensorTypeC>::canUse32BitIndexMath(state, c)) {
567668
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
568669
getTensorInfo<TensorTypeA, unsigned int>(state, a);
569-
aInfo.collapseDims();
570670

571671
TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
572672
getTensorInfo<TensorTypeB, unsigned int>(state, b);
573-
bInfo.collapseDims();
574673

575674
TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int> cInfo =
576675
getTensorInfo<TensorTypeC, unsigned int>(state, c);
676+
677+
rearrangeDims(&aInfo, &bInfo, &cInfo);
678+
aInfo.collapseDims();
679+
bInfo.collapseDims();
577680
cInfo.collapseDims();
578681

579682
#if CUDA_VERSION < 9000
@@ -584,14 +687,16 @@ bool THC_pointwiseApply3(THCState* state,
584687
} else {
585688
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
586689
getTensorInfo<TensorTypeA, uint64_t>(state, a);
587-
aInfo.collapseDims();
588690

589691
TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
590692
getTensorInfo<TensorTypeB, uint64_t>(state, b);
591-
bInfo.collapseDims();
592693

593694
TensorInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t> cInfo =
594695
getTensorInfo<TensorTypeC, uint64_t>(state, c);
696+
697+
rearrangeDims(&aInfo, &bInfo, &cInfo);
698+
aInfo.collapseDims();
699+
bInfo.collapseDims();
595700
cInfo.collapseDims();
596701

597702
// For large tensors, we only compile the completely contiguous

0 commit comments

Comments
 (0)