Skip to content

Commit 6f5fb78

Browse files
ssnlsoumith
authored andcommitted
Fix CTC loss for zero-length targets on GPU (#23298) (#23715)
Summary: Fixes: #18215 at last! Also sprinkle tests... Pull Request resolved: #23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
1 parent 4f52116 commit 6f5fb78

File tree

4 files changed

+133
-50
lines changed

4 files changed

+133
-50
lines changed

aten/src/ATen/native/LossCTC.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu
374374
}
375375
}
376376
if (reduction == Reduction::Mean) {
377-
auto target_lengths_t = at::tensor(target_lengths, res.options());
377+
auto target_lengths_t =
378+
at::tensor(target_lengths, res.options()).clamp_min(1);
378379
return (res / target_lengths_t).mean();
379380
} else if (reduction == Reduction::Sum) {
380381
return res.sum();

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,20 @@ namespace native {
2424

2525
namespace {
2626

27-
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
28-
// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
29-
template<typename target_t>
30-
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
27+
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
28+
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
29+
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
30+
// - note that no bound-checking is done
31+
// - it is important to only call it witth idx == 0 if the target length is 0
32+
// - __restrict__ impact to be measured, see
33+
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
34+
template <typename target_t>
35+
__device__ static inline int64_t get_target_prime(
36+
const target_t* __restrict__ target,
37+
int64_t offset,
38+
int64_t stride,
39+
int64_t idx,
40+
int64_t BLANK) {
3141
if (idx % 2 == 0) {
3242
return BLANK;
3343
} else {
@@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
8090
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
8191
break;
8292
case 1:
83-
if (target_length > 0) {
84-
la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
85-
}
86-
else {
87-
la = neginf;
88-
}
93+
la = target_length == 0 ? neginf
94+
: log_probs_data
95+
[lp_batch_offset +
96+
lp_char_stride *
97+
get_target_prime(
98+
targets_data,
99+
tg_batch_offset,
100+
tg_target_stride,
101+
1,
102+
BLANK)];
89103
break;
90104
default:
91105
la = neginf;
@@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
100114
// These two only depend on s, so we can cache them.
101115
int64_t current_char; // l_s in eq (6)
102116
bool have_three; // flag which of the two cases in eq (6) we have
103-
if (s < 2*target_length+1) {
104-
current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
105-
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
117+
if (s < 2 * target_length + 1 && target_length > 0) {
118+
current_char = get_target_prime(
119+
targets_data,
120+
tg_batch_offset,
121+
tg_target_stride,
122+
s,
123+
BLANK);
124+
have_three =
125+
((s > 1) &&
126+
(get_target_prime(
127+
targets_data,
128+
tg_batch_offset,
129+
tg_target_stride,
130+
s - 2,
131+
BLANK) != current_char));
106132
} else {
107133
current_char = BLANK;
108134
have_three = false;
109135
}
110136
for (int64_t t=1; t < max_input_length; t++) {
111137
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
112-
if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) {
138+
if ((t < input_length) && (s < 2 * target_length + 1)) {
113139
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
114140
// lamax is the maximum for the logsumexp trick.
115141
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
@@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
146172
// compute the loss (eq (8))
147173
if (threadIdx.x == 0) {
148174
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
149-
scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)];
175+
scalar_t l2 = target_length > 0
176+
? log_alpha_data
177+
[la_batch_offset + la_input_stride * (input_length - 1) +
178+
la_target_stride * (target_length * 2 - 1)]
179+
: neginf;
150180
scalar_t m = ((l1 > l2) ? l1 : l2);
151181
m = ((m == neginf) ? 0 : m);
152182
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
@@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
236266
threads_target /= 2;
237267
}
238268
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
239-
240269
dim3 block(threads_target, threads_batch);
241270
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
242271
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
285314
scalar_t lb;
286315
if (s == 2*target_length) {
287316
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
288-
} else if ((target_length > 0) && (s == 2*target_length-1)) {
289-
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
317+
} else if (s == 2 * target_length - 1) { // false for target_length == 0
318+
int64_t current_target_prime = get_target_prime(
319+
targets_data,
320+
tg_batch_offset,
321+
tg_target_stride,
322+
s,
323+
BLANK);
290324
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
291325
} else {
292326
lb = neginf;
@@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
301335
int64_t s = threadIdx.x + block_s;
302336
int64_t current_target_prime;
303337
bool have_three;
304-
if (s < 2*target_length+1) {
305-
current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
306-
have_three = ((s < 2*target_length-1) &&
307-
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
308-
current_target_prime));
338+
if (s < 2 * target_length + 1 && target_length > 0) {
339+
current_target_prime = get_target_prime(
340+
targets_data,
341+
tg_batch_offset,
342+
tg_target_stride,
343+
s,
344+
BLANK);
345+
have_three =
346+
((s < 2 * target_length - 1) &&
347+
(get_target_prime(
348+
targets_data,
349+
tg_batch_offset,
350+
tg_target_stride,
351+
s + 2,
352+
BLANK) != current_target_prime));
309353
} else {
310354
current_target_prime = BLANK;
311355
have_three = false;
312356
}
313357
// now go backward in t. Note that we need to skip the last timestep that we did above.
314358
for (int64_t t=max_input_length-2; t>=0; t--) {
315359
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
316-
if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) {
360+
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
317361
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
318362
scalar_t lbmax = lb1;
319363
scalar_t lb2, lb3;
@@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
339383
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
340384

341385
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
342-
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
343-
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
386+
} else if (
387+
(s < 2 * max_target_length + 1) &&
388+
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
389+
(t >= input_length))) {
390+
log_beta_data
391+
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
392+
neginf;
344393
}
345394
}
346395
}
@@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
448497

449498
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
450499
for (int s = 0; s < 2*max_target_length+1; s++) {
451-
if ((target_length > 0) && (s < 2*target_length+1)) {
452-
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
500+
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
501+
int64_t current_target_prime = get_target_prime(
502+
targets_data,
503+
tg_batch_offset,
504+
tg_target_stride,
505+
s,
506+
BLANK);
453507
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
454508
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
455509
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
@@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
569623
{
570624
dim3 block(threads_target, threads_batch);
571625
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
572-
573626
ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
574627
(log_beta.data<scalar_t>(),
575628
log_probs.data<scalar_t>(), input_lengths_t.data<int64_t>(), log_probs.size(0),
@@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
612665
// For the non-blank characters, we use a kernel to compute the subtrahend.
613666
// Again we might configure block and grid in a better way.
614667
int threads_target = max_threads;
615-
while (threads_target / 2 >= max_target_length) {
668+
while (threads_target / 2 >= max_target_length && threads_target > 1) {
616669
threads_target /= 2;
617670
}
618671
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
619672
dim3 block(threads_target, threads_batch);
620-
dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
673+
dim3 grid(
674+
std::max<int>(
675+
(max_target_length + threads_target - 1) / threads_target, 1),
676+
(batch_size + threads_batch - 1) / threads_batch,
677+
1);
621678
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
622679
(grad.data<scalar_t>(),
623680
grad_out.data<scalar_t>(), grad_out.stride(0),
@@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
635692
} else { // small problem, use naive algorithm
636693
// Still no block/grid configuration guru...
637694
int threads_input = max_threads;
638-
while (threads_input / 2 >= log_probs.size(0)) {
695+
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
639696
threads_input /= 2;
640697
}
641698
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
642699
dim3 block(threads_input, threads_batch);
643700
dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
644-
645701
ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
646702
(grad.data<scalar_t>(),
647703
grad_out.data<scalar_t>(), grad_out.stride(0),

test/test_autograd.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,25 +1609,42 @@ def test_ctc_loss(self):
16091609
target_length = 15
16101610
gradcheck_input_size = 10
16111611

1612-
# device, input_length
1613-
tests = [('cpu', 150, False),
1614-
('cpu', 150, True)]
1612+
ZERO_NONE = 0
1613+
ZERO_SOME = 1
1614+
ZERO_ALL = 2
1615+
1616+
# device, input_length, vary_lengths, zero_lengths
1617+
tests = [('cpu', 150, False, ZERO_NONE),
1618+
('cpu', 150, True, ZERO_NONE),
1619+
('cpu', 50, True, ZERO_SOME),
1620+
('cpu', 50, True, ZERO_ALL)]
16151621
if torch.cuda.is_available():
1616-
tests += [('cuda', 50, False),
1617-
('cuda', 150, False),
1618-
('cuda', 50, True),
1619-
('cuda', 150, True)]
1620-
1621-
for device, input_length, vary_lengths in tests:
1622+
tests += [('cuda', 50, False, ZERO_NONE),
1623+
('cuda', 150, False, ZERO_NONE),
1624+
('cuda', 50, True, ZERO_NONE),
1625+
('cuda', 150, True, ZERO_NONE),
1626+
('cuda', 50, True, ZERO_SOME),
1627+
('cuda', 150, True, ZERO_SOME),
1628+
('cuda', 50, True, ZERO_ALL),
1629+
('cuda', 150, True, ZERO_ALL)]
1630+
1631+
for device, input_length, vary_lengths, zero_mode in tests:
16221632
targets = torch.randint(1, num_labels, (batch_size, target_length),
16231633
device=device, dtype=torch.long)
16241634
x = torch.randn(gradcheck_input_size, device=device, requires_grad=True)
16251635
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
16261636
device=device)
16271637
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
16281638
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
1629-
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
1630-
if vary_lengths else target_length) for i in range(batch_size)]
1639+
if zero_mode == ZERO_ALL:
1640+
target_lengths = [0 for _ in range(batch_size)]
1641+
else:
1642+
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
1643+
if vary_lengths else target_length) for _ in range(batch_size)]
1644+
if zero_mode == ZERO_SOME:
1645+
idxes = torch.randint(0, batch_size, (10,))
1646+
for i in idxes:
1647+
target_lengths[i] = 0
16311648

16321649
def ctc_after_softmax(x):
16331650
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]

test/test_nn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5599,20 +5599,29 @@ def test_CTCLoss_lengthchecks_cpu(self):
55995599
with self.assertRaises(RuntimeError):
56005600
torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
56015601

5602-
def test_CTCLoss_empty_target_cpu(self):
5602+
def _test_CTCLoss_empty_target(self, device):
56035603
target_lengths = [0, 0, 0]
56045604
input_lengths = [50, 50, 50]
5605-
targets = torch.randint(1, 15, (0,), dtype=torch.int)
5606-
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
5605+
targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
5606+
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
56075607
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
56085608
self.assertTrue((loss >= 0).all().item())
5609+
self.assertAlmostEqual(-log_probs.sum(0)[:, 0], loss)
56095610

56105611
target_lengths = [0, 9, 0]
56115612
input_lengths = [50, 50, 50]
5612-
targets = torch.randint(1, 15, (9,), dtype=torch.int)
5613-
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
5613+
targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
5614+
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
56145615
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
56155616
self.assertTrue((loss >= 0).all().item())
5617+
self.assertAlmostEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])
5618+
5619+
def test_CTCLoss_empty_target_cpu(self):
5620+
self._test_CTCLoss_empty_target('cpu')
5621+
5622+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
5623+
def test_CTCLoss_empty_target_cuda(self):
5624+
self._test_CTCLoss_empty_target('cuda')
56165625

56175626
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
56185627
def test_CTCLoss_zero_infinity(self):

0 commit comments

Comments
 (0)