1212// copying or temporary storage.
1313//
1414
15+ // Rearrange dimensions for pointwise operations so that strides are in
16+ // decreasing order as much as possible, so that kernels have better memory
17+ // access patterns.
18+ //
19+ // For example, consider a binary operation on two "transposed" 2-dim tensors:
20+ // sizes: 256 512
21+ // aInfo->strides: 1 256
22+ // bInfo->strides: 1 256
23+ //
24+ // Given this, each concurrent memory access inside kernelPointwiseApply2() is
25+ // exactly 256 elements apart, resulting in poor performance.
26+ //
27+ // This function exchanges dimensions so that memory access is contiguous:
28+ // sizes: 512 256
29+ // aInfo->strides: 256 1
30+ // bInfo->strides: 256 1
31+ //
32+ // (Actually, it becomes even better because now collapseDims() can turn each
33+ // input into one contiguous array.)
34+ //
35+ // In general, given M (<=3) TensorInfo's with N dimensions, we can view each
36+ // strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
37+ // strides[i] and [j] if
38+ // (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
39+ // (exchanging them will benefit input #k), and
40+ // (2) strides[i][k] <= strieds[j][k] for all k
41+ // (exchanging them will not make any input worse).
42+ template <typename T1, typename IndexType,
43+ typename T2 = void , typename T3 = void >
44+ void rearrangeDims (TensorInfo<T1, IndexType>* aInfo,
45+ TensorInfo<T2, IndexType>* bInfo = nullptr ,
46+ TensorInfo<T3, IndexType>* cInfo = nullptr ) {
47+ int numInfos = 1 ;
48+ int dims = aInfo->dims ;
49+ IndexType *sizes[3 ] = { aInfo->sizes , };
50+ IndexType *strides[3 ] = { aInfo->strides , };
51+
52+ if (bInfo != nullptr ) {
53+ ++numInfos;
54+ if (bInfo->dims != dims) return ;
55+ sizes[1 ] = bInfo->sizes ;
56+ strides[1 ] = bInfo->strides ;
57+ }
58+
59+ if (cInfo != nullptr ) {
60+ ++numInfos;
61+ if (cInfo->dims != dims) return ;
62+ sizes[2 ] = cInfo->sizes ;
63+ strides[2 ] = cInfo->strides ;
64+ }
65+
66+ // Bail out if sizes do not match: we are using "deprecated pointwise
67+ // behavior" among tensors of different shapes but same number of elements.
68+ for (int i = 1 ; i < numInfos; ++i) {
69+ for (int j = 0 ; j < dims; ++j) {
70+ if (sizes[i][j] != sizes[0 ][j]) return ;
71+ }
72+ }
73+
74+ for (int i = 0 ; i < dims - 1 ; ++i) {
75+ // No need to consider dimensions of size 1.
76+ if (sizes[0 ][i] == 1 ) continue ;
77+
78+ for (int j = i + 1 ; j < dims; ++j) {
79+ if (sizes[0 ][j] == 1 ) continue ;
80+
81+ // Compare the relative sizes of strides between dim #i and dim #j.
82+ bool hasIncreasingStrides = false ;
83+ bool hasDecreasingStrides = false ;
84+
85+ for (int k = 0 ; k < numInfos; k++) {
86+ IndexType stride_i = strides[k][i];
87+ IndexType stride_j = strides[k][j];
88+ if (stride_i < stride_j) {
89+ hasIncreasingStrides = true ;
90+ } else if (stride_i > stride_j) {
91+ hasDecreasingStrides = true ;
92+ }
93+ }
94+
95+ if (hasIncreasingStrides && !hasDecreasingStrides) {
96+ for (int k = 0 ; k < numInfos; k++) {
97+ IndexType size = sizes[k][i];
98+ sizes[k][i] = sizes[k][j];
99+ sizes[k][j] = size;
100+
101+ IndexType stride = strides[k][i];
102+ strides[k][i] = strides[k][j];
103+ strides[k][j] = stride;
104+ }
105+ }
106+ }
107+ }
108+ }
109+
15110// Threads per block for our apply kernel
16111// FIXME: use occupancy calculator instead
17112#define THC_APPLY_THREADS_PER_BLOCK 32 * 16
@@ -197,6 +292,7 @@ bool THC_pointwiseApply1(THCState* state,
197292 if (TensorUtils<TensorTypeA>::canUse32BitIndexMath (state, a)) {
198293 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int > aInfo =
199294 getTensorInfo<TensorTypeA, unsigned int >(state, a);
295+ rearrangeDims (&aInfo);
200296 aInfo.collapseDims ();
201297#if CUDA_VERSION < 9000
202298 if (!aInfo.isContiguous ())
@@ -206,6 +302,7 @@ bool THC_pointwiseApply1(THCState* state,
206302 } else {
207303 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t > aInfo =
208304 getTensorInfo<TensorTypeA, uint64_t >(state, a);
305+ rearrangeDims (&aInfo);
209306 aInfo.collapseDims ();
210307
211308 // For large tensors, we only compile the completely contiguous
@@ -359,10 +456,12 @@ bool THC_pointwiseApply2(THCState* state,
359456 TensorUtils<TensorTypeB>::canUse32BitIndexMath (state, b)) {
360457 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int > aInfo =
361458 getTensorInfo<TensorTypeA, unsigned int >(state, a);
362- aInfo.collapseDims ();
363459
364460 TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int > bInfo =
365461 getTensorInfo<TensorTypeB, unsigned int >(state, b);
462+
463+ rearrangeDims (&aInfo, &bInfo);
464+ aInfo.collapseDims ();
366465 bInfo.collapseDims ();
367466#if CUDA_VERSION < 9000
368467 if (!(aInfo.isContiguous () && bInfo.isContiguous ()))
@@ -373,10 +472,12 @@ bool THC_pointwiseApply2(THCState* state,
373472 } else {
374473 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t > aInfo =
375474 getTensorInfo<TensorTypeA, uint64_t >(state, a);
376- aInfo.collapseDims ();
377475
378476 TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t > bInfo =
379477 getTensorInfo<TensorTypeB, uint64_t >(state, b);
478+
479+ rearrangeDims (&aInfo, &bInfo);
480+ aInfo.collapseDims ();
380481 bInfo.collapseDims ();
381482
382483 // For large tensors, we only compile the completely contiguous
@@ -566,14 +667,16 @@ bool THC_pointwiseApply3(THCState* state,
566667 TensorUtils<TensorTypeC>::canUse32BitIndexMath (state, c)) {
567668 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int > aInfo =
568669 getTensorInfo<TensorTypeA, unsigned int >(state, a);
569- aInfo.collapseDims ();
570670
571671 TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int > bInfo =
572672 getTensorInfo<TensorTypeB, unsigned int >(state, b);
573- bInfo.collapseDims ();
574673
575674 TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int > cInfo =
576675 getTensorInfo<TensorTypeC, unsigned int >(state, c);
676+
677+ rearrangeDims (&aInfo, &bInfo, &cInfo);
678+ aInfo.collapseDims ();
679+ bInfo.collapseDims ();
577680 cInfo.collapseDims ();
578681
579682#if CUDA_VERSION < 9000
@@ -584,14 +687,16 @@ bool THC_pointwiseApply3(THCState* state,
584687 } else {
585688 TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t > aInfo =
586689 getTensorInfo<TensorTypeA, uint64_t >(state, a);
587- aInfo.collapseDims ();
588690
589691 TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t > bInfo =
590692 getTensorInfo<TensorTypeB, uint64_t >(state, b);
591- bInfo.collapseDims ();
592693
593694 TensorInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t > cInfo =
594695 getTensorInfo<TensorTypeC, uint64_t >(state, c);
696+
697+ rearrangeDims (&aInfo, &bInfo, &cInfo);
698+ aInfo.collapseDims ();
699+ bInfo.collapseDims ();
595700 cInfo.collapseDims ();
596701
597702 // For large tensors, we only compile the completely contiguous
0 commit comments