Skip to content

Commit b460a19

Browse files
Yongqiang Wangfacebook-github-bot
authored andcommitted
Per discussion at #21244, fix bugs in (#21392)
Summary: Pull Request resolved: #21392 as discussed at #21244, we found some values in log_beta are not properly initialized. This diff will 1) initialize all log_beta to -inf; 2) fix a tricky compare condition; 3) zero all the gradient elements corresponding to padding to zero. Offline experiments show that this diff can fix previous seen NaN loss. Differential Revision: D15637977 fbshipit-source-id: 477008a5e11aae946bd2aa401ab7e0c513421af0
1 parent 42b2f56 commit b460a19

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
339339
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
340340

341341
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
342-
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s > 2*target_length+1) || (t >= input_length))) {
342+
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
343343
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
344344
}
345345
}
@@ -477,6 +477,40 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
477477
}
478478
}
479479

480+
// This is to zero gradients which corresponding to the out-of-sequence position
481+
// Those gradients should not be used in any model update since the input
482+
// elements are padded
483+
template<typename scalar_t>
484+
__global__ void
485+
#if defined (__HIP_PLATFORM_HCC__)
486+
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
487+
#endif
488+
ctc_loss_zero_padded_gradients(
489+
scalar_t* __restrict__ gradient_data, /* (T, B, D) layout */
490+
const int64_t* __restrict__ input_lengths, /* (B, ) layout */
491+
int64_t gr_timestep_stride,
492+
int64_t gr_batch_stride,
493+
int64_t gr_label_stride,
494+
int64_t max_input_length, /* T */
495+
int64_t batch_size, /* B */
496+
int64_t num_labels /* D */ ) {
497+
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
498+
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
499+
500+
if (b >= batch_size || t >= max_input_length) {
501+
return;
502+
}
503+
504+
scalar_t input_length = input_lengths[b];
505+
if (t >= input_length) {
506+
for (int l = 0; l < num_labels; l++)
507+
gradient_data[
508+
t * gr_timestep_stride + b * gr_batch_stride + l * gr_label_stride]
509+
= 0.0f;
510+
}
511+
}
512+
513+
480514
// The backward. It essentially computes eq 16 by using the above kernels.
481515
// We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
482516
template<typename scalar_t, ScalarType target_scalar_type>
@@ -517,7 +551,9 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
517551
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
518552
tg_batch_offsets = tg_batch_offsets.cuda();
519553

520-
auto log_beta = at::empty_like(log_alpha);
554+
Tensor log_beta = at::empty_like(log_alpha);
555+
log_beta.fill_(neginf);
556+
521557
Tensor grad = at::full_like(log_probs, neginf); // initialization for log(sum (alpha beta))
522558

523559
// As above, there may be better configurations to use.
@@ -621,6 +657,31 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
621657
batch_size, num_labels, BLANK, zero_infinity);
622658
THCudaCheck(cudaGetLastError()); // catch launch errors
623659
}
660+
661+
// zero those invalid graident elements due to padding
662+
{
663+
int threads_input = max_threads;
664+
while (threads_input / 2 >= log_probs.size(0)) {
665+
threads_input /= 2;
666+
}
667+
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
668+
dim3 block(threads_input, threads_batch);
669+
dim3 grid(
670+
(log_probs.size(0) + threads_input-1)/threads_input,
671+
(batch_size+threads_batch-1)/threads_batch);
672+
ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
673+
grad.data<scalar_t>(),
674+
input_lengths_t.data<int64_t>(),
675+
grad.stride(0),
676+
grad.stride(1),
677+
grad.stride(2),
678+
grad.size(0),
679+
grad.size(1),
680+
grad.size(2)
681+
);
682+
THCudaCheck(cudaGetLastError());
683+
}
684+
624685
return grad;
625686
}
626687

0 commit comments

Comments
 (0)