Skip to content

Conversation

@mingfeima
Copy link
Collaborator

@mingfeima mingfeima commented Apr 25, 2019

This PR aims at improving topk() performance on CPU. This is useful when computing beam search during Transformer and BERT.

Given a tensor x of size [N, C], and we want to apply x.topk(K), the current logic is sequentially loop on the dimension of N and do quick select on the dimension of C so as to find out top K elements.

Performance can be further improved from:

  • On the dimension of N, it can be paralleled
  • Maybe a faster sorting algorithm for topk. (After a bunch of experimenting, std::partial_sort seems to be the most promising)

So i compared 3 versions:

  1. vanilla: sequential + quick select
  2. reference PR [Don't merge]optimize topk on cpu using parallel and quick select #19737: parallel + quick select
  3. this PR: parallel + partial sort

with the following benchmark, on Xeon 8180, 2*28 cores@2.5 GHz:

import torch
from time import time

num_iters = 1000

def bench_topk(N=8, C=168560, k=10):
    a = torch.randn(N, C)
    # warm up
    for i in range(100):
        torch.topk(a, k)
    
    t = 0
    for i in range(num_iters):
        a = torch.randn(N, C)
        start = time()
        value, indice = torch.topk(a, k)
        t += time() - start
    print("#[%d, %d] times: %f ms" % (N, C, t / num_iters * 1000))

Ns = [10, 20, 30]
Cs = [10000, 20000, 40000, 80000, 160000, 320000]

for n in Ns:
    for c in Cs:
        bench_topk(N=n, C=c)

vanilla: sequential + quick select

#[10, 10000] times: 0.746740 ms
#[10, 20000] times: 1.437399 ms
#[10, 40000] times: 2.832455 ms
#[10, 80000] times: 5.649426 ms
#[10, 160000] times: 11.309466 ms
#[10, 320000] times: 22.798765 ms
#[20, 10000] times: 1.511303 ms
#[20, 20000] times: 2.822024 ms
#[20, 40000] times: 5.564770 ms
#[20, 80000] times: 11.443044 ms
#[20, 160000] times: 22.747731 ms
#[20, 320000] times: 46.234449 ms
#[30, 10000] times: 2.214045 ms
#[30, 20000] times: 4.236179 ms
#[30, 40000] times: 8.418577 ms
#[30, 80000] times: 17.067578 ms
#[30, 160000] times: 33.826214 ms
#[30, 320000] times: 68.109420 ms

reference PR: parallel + quick select

#[10, 10000] times: 0.271649 ms
#[10, 20000] times: 0.593016 ms
#[10, 40000] times: 1.133518 ms
#[10, 80000] times: 2.082355 ms
#[10, 160000] times: 4.049928 ms
#[10, 320000] times: 7.321285 ms
#[20, 10000] times: 0.315255 ms
#[20, 20000] times: 0.539054 ms
#[20, 40000] times: 1.000675 ms
#[20, 80000] times: 1.914586 ms
#[20, 160000] times: 4.437122 ms
#[20, 320000] times: 8.822445 ms
#[30, 10000] times: 0.347209 ms
#[30, 20000] times: 0.589947 ms
#[30, 40000] times: 1.102814 ms
#[30, 80000] times: 2.112201 ms
#[30, 160000] times: 5.186837 ms
#[30, 320000] times: 10.523023 ms

this PR: parallel + partial sort

#[10, 10000] times: 0.150284 ms
#[10, 20000] times: 0.220089 ms
#[10, 40000] times: 0.521875 ms
#[10, 80000] times: 0.965593 ms
#[10, 160000] times: 2.312356 ms
#[10, 320000] times: 4.759422 ms
#[20, 10000] times: 0.167630 ms
#[20, 20000] times: 0.265607 ms
#[20, 40000] times: 0.471477 ms
#[20, 80000] times: 0.974572 ms
#[20, 160000] times: 3.269645 ms
#[20, 320000] times: 6.538608 ms
#[30, 10000] times: 0.204976 ms
#[30, 20000] times: 0.342833 ms
#[30, 40000] times: 0.589381 ms
#[30, 80000] times: 1.398579 ms
#[30, 160000] times: 3.904077 ms
#[30, 320000] times: 9.681224 ms

In summary, 2 is 5x faster than vanilla on average and 3 is 8.6x faster than vanilla.
On Fairseq Transformer, the default parameter on dataset wmt14 would have a topk size of [8, 168560], and this operator gets 3x faster with this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that we need to copy the data inside the TensorAccessor.

@ezyang do you know how hard it would be to make TensorAccessor work with std::sort and alike? Would that be something we would want to do at some point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just get a raw data pointer if it's contiguous and sort in-place. You're unlikely to get std::sort working natively on strided stuff though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if TensorAccessor respects the following requirements, it would work

Type requirements

My question is more like, how much effort would it be to make this happen, if possible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like @ezyang said, in case dim=-1, (sorting dimension is contiguous), it is possible to create std::vector directly on the raw pointer. But if not (non-contiguous), it is not going to work.
I measured the time cost of this copy, it is not that big compared to the sorting. And the current topk also has copy, here, since the sorting needs to swap the value/indice.

Anyway, if TensorAccessor and std::vector has some sort of in place conversion, it would be more convenient.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 26, 2019
@fmassa
Copy link
Member

fmassa commented Apr 27, 2019

While not conflicting with this PR, #18344 is relevant, as it optimizes kthvalue for float, and one could use kthvalue to implement topk

@cpuhrsch
Copy link
Contributor

@fmassa Have we look at faiss as a source for the top-k implementation? In particular for GPU.

@mingfeima
Copy link
Collaborator Author

While not conflicting with this PR, #18344 is relevant, as it optimizes kthvalue for float, and one could use kthvalue to implement topk

Yes, this one is not conflicting #18344. In case the avx2 partition is faster than std::partial_sort, we can switch from std::partial_sort to quick_select_template very easily, just like the reference #19737

@fmassa
Copy link
Member

fmassa commented Apr 28, 2019

@cpuhrsch I believe the kernels in faiss were (are?) the same as the ones in PyTorch. Indeed, @wickedfoo was the one who added support for topk on the GPU back in the Lua days, and I'm not sure if there are differences nowadays that would be good to port back

@soumith
Copy link
Contributor

soumith commented Apr 28, 2019

the ones in faiss are significantly different than pytorch, though jhj is the person who wrote both. The ones in faiss are also significantly more complex than what pytorch needs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very close to TensorIterators. In fact, dim_apply overall looks like it could be replaced by that. This would yield much cleaner and usually faster code. @VitalyFedyunin, please provide guidance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the reason why TensorIterator cant be used here is that they cant take a set of input values and return a single output value (which is the case in sort / topk). But it would be great to extend TensorIterator to support this use-case

@cpuhrsch
Copy link
Contributor

@mingfeima - could you also provide benchmarks for varying numbers of threads? In particular single core performance.

@mingfeima
Copy link
Collaborator Author

mingfeima commented Apr 29, 2019

@cpuhrsch scaling result with OMP_NUM_THREADS=1, 2, 4, 8, 16.
before:

### topk: OMP_NUM_THREADS=1 ###
#[10, 10000] times: 0.749171 ms
#[10, 20000] times: 1.449934 ms
#[10, 40000] times: 2.854897 ms
#[10, 80000] times: 5.720388 ms
#[10, 160000] times: 11.393573 ms
#[10, 320000] times: 22.958385 ms
### topk: OMP_NUM_THREADS=2 ###
#[10, 10000] times: 0.748674 ms
#[10, 20000] times: 1.452205 ms
#[10, 40000] times: 2.847788 ms
#[10, 80000] times: 5.707714 ms
#[10, 160000] times: 11.470940 ms
#[10, 320000] times: 22.871153 ms
### topk: OMP_NUM_THREADS=4 ###
#[10, 10000] times: 0.746231 ms
#[10, 20000] times: 1.448718 ms
#[10, 40000] times: 2.868835 ms
#[10, 80000] times: 5.650503 ms
#[10, 160000] times: 11.451153 ms
#[10, 320000] times: 22.694164 ms
### topk: OMP_NUM_THREADS=8 ###
#[10, 10000] times: 0.745902 ms
#[10, 20000] times: 1.451694 ms
#[10, 40000] times: 2.853764 ms
#[10, 80000] times: 5.688034 ms
#[10, 160000] times: 11.397316 ms
#[10, 320000] times: 22.893178 ms
### topk: OMP_NUM_THREADS=16 ###
#[10, 10000] times: 0.743552 ms
#[10, 20000] times: 1.459475 ms
#[10, 40000] times: 2.864395 ms
#[10, 80000] times: 5.682968 ms
#[10, 160000] times: 11.432610 ms
#[10, 320000] times: 22.753903 ms

after:

### topk: OMP_NUM_THREADS=1 ###
#[10, 10000] times: 0.739171 ms
#[10, 20000] times: 1.347100 ms
#[10, 40000] times: 2.909180 ms
#[10, 80000] times: 5.716366 ms
#[10, 160000] times: 12.710830 ms
#[10, 320000] times: 28.739215 ms
### topk: OMP_NUM_THREADS=2 ###
#[10, 10000] times: 0.471362 ms
#[10, 20000] times: 0.796563 ms
#[10, 40000] times: 1.673725 ms
#[10, 80000] times: 3.335849 ms
#[10, 160000] times: 6.459378 ms
#[10, 320000] times: 15.167443 ms
### topk: OMP_NUM_THREADS=4 ###
#[10, 10000] times: 0.323723 ms
#[10, 20000] times: 0.518589 ms
#[10, 40000] times: 1.106971 ms
#[10, 80000] times: 2.216807 ms
#[10, 160000] times: 4.901155 ms
#[10, 320000] times: 11.800200 ms
### topk: OMP_NUM_THREADS=8 ###
#[10, 10000] times: 0.266110 ms
#[10, 20000] times: 0.396013 ms
#[10, 40000] times: 0.824147 ms
#[10, 80000] times: 1.694105 ms
#[10, 160000] times: 3.393351 ms
#[10, 320000] times: 8.041087 ms
### topk: OMP_NUM_THREADS=16 ###
#[10, 10000] times: 0.180306 ms
#[10, 20000] times: 0.267109 ms
#[10, 40000] times: 0.548724 ms
#[10, 80000] times: 1.096469 ms
#[10, 160000] times: 2.346745 ms
#[10, 320000] times: 5.103983 ms

This patch did parallelization only so won't help when using single core. But in case you use more than 2 cores, it helps.

In case we use parallelization, partial_sort performs slightly better than quick select, probably because quick select has an average time complexity of o(n) but the worst case is o(n2). And in the thread pool you have wait for the every last one to finish. Anyway it should be very easy to switch the inner sorting algorithm in case we know a better option, like #18344

i was looking at Transformer in the mlperf so my scenario will be multi core. Are you specifically looking at single core scenario? well in that case, apply simd sort is needed.

@wickedfoo
Copy link
Contributor

This thread seems to be about the CPU but there was a question about the GPU.

The GPU top-k code in (Lua) Torch is different than the top-k code in Caffe2 is different than the top-k code in Faiss. I wrote all three.

The code in Torch supports arbitrary k (1 <= k <= n) via radix sort. It can be slow because it requires multiple passes over the input. The topk call in Torch needs to support arbitrary k, but it can have a specialization for small k (e.g., k <= 1024 or 2048).

The code in Caffe2 I think supports small k only, but could be wrong on that. It contains code from a very early version of Faiss.

The code in Faiss is fastest, but it only supports k <= 2048.

@VitalyFedyunin
Copy link
Contributor

VitalyFedyunin commented Apr 29, 2019

@mingfeima can you please run compare before, after for k=10,100,C/10,C/2, C-5

@mingfeima
Copy link
Collaborator Author

@mingfeima can you please run compare before, after for k=10,100,C/10,C/2, C-5

@VitalyFedyunin

  1. Any range for C?
  2. Number of threads? (This patch wont help for single thread)
  3. C-5 refers to sorting the whole channel except the last five?
  4. In case k is a relatively large number (for example C/2), torch.(sorted=True) or torch.(sorted=False)?

I was assuming that k is not a very large number (in beam search it is beam_size * 2 or even less), so whether sorting the k values or not won't make any difference. But in case you have scenarios of dealing with relatively large k, it will.

@fmassa
Copy link
Member

fmassa commented Apr 30, 2019

@mingfeima it would also be good to run it with a single thread, to assess the impact of performing a copy of the data into a vector to feed to the partial_sort.

Here is another example where topk is used (for mask-rcnn).
The typical sizes for the tensor objectness are:

objectness = torch.rand(2, 100000)
pre_nms_top_n = 2000
objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True)

@VitalyFedyunin
Copy link
Contributor

  1. Try to pick various C. 10,000 40,000 320,000
  2. Compare single thread and 20 threads.
  3. Yes
  4. sorted=False

@mingfeima mingfeima force-pushed the topk/partial_sort branch from 1f225f3 to da1f9c8 Compare May 23, 2019 05:06
@mingfeima
Copy link
Collaborator Author

mingfeima commented May 23, 2019

@VitalyFedyunin @fmassa Hi, all the proposed patterns have been tested to have a more generic converage:

  1. C = 10000, 40000, 320000
  2. K = 10, 50, 100, C/10, C/2, C-5
  3. Test with 20 threads and 1 thread
  4. Test with Sorted=True and Sorted=False

All the details including:

  1. test pattern
  2. benchmark scripts
  3. performance results
  4. optimization strategy
  5. topk design from other frameworks

are listed in the gist, otherwise it's too much to show in this thread.

To sum up a little bit, with the latest code in this pr, topk() compared to original:

  1. 20 threads and sorted=True: 8.8x speedup.
  2. 20 threads and sorted=False: 7.5x speedup.
  3. 1 thread and sorted=True: 1.8x speedup.
  4. 1 thread and sorted=False: 1.7x speedup.

Speedup from single thread runs come from std::partial_sort over quick select.

@Jianhui-Li
Copy link

@cpuhrsch

@cpuhrsch
Copy link
Contributor

cc @VitalyFedyunin for TensorIterator.

@mingfeima - could I ask you to rebase the PR so we can import it?

@VitalyFedyunin
Copy link
Contributor

Writing dim_apply implementation for TensorIterator is nice thing, but might be out of scope for this PR. I will add it in our wishlist table

@mingfeima mingfeima force-pushed the topk/partial_sort branch from da1f9c8 to 0d459b7 Compare June 26, 2019 08:07
@pytorchbot pytorchbot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jun 26, 2019
@mingfeima
Copy link
Collaborator Author

@VitalyFedyunin The CPU kernels have been moved to native/cpu/SortingKernel.cpp

Perf is slightly better compared with last version,

  1. 20 threads and sorted=True: 9.4x speedup (last version 8.8x).
  2. 20 threads and sorted=False: 8.2x speedup (last version 7.5x).
  3. 1 thread and sorted=True: 1.8x speedup (last version 1.8x).
  4. 1 thread and sorted=False: 1.7x speedup (last version 1.7x).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin VitalyFedyunin self-requested a review July 11, 2019 17:39
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 12, 2019
Summary:
This PR aims at improving `topk()` performance on CPU. This is useful when computing **beam search** during `Transformer` and `BERT`.

Given a tensor x of size `[N, C]`, and we want to apply `x.topk(K)`, the current logic is **sequentially** loop on the dimension of `N` and do **quick select** on the dimension of `C` so as to find out top K elements.

Performance can be further improved from:

- On the dimension of `N`, it can be paralleled
- Maybe a faster sorting algorithm for `topk`. (After a bunch of experimenting, `std::partial_sort` seems to be the most promising)

So i compared 3 versions:

1. vanilla: sequential + quick select
2. reference PR pytorch/pytorch#19737: parallel + quick select
3. this PR: parallel + partial sort

with the following benchmark, on `Xeon 8180, 2*28 cores@2.5 GHz`:
```python
import torch
from time import time

num_iters = 1000

def bench_topk(N=8, C=168560, k=10):
    a = torch.randn(N, C)
    # warm up
    for i in range(100):
        torch.topk(a, k)

    t = 0
    for i in range(num_iters):
        a = torch.randn(N, C)
        start = time()
        value, indice = torch.topk(a, k)
        t += time() - start
    print("#[%d, %d] times: %f ms" % (N, C, t / num_iters * 1000))

Ns = [10, 20, 30]
Cs = [10000, 20000, 40000, 80000, 160000, 320000]

for n in Ns:
    for c in Cs:
        bench_topk(N=n, C=c)

```
### vanilla: sequential + quick select
```
#[10, 10000] times: 0.746740 ms
#[10, 20000] times: 1.437399 ms
#[10, 40000] times: 2.832455 ms
#[10, 80000] times: 5.649426 ms
#[10, 160000] times: 11.309466 ms
#[10, 320000] times: 22.798765 ms
#[20, 10000] times: 1.511303 ms
#[20, 20000] times: 2.822024 ms
#[20, 40000] times: 5.564770 ms
#[20, 80000] times: 11.443044 ms
#[20, 160000] times: 22.747731 ms
#[20, 320000] times: 46.234449 ms
#[30, 10000] times: 2.214045 ms
#[30, 20000] times: 4.236179 ms
#[30, 40000] times: 8.418577 ms
#[30, 80000] times: 17.067578 ms
#[30, 160000] times: 33.826214 ms
#[30, 320000] times: 68.109420 ms
```
### reference PR: parallel + quick select
```
#[10, 10000] times: 0.271649 ms
#[10, 20000] times: 0.593016 ms
#[10, 40000] times: 1.133518 ms
#[10, 80000] times: 2.082355 ms
#[10, 160000] times: 4.049928 ms
#[10, 320000] times: 7.321285 ms
#[20, 10000] times: 0.315255 ms
#[20, 20000] times: 0.539054 ms
#[20, 40000] times: 1.000675 ms
#[20, 80000] times: 1.914586 ms
#[20, 160000] times: 4.437122 ms
#[20, 320000] times: 8.822445 ms
#[30, 10000] times: 0.347209 ms
#[30, 20000] times: 0.589947 ms
#[30, 40000] times: 1.102814 ms
#[30, 80000] times: 2.112201 ms
#[30, 160000] times: 5.186837 ms
#[30, 320000] times: 10.523023 ms
```
### this PR: parallel + partial sort
```
#[10, 10000] times: 0.150284 ms
#[10, 20000] times: 0.220089 ms
#[10, 40000] times: 0.521875 ms
#[10, 80000] times: 0.965593 ms
#[10, 160000] times: 2.312356 ms
#[10, 320000] times: 4.759422 ms
#[20, 10000] times: 0.167630 ms
#[20, 20000] times: 0.265607 ms
#[20, 40000] times: 0.471477 ms
#[20, 80000] times: 0.974572 ms
#[20, 160000] times: 3.269645 ms
#[20, 320000] times: 6.538608 ms
#[30, 10000] times: 0.204976 ms
#[30, 20000] times: 0.342833 ms
#[30, 40000] times: 0.589381 ms
#[30, 80000] times: 1.398579 ms
#[30, 160000] times: 3.904077 ms
#[30, 320000] times: 9.681224 ms
```
In summary, `2` is **5x** faster than `vanilla` on average and `3` is **8.6x** faster than `vanilla`.
On `Fairseq Transformer`, the default parameter on dataset `wmt14` would have a `topk` size of `[8, 168560]`, and this operator gets `3x` faster with this PR.
Pull Request resolved: pytorch/pytorch#19736

Differential Revision: D16204820

Pulled By: VitalyFedyunin

fbshipit-source-id: ea70562c9149a0d832cf5872a891042ebd74fc63
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 10c14ad.

facebook-github-bot pushed a commit that referenced this pull request Jul 18, 2019
#22865)

Summary:
#19736 was reverted as it was suspected to be broken on the master, trying to reapply
Pull Request resolved: #22865

Differential Revision: D16265457

Pulled By: VitalyFedyunin

fbshipit-source-id: 784bd6405471f15a8a49ebd0f3e98160d7d0679e
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 18, 2019
…) (#22865)

Summary:
pytorch/pytorch#19736 was reverted as it was suspected to be broken on the master, trying to reapply
Pull Request resolved: pytorch/pytorch#22865

Differential Revision: D16265457

Pulled By: VitalyFedyunin

fbshipit-source-id: 784bd6405471f15a8a49ebd0f3e98160d7d0679e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.