Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 111 additions & 6 deletions aten/src/THC/THCApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,101 @@
// copying or temporary storage.
//

// Rearrange dimensions for pointwise operations so that strides are in
// decreasing order as much as possible, so that kernels have better memory
// access patterns.
//
// For example, consider a binary operation on two "transposed" 2-dim tensors:
// sizes: 256 512
// aInfo->strides: 1 256
// bInfo->strides: 1 256
//
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
// exactly 256 elements apart, resulting in poor performance.
//
// This function exchanges dimensions so that memory access is contiguous:
// sizes: 512 256
// aInfo->strides: 256 1
// bInfo->strides: 256 1
//
// (Actually, it becomes even better because now collapseDims() can turn each
// input into one contiguous array.)
//
// In general, given M (<=3) TensorInfo's with N dimensions, we can view each
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
// strides[i] and [j] if
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
// (exchanging them will benefit input #k), and
// (2) strides[i][k] <= strieds[j][k] for all k
// (exchanging them will not make any input worse).
template <typename T1, typename IndexType,
typename T2 = void, typename T3 = void>
void rearrangeDims(TensorInfo<T1, IndexType>* aInfo,
TensorInfo<T2, IndexType>* bInfo = nullptr,
TensorInfo<T3, IndexType>* cInfo = nullptr) {
int numInfos = 1;
int dims = aInfo->dims;
IndexType *sizes[3] = { aInfo->sizes, };
IndexType *strides[3] = { aInfo->strides, };

if (bInfo != nullptr) {
++numInfos;
if (bInfo->dims != dims) return;
sizes[1] = bInfo->sizes;
strides[1] = bInfo->strides;
}

if (cInfo != nullptr) {
++numInfos;
if (cInfo->dims != dims) return;
sizes[2] = cInfo->sizes;
strides[2] = cInfo->strides;
}

// Bail out if sizes do not match: we are using "deprecated pointwise
// behavior" among tensors of different shapes but same number of elements.
for (int i = 1; i < numInfos; ++i) {
for (int j = 0; j < dims; ++j) {
if (sizes[i][j] != sizes[0][j]) return;
}
}

for (int i = 0; i < dims - 1; ++i) {
// No need to consider dimensions of size 1.
if (sizes[0][i] == 1) continue;

for (int j = i + 1; j < dims; ++j) {
if (sizes[0][j] == 1) continue;

// Compare the relative sizes of strides between dim #i and dim #j.
bool hasIncreasingStrides = false;
bool hasDecreasingStrides = false;

for (int k = 0; k < numInfos; k++) {
IndexType stride_i = strides[k][i];
IndexType stride_j = strides[k][j];
if (stride_i < stride_j) {
hasIncreasingStrides = true;
} else if (stride_i > stride_j) {
hasDecreasingStrides = true;
}
}

if (hasIncreasingStrides && !hasDecreasingStrides) {
for (int k = 0; k < numInfos; k++) {
IndexType size = sizes[k][i];
sizes[k][i] = sizes[k][j];
sizes[k][j] = size;

IndexType stride = strides[k][i];
strides[k][i] = strides[k][j];
strides[k][j] = stride;
}
}
}
}
}

// Threads per block for our apply kernel
// FIXME: use occupancy calculator instead
#define THC_APPLY_THREADS_PER_BLOCK 32 * 16
Expand Down Expand Up @@ -197,6 +292,7 @@ bool THC_pointwiseApply1(THCState* state,
if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();
#if CUDA_VERSION < 9000
if (!aInfo.isContiguous())
Expand All @@ -206,6 +302,7 @@ bool THC_pointwiseApply1(THCState* state,
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();

// For large tensors, we only compile the completely contiguous
Expand Down Expand Up @@ -359,10 +456,12 @@ bool THC_pointwiseApply2(THCState* state,
TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
aInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
getTensorInfo<TensorTypeB, unsigned int>(state, b);

rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();
#if CUDA_VERSION < 9000
if (!(aInfo.isContiguous() && bInfo.isContiguous()))
Expand All @@ -373,10 +472,12 @@ bool THC_pointwiseApply2(THCState* state,
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
aInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
getTensorInfo<TensorTypeB, uint64_t>(state, b);

rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();

// For large tensors, we only compile the completely contiguous
Expand Down Expand Up @@ -566,14 +667,16 @@ bool THC_pointwiseApply3(THCState* state,
TensorUtils<TensorTypeC>::canUse32BitIndexMath(state, c)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
aInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
getTensorInfo<TensorTypeB, unsigned int>(state, b);
bInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int> cInfo =
getTensorInfo<TensorTypeC, unsigned int>(state, c);

rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
bInfo.collapseDims();
cInfo.collapseDims();

#if CUDA_VERSION < 9000
Expand All @@ -584,14 +687,16 @@ bool THC_pointwiseApply3(THCState* state,
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
aInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
getTensorInfo<TensorTypeB, uint64_t>(state, b);
bInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t> cInfo =
getTensorInfo<TensorTypeC, uint64_t>(state, c);

rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
bInfo.collapseDims();
cInfo.collapseDims();

// For large tensors, we only compile the completely contiguous
Expand Down