Skip to content

Conversation

@yongjik
Copy link
Contributor

@yongjik yongjik commented Feb 5, 2018

  • OffsetInfo and OffsetIterator pre-computes the necessary coordinate
    change along each dimension, so that each successive offset can be
    computed using only addition/subtraction/comparisons.

  • Added IntDivider which supports "magic division" for uint32_t, thus
    eliminating integer divisions altogether for offset calculation, as
    long as indices fit in 32 bits.

  • In code paths with statically determined dimensions (Dims=1 or 2),
    kernel arguments now contain only the necessary data (instead of
    MAX_CUTORCH_DIMS of everything).

  • Fixed index overflow errors: for tensors with >= 2G elements, we used
    to have incorrect results or an infinite loop inside the kernel.

TODO: The following pattern is broken for tensors with >= 2G elements.
It will result in overflow, even if IndexType is uint64_t. Need
to search and replace them.

for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {

- OffsetInfo and OffsetIterator pre-computes the necessary coordinate
  change along each dimension, so that each successive offset can be
  computed using only addition/subtraction/comparisons.

- Added IntDivider which supports "magic division" for uint32_t, thus
  eliminating integer divisions altogether for offset calculation, as
  long as indices fit in 32 bits.

- In code paths with statically determined dimensions (Dims=1 or 2),
  kernel arguments now contain only the necessary data (instead of
  MAX_CUTORCH_DIMS of everything).

- Fixed index overflow errors: for tensors with >= 2G elements, we used
  to have incorrect results or an infinite loop inside the kernel.

TODO: The following pattern is broken for tensors with >= 2G elements.
      It will result in overflow, even if IndexType is uint64_t.  Need
      to search and replace them.

  > for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
  >      linearIndex < totalElements;
  >      linearIndex += gridDim.x * blockDim.x) {
@yongjik
Copy link
Contributor Author

yongjik commented Feb 5, 2018

This PR improves some float operations by ~20% (and some operations on ByteTensor by up to ~45%), but in general the performance impact seems small, unless one uses a lot of non-contiguous tensors and/or broadcasting with large dimensions.

Here's an example where I could get ~20% improvement on GTX 1080:

A = torch.cuda.FloatTensor(1000, 256)
B = torch.cuda.FloatTensor(128)
A = A[:, :128]
A.pow_(B)  # Improves from ~7.6 to ~6.0 usec

I found at least one case where it becomes slower by ~5%, but such cases seem to be rare, so I still think it's a net performance win on average, although small.

A = torch.cuda.IntTensor(2048, 2048)
A = A[:, :2000]
B = torch.cuda.IntTensor(2000).fill_(10)
A.remainder_(B)  # Changes from ~180 to 189 usec.

Raw benchmark results are https://github.com/yongjik/pt_test/tree/master/results/offset in case anybody's interested.

  • This PR also fixes overflow errors with large tensors (2G elements or more), which makes me suspect that nobody has actually been using 64-bit index math so far...

@soumith
Copy link
Contributor

soumith commented Feb 5, 2018

@pytorchbot add to whitelist

Copy link
Contributor

@wickedfoo wickedfoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that the code as you have it results in a performance win, and it makes the code a lot more complicated. 7.6 us to 6.0 us is within the realm of noise, and such changes are sensitive to heuristics used in the register allocator and in other places.

Replacing the linear index with a per-dimension index will bloat out the register count, and the code within the new iteration stuff looks like it has divergent/predicated execution paths as well.

However, I do believe that constant integer division via multiplication/shift by constants is worth trying. Your magic number division algorithm can be simplified by restricting its usage to the case 2 to max signed int (see comments).

Can you do a more minimal diff keeping the old kernel structure and the linear index -> offset lookup trying the faster version of the magic constant division algorithm, with a fallback to using normal integer div/mod if it falls outside the range under consideration?

For performance testing, I would concentrate on sufficiently large tensor sizes, say a large tensor (multi-100 MB+ in size) that is transposed on which you perform pointwise operations. A kernel that executes in just microseconds I think is likely to fall within the margin of noise.

Also I would inspect the SASS to see what instructions it was emitting before for integer div/mod (I believe it tries to map it to floating point inverse, when I recall looking a long time ago), and see what instructions it actually issues for umulhi as well.

#ifdef __CUDA_ARCH__
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
// 'm1'.
unsigned int t = __umulhi(n, m1);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
// 'm1'.
unsigned int t = __umulhi(n, m1);
unsigned int t2 = t + ((n - t) >> s1);

This comment was marked as off-topic.

This comment was marked as off-topic.

{
bool carry = false;

for (int i = dims - 1; i > 0; --i) {

This comment was marked as off-topic.

This comment was marked as off-topic.

bool carry = false;

for (int i = dims - 1; i > 0; --i) {
IndexType index = indices[i] + increments[i] + (IndexType) carry;

This comment was marked as off-topic.

This comment was marked as off-topic.

typename IndexType,
int ADims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM)

This comment was marked as off-topic.

This comment was marked as off-topic.

IndexType next = index + step;

// The second condition is necessary to handle overflow (e.g., when step is
// 2GB and limit is 3GB, assuming 32-bit index).

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

const OffsetInfo<Tb, IndexType, BDims> b,
IndexType totalElements,
Op op) {
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;

This comment was marked as off-topic.

This comment was marked as off-topic.

@yongjik
Copy link
Contributor Author

yongjik commented Feb 6, 2018

Hi @wickedfoo, thanks for the detailed review, and I understand your point that the code is too complicated for the (rather unimpressive) speedup. I'll try just using the constant division algorithm and get back to you. Might take a few days.

On the other hand, I do think there's a measurable speedup for some cases. One case I found:

A = torch.cuda.FloatTensor(4096, 4096)
B = torch.cuda.FloatTensor(2048)
A = A[:, :2048]
A.pow_(B)  # Changes from ~389 us to ~315 us (speedup ~23%)

Ironically, using even larger tensor doesn't show larger speedup, because then (I suppose) memory bandwidth dominates everything.

@ngimel
Copy link
Collaborator

ngimel commented Feb 6, 2018

@yongjik also take a look at https://github.com/milakov/int_fastdiv

@wickedfoo
Copy link
Contributor

Also the reason that IndexType was unsigned int rather than signed int for 32-bit indices was that I discovered that div/mod was faster on unsigned rather than signed types (relevant for K40 and M40 I think?), but this may no longer be the case with more recent architectures. It might be worth comparing both signed and unsigned int32 for IndexType to see if there's still a difference.

@wickedfoo
Copy link
Contributor

The integer division by magic constants code in the Caffe2 source I think will be faster than int_fastdiv if you exclude the -1 / 1 case. They're basically the same code more or less, except you avoid this additional work:

https://github.com/milakov/int_fastdiv/blob/master/int_fastdiv.h#L126

@Stonesjtu
Copy link
Contributor

@yongjik I suffered a lot tuning the indexSelectKernel and I found the kernel is mainly limited by read_throughput. I suspects that we should focus more on reducing the memory read operations and improving the locality, rather than only computational cost.

- Also changed canUse32BitIndexMath so that the max index for 32-bit
  math is INT32_MAX, instead of UINT32_MAX.  It also simplifies the
  division operation.
@yongjik yongjik changed the title Use pre-computed offset increments to avoid int division inside kernels. Use fast integer division algorithm to avoid division ops inside kernels. Feb 13, 2018
@yongjik
Copy link
Contributor Author

yongjik commented Feb 13, 2018

Hi @wickedfoo, I updated the code to remove the increment stuff and only leave the int division algorithm. Could you take another look?

Regarding signed/unsigned integer, I think the point is moot, because (in the references I found) the fast division algorithm for signed integers always has more operations than the unsigned version. So I think they don't really give us any benefit here.

@yongjik
Copy link
Contributor Author

yongjik commented Feb 21, 2018

Hi guys, any thoughts on this PR?

@yongjik
Copy link
Contributor Author

yongjik commented Feb 28, 2018

Hi @wickedfoo, could you give your opinion? If this PR still looks like too much complication, I understand if you don't want to merge this, but I'd appreciate a decision rather than this PR staying in limbo forever. Thanks!

@soumith
Copy link
Contributor

soumith commented Mar 1, 2018

@yongjik i think he does not get github notification emails. I will ping him directly. sorry for delay.

@wickedfoo
Copy link
Contributor

Looking now.

Copy link
Contributor

@wickedfoo wickedfoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Any idea what the performance change of this is (i.e., is it worth it, and for what sizes)?

__host__ __device__ T* get(IndexType linearIndex) const {
IndexType offset = 0;

for (int i = tinfo.dims - 1; i > 0; --i) {

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -0,0 +1,89 @@
#ifndef THC_OFFSET_INFO_INC

This comment was marked as off-topic.

This comment was marked as off-topic.

@yongjik
Copy link
Contributor Author

yongjik commented Mar 2, 2018

Hi @wickedfoo, thanks for the review.

I ran several hundred configurations of tensor operations (on GTX 1080 / CUDA 9.1), including add, mul, tanh, pow, and remainder, on tensors with sizes ranging between 1000x64 and 4096x4096. About 1/3 of them showed speedup of 5% or more. There are some cases when performance improves by up to ~50%. For the rest, performance didn't change meaningfully (within 5%). I saw several cases of slowdown of ~5%, but these operations were too small (< 20 us) and I couldn't reproduce them reliably: might be just random noise.

The biggest win I could find was:

A = torch.cuda.HalfTensor(4096, 4096)
B = torch.cuda.HalfTensor(2048)
C = torch.cuda.HalfTensor(4096, 4096)

# ~266 us to ~172 us (speedup 55%)
torch.mul(A[:, :2048], B, out=C[:, :2048])

A = torch.cuda.HalfTensor(8192, 8192)
B = torch.cuda.HalfTensor(4096)
C = torch.cuda.HalfTensor(8192, 8192)

# ~1.06 to ~0.69 ms (speedup 54%)
torch.mul(A[:, :4096], B, out=C[:, :4096])

We also have speedup for float operations, though not as dramatic:

# ~26 to ~21 us (speedup 24%, but it's too noisy so might not be accurate)
nrows, ncols, ncols2 = 1024, 1024, 512

# ~365 to 312 us (speedup 17%)
nrows, ncols, ncols2 = 4096, 4096, 2048

# ~1.38 to 1.17 ms (speedup 18%)
nrows, ncols, ncols2 = 8000, 8000, 4000

A = torch.cuda.FloatTensor(nrows, ncols)
B = torch.cuda.FloatTensor(ncols2)
A = A[:, :ncols2]

A.pow_(B)

For some other float operations, I observed speedup of ~25% for mid-size tensors (around 1000x128), but it becomes smaller as tensors get bigger (~9% for 1024x1024, ~3% for 8000x3000), probably because memory latency dominates everything for these tensors.

@soumith
Copy link
Contributor

soumith commented Mar 3, 2018

@yongjik thanks a lot for the PR it's good to go.
@ezyang can you check if the failing perf-regression test is legit or false-positive?

@yongjik
Copy link
Contributor Author

yongjik commented Mar 3, 2018

I don't know if I'm doing it right, but I followed the advice of the failed test log and ran cd .jenkins/perf_test/ && bash test_gpu_speed_mlstm.sh:

On clean branch (2726550, ran three times):

Runtime stats in seconds:
{"mean": 2.3622500000000004, "sigma": 0.01499624953113275}
{"mean": 2.37195, "sigma": 0.01663272376972571}
{"mean": 2.3788499999999995, "sigma": 0.01288128487380043}

With this PR on top of it:

{"mean": 2.3658499999999996, "sigma": 0.018205150370156273}
{"mean": 2.3644000000000007, "sigma": 0.021511392330576876}
{"mean": 2.36015, "sigma": 0.01142048597915163}

So I think there's no meaningful difference on GTX-1080, but other GPUs might report different numbers, I guess.

@ezyang
Copy link
Contributor

ezyang commented Mar 4, 2018

The GPU perf tests have been flaky recently, so you should ignore them for the purposes of assessing this PR.

@soumith soumith merged commit c713c66 into pytorch:master Mar 4, 2018
@soumith
Copy link
Contributor

soumith commented Mar 4, 2018

thanks @yongjik. sorry for the delay in review.

@yongjik
Copy link
Contributor Author

yongjik commented Mar 4, 2018

No worries! Half of the delay was mine, after all. Thanks for the review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants