-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add CTC loss #9628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CTC loss #9628
Conversation
The CPU and CUDA variants are a direct transposition of Graves et al.'s description of the algorithm with the modification that is is in log space. The there also is a binding for the (much faster) CuDNN implementation.
|
So I experimented a bit.
(Benchmark idea from @galv , http://danielgalvez.me/jekyll/update/2017/12/29/benchmarking-ctc-implementations.html, but with different random target lengths and my own errors, warpctc is unfair because it involves moving the result to cpu and gradientsback)
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Cudnn implementation is likely following this GTC talk http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf, perhaps you can use it to speed up your cuda implementation. |
|
Awesome, thank you very much! I'll check that out. |
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a brief pass on the CPU part because it seems further optimization on CUDA side will be made. The implementation is straightforward and seems correct (didn't look into the maths details), and indeed has room for optimization (as mentioned in the comments). I left some comments.
aten/src/ATen/native/LossCTC.cpp
Outdated
|
|
||
| namespace { | ||
|
|
||
| inline int64_t get_target_prime(int64_t* target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/LossCTC.cpp
Outdated
| } | ||
|
|
||
| template<typename scalar_t> | ||
| std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/LossCTC.cpp
Outdated
| auto log_alpha_data = log_alpha.data<scalar_t>(); | ||
| auto targets_data = targets.data<int64_t>(); | ||
| auto neg_log_likelihood_data = neg_log_likelihood.data<scalar_t>(); // we assume stride one for the only dimension for freshly allocated tensor | ||
| size_t lp_input_stride = log_probs.stride(0); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/LossCTC.cpp
Outdated
| log_alpha.narrow(1, 0, 1).fill_(neginf); // or do this inside the batch loop? | ||
| #pragma omp parallel for | ||
| for (int64_t b = 0; b < batch_size; b++) { | ||
| int64_t input_length = input_lengths[b].toCLong(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/LossCTC.cpp
Outdated
| for (int64_t b = 0; b < batch_size; b++) { | ||
| int64_t input_length = input_lengths[b].toCLong(); | ||
| int64_t target_length = target_lengths[b].toCLong(); | ||
| int64_t lp_batch_offset = b*log_probs.stride(1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I'm still working on incorporating Simon's feedback and the tests, but I wanted to push the cuda updates. So for small to medium batch sizes we still are slower than CuDNN, but I don't think it's unreasonably slow. For large batch sizes, we are actually pretty fast. The deviation between cuda and cudnn is of order 1e-4 in the gradients. I haven't checked who is closer to the true value (by comparing to the full-log-space-results with double). I still need to do the tests. I might yet merge the beta calculation into that of alpha for cuda, as I think that would give some speedup. Also writing a kernel for the larger bits of the gradient calculations currently done in ATen is on my list of potential speedups. |
|
Nice cuda improvements! Many people are averse to non-deterministic results that would result from atomicAdd, that I think was the reason why cudnn added deterministic grad calculation that does not follow GTC paper. So it's more of a question for core developers - are we ok with nondeterminism? |
|
I think we already have nondeterminism e.g. in |
Based on SsnL's review comments. Thank you!
To get new location of getCurrentCUDAStream
|
@pytorchbot retest this please |
|
So, it seems that things are converging: Some questions
While one can always add more optimizations (and it's kind of fun, too), I think the above are the top questions I could use your input on. |
|
So the situation with the cudnn ctc requiring int targets seems bad enough that I'm inclined to make the target type a template parameter and move them to cuda if they are on CPU for the cuda variant. |
|
So the Windows build failure is about at ATen function I use. I could use a hint as to what goes wrong. |
To get the missing symbol on Windows
|
So I found TensorAccessors and they I couldn't resist switching to them (30 lines of code less and much clearer indexing) for CPU. I also changed the CPU and CUDA versions to all accept int labels (and move them to GPU as needed). The idea is that if you program against CuDNN (where you have to have CPU int labels) you can switch to CUDA/CPU as needed. The remaining two CI failures (running out of heap on Windows somewhere else, some timeout for the ROCm one) look like I cannot do much about. I had to adapt the test infrastructure a bit to cover all cases (so additional arguments and not converting targets). |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: I hope this helps me for the windows build failure in pytorch#9628 . Pull Request resolved: pytorch#9904 Differential Revision: D9026715 Pulled By: soumith fbshipit-source-id: bb97d41d060823f5a37bfc9a1659815b8b9f4eab
|
@pytorchbot retest this please |
| In order to use CuDNN, the following must be satisfied: :attr:`targets` must be | ||
| in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`, | ||
| :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of | ||
| :class:`torch.IntTensor`. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| int64_t batch_size = log_probs.size(1); | ||
| int64_t num_labels = log_probs.size(2); | ||
| AT_CHECK(BLANK < num_labels, "blank must be in label range"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Summary:
The CPU and CUDA variants are a direct transposition of Graves et al.'s description of the algorithm with the
modification that is is in log space.
The there also is a binding for the (much faster) CuDNN implementation.
This could eventually fix #3420
I still need to add tests (TestNN seems much more elaborate than the other testing) and fix the bugs than invariably turn up during the testing. Also, I want to add some more code comments.
I could use feedback on all sorts of things, including:
- Type handling (cuda vs. cpu for the int tensors, dtype for the int tensors)
- Input convention. I use log probs because that is what the gradients are for.
- Launch parameters for the kernels
- Errors and obmissions and anything else I'm not even aware of.
Thank you for looking!
In terms of performance it looks like it is superficially comparable to WarpCTC (and thus, but I have not systematically investigated this).
I have read CuDNN is much faster than implementations because it does *not* use log-space, but also the gathering step is much much faster (but I avoided trying tricky things, it seems to contribute to warpctc's fragility). I might think some more which existing torch function (scatter or index..) I could learn from for that step.
Average timings for the kernels from nvprof for some size:
```
CuDNN:
60.464us compute_alphas_and_betas
16.755us compute_grads_deterministic
Cuda:
121.06us ctc_loss_backward_collect_gpu_kernel (= grads)
109.88us ctc_loss_gpu_kernel (= alphas)
98.517us ctc_loss_backward_betas_gpu_kernel (= betas)
WarpCTC:
299.74us compute_betas_and_grad_kernel
66.977us compute_alpha_kernel
```
Of course, I still have the (silly) outer blocks loop rather than computing consecutive `s` in each thread which I might change, and there are a few other things where one could look for better implementations.
Finally, it might not be unreasonable to start with these implementations, as the performance of the loss has to be seen in the context of the entire training computation, so this would likely dilute the relative speedup considerably.
My performance measuring testing script:
```
import timeit
import sys
import torch
num_labels = 10
target_length = 30
input_length = 50
eps = 1e-5
BLANK = 0#num_labels
batch_size = 16
torch.manual_seed(5)
activations = torch.randn(input_length, batch_size, num_labels + 1)
log_probs = torch.log_softmax(activations, 2)
probs = torch.exp(log_probs)
targets = torch.randint(1, num_labels+1, (batch_size * target_length,), dtype=torch.long)
targets_2d = targets.view(batch_size, target_length)
target_lengths = torch.tensor(batch_size*[target_length])
input_lengths = torch.tensor(batch_size*[input_length])
activations = log_probs.detach()
def time_cuda_ctc_loss(grout, *args):
torch.cuda.synchronize()
culo, culog_alpha = torch._ctc_loss(*args)
g, = torch.autograd.grad(culo, args[0], grout)
torch.cuda.synchronize()
def time_cudnn_ctc_loss(groupt, *args):
torch.cuda.synchronize()
culo, cugra= torch._cudnn_ctc_loss(*args)
g, = torch.autograd.grad(culo, args[0], grout)
torch.cuda.synchronize()
def time_warp_ctc_loss(grout, *args):
torch.cuda.synchronize()
culo = warpctc.ctc_loss(*args, blank_label=BLANK, size_average=False, length_average=False, reduce=False)
g, = torch.autograd.grad(culo, args[0], grout)
torch.cuda.synchronize()
if sys.argv[1] == 'cuda':
lpcu = log_probs.float().cuda().detach().requires_grad_()
args = [lpcu, targets_2d.cuda(), input_lengths.cuda(), target_lengths.cuda(), BLANK]
grout = lpcu.new_ones((batch_size,))
torch.cuda.synchronize()
print(timeit.repeat("time_cuda_ctc_loss(grout, *args)", number=1000, globals=globals()))
elif sys.argv[1] == 'cudnn':
lpcu = log_probs.float().cuda().detach().requires_grad_()
args = [lpcu, targets.int(), input_lengths.int(), target_lengths.int(), BLANK, True]
grout = lpcu.new_ones((batch_size,))
torch.cuda.synchronize()
print(timeit.repeat("time_cudnn_ctc_loss(grout, *args)", number=1000, globals=globals()))
elif sys.argv[1] == 'warpctc':
import warpctc
activations = activations.cuda().detach().requires_grad_()
args = [activations, input_lengths.int(), targets.int(), target_lengths.int()]
grout = activations.new_ones((batch_size,), device='cpu')
torch.cuda.synchronize()
print(timeit.repeat("time_warp_ctc_loss(grout, *args)", number=1000, globals=globals()))
```
I'll also link to a notebook that I used for writing up the algorithm in simple form and then test the against implementations against it.
Pull Request resolved: pytorch/pytorch#9628
Differential Revision: D8952453
Pulled By: ezyang
fbshipit-source-id: 18e073f40c2d01a7c96c1cdd41f6c70a06e35860
|
as reported by someone internally, the example is broken. Can we allow input_lengths and target_lengths to optionally also be a Tensor, not just tuple of ints. |
|
Oh. I'll fix that and look at allowing tensors.
|
Summary: I hope this helps me for the windows build failure in pytorch#9628 . Pull Request resolved: pytorch#9904 Differential Revision: D9026715 Pulled By: soumith fbshipit-source-id: bb97d41d060823f5a37bfc9a1659815b8b9f4eab
Summary: The CPU and CUDA variants are a direct transposition of Graves et al.'s description of the algorithm with the modification that is is in log space. The there also is a binding for the (much faster) CuDNN implementation. This could eventually fix pytorch#3420 I still need to add tests (TestNN seems much more elaborate than the other testing) and fix the bugs than invariably turn up during the testing. Also, I want to add some more code comments. I could use feedback on all sorts of things, including: - Type handling (cuda vs. cpu for the int tensors, dtype for the int tensors) - Input convention. I use log probs because that is what the gradients are for. - Launch parameters for the kernels - Errors and obmissions and anything else I'm not even aware of. Thank you for looking! In terms of performance it looks like it is superficially comparable to WarpCTC (and thus, but I have not systematically investigated this). I have read CuDNN is much faster than implementations because it does *not* use log-space, but also the gathering step is much much faster (but I avoided trying tricky things, it seems to contribute to warpctc's fragility). I might think some more which existing torch function (scatter or index..) I could learn from for that step. Average timings for the kernels from nvprof for some size: ``` CuDNN: 60.464us compute_alphas_and_betas 16.755us compute_grads_deterministic Cuda: 121.06us ctc_loss_backward_collect_gpu_kernel (= grads) 109.88us ctc_loss_gpu_kernel (= alphas) 98.517us ctc_loss_backward_betas_gpu_kernel (= betas) WarpCTC: 299.74us compute_betas_and_grad_kernel 66.977us compute_alpha_kernel ``` Of course, I still have the (silly) outer blocks loop rather than computing consecutive `s` in each thread which I might change, and there are a few other things where one could look for better implementations. Finally, it might not be unreasonable to start with these implementations, as the performance of the loss has to be seen in the context of the entire training computation, so this would likely dilute the relative speedup considerably. My performance measuring testing script: ``` import timeit import sys import torch num_labels = 10 target_length = 30 input_length = 50 eps = 1e-5 BLANK = 0#num_labels batch_size = 16 torch.manual_seed(5) activations = torch.randn(input_length, batch_size, num_labels + 1) log_probs = torch.log_softmax(activations, 2) probs = torch.exp(log_probs) targets = torch.randint(1, num_labels+1, (batch_size * target_length,), dtype=torch.long) targets_2d = targets.view(batch_size, target_length) target_lengths = torch.tensor(batch_size*[target_length]) input_lengths = torch.tensor(batch_size*[input_length]) activations = log_probs.detach() def time_cuda_ctc_loss(grout, *args): torch.cuda.synchronize() culo, culog_alpha = torch._ctc_loss(*args) g, = torch.autograd.grad(culo, args[0], grout) torch.cuda.synchronize() def time_cudnn_ctc_loss(groupt, *args): torch.cuda.synchronize() culo, cugra= torch._cudnn_ctc_loss(*args) g, = torch.autograd.grad(culo, args[0], grout) torch.cuda.synchronize() def time_warp_ctc_loss(grout, *args): torch.cuda.synchronize() culo = warpctc.ctc_loss(*args, blank_label=BLANK, size_average=False, length_average=False, reduce=False) g, = torch.autograd.grad(culo, args[0], grout) torch.cuda.synchronize() if sys.argv[1] == 'cuda': lpcu = log_probs.float().cuda().detach().requires_grad_() args = [lpcu, targets_2d.cuda(), input_lengths.cuda(), target_lengths.cuda(), BLANK] grout = lpcu.new_ones((batch_size,)) torch.cuda.synchronize() print(timeit.repeat("time_cuda_ctc_loss(grout, *args)", number=1000, globals=globals())) elif sys.argv[1] == 'cudnn': lpcu = log_probs.float().cuda().detach().requires_grad_() args = [lpcu, targets.int(), input_lengths.int(), target_lengths.int(), BLANK, True] grout = lpcu.new_ones((batch_size,)) torch.cuda.synchronize() print(timeit.repeat("time_cudnn_ctc_loss(grout, *args)", number=1000, globals=globals())) elif sys.argv[1] == 'warpctc': import warpctc activations = activations.cuda().detach().requires_grad_() args = [activations, input_lengths.int(), targets.int(), target_lengths.int()] grout = activations.new_ones((batch_size,), device='cpu') torch.cuda.synchronize() print(timeit.repeat("time_warp_ctc_loss(grout, *args)", number=1000, globals=globals())) ``` I'll also link to a notebook that I used for writing up the algorithm in simple form and then test the against implementations against it. Pull Request resolved: pytorch#9628 Differential Revision: D8952453 Pulled By: ezyang fbshipit-source-id: 18e073f40c2d01a7c96c1cdd41f6c70a06e35860
The CPU and CUDA variants are a direct transposition of Graves et al.'s description of the algorithm with the
modification that is is in log space.
The there also is a binding for the (much faster) CuDNN implementation.
This could eventually fix #3420
I still need to add tests (TestNN seems much more elaborate than the other testing) and fix the bugs than invariably turn up during the testing. Also, I want to add some more code comments.
I could use feedback on all sorts of things, including:
Thank you for looking!
In terms of performance it looks like it is superficially comparable to WarpCTC (and thus, but I have not systematically investigated this).
I have read CuDNN is much faster than implementations because it does not use log-space, but also the gathering step is much much faster (but I avoided trying tricky things, it seems to contribute to warpctc's fragility). I might think some more which existing torch function (scatter or index..) I could learn from for that step.
Average timings for the kernels from nvprof for some size:
Of course, I still have the (silly) outer blocks loop rather than computing consecutive
sin each thread which I might change, and there are a few other things where one could look for better implementations.Finally, it might not be unreasonable to start with these implementations, as the performance of the loss has to be seen in the context of the entire training computation, so this would likely dilute the relative speedup considerably.
My performance measuring testing script:
I'll also link to a notebook that I used for writing up the algorithm in simple form and then test the against implementations against it.