@@ -339,7 +339,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
339339 + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
340340
341341 log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
342- } else if ((s < 2 *max_target_length+1 ) && ((target_length == 0 ) || (s > 2 *target_length+1 ) || (t >= input_length))) {
342+ } else if ((s < 2 *max_target_length+1 ) && ((target_length == 0 ) || (s >= 2 *target_length+1 ) || (t >= input_length))) {
343343 log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
344344 }
345345 }
@@ -477,6 +477,40 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
477477 }
478478}
479479
480+ // This is to zero gradients which corresponding to the out-of-sequence position
481+ // Those gradients should not be used in any model update since the input
482+ // elements are padded
483+ template <typename scalar_t >
484+ __global__ void
485+ #if defined (__HIP_PLATFORM_HCC__)
486+ C10_LAUNCH_BOUNDS_2 ((std::is_same<scalar_t , float >::value ? 1024 : 896 ), 1)
487+ #endif
488+ ctc_loss_zero_padded_gradients (
489+ scalar_t * __restrict__ gradient_data, /* (T, B, D) layout */
490+ const int64_t * __restrict__ input_lengths, /* (B, ) layout */
491+ int64_t gr_timestep_stride,
492+ int64_t gr_batch_stride,
493+ int64_t gr_label_stride,
494+ int64_t max_input_length, /* T */
495+ int64_t batch_size, /* B */
496+ int64_t num_labels /* D */ ) {
497+ int64_t b = threadIdx .y + blockIdx .y * blockDim .y ;
498+ int64_t t = threadIdx .x + blockIdx .x * blockDim .x ;
499+
500+ if (b >= batch_size || t >= max_input_length) {
501+ return ;
502+ }
503+
504+ scalar_t input_length = input_lengths[b];
505+ if (t >= input_length) {
506+ for (int l = 0 ; l < num_labels; l++)
507+ gradient_data[
508+ t * gr_timestep_stride + b * gr_batch_stride + l * gr_label_stride]
509+ = 0 .0f ;
510+ }
511+ }
512+
513+
480514// The backward. It essentially computes eq 16 by using the above kernels.
481515// We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
482516template <typename scalar_t , ScalarType target_scalar_type>
@@ -517,7 +551,9 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
517551 auto input_lengths_t = at::tensor (input_lengths, targets.options ().dtype (kLong ));
518552 tg_batch_offsets = tg_batch_offsets.cuda ();
519553
520- auto log_beta = at::empty_like (log_alpha);
554+ Tensor log_beta = at::empty_like (log_alpha);
555+ log_beta.fill_ (neginf);
556+
521557 Tensor grad = at::full_like (log_probs, neginf); // initialization for log(sum (alpha beta))
522558
523559 // As above, there may be better configurations to use.
@@ -621,6 +657,31 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
621657 batch_size, num_labels, BLANK, zero_infinity);
622658 THCudaCheck (cudaGetLastError ()); // catch launch errors
623659 }
660+
661+ // zero those invalid graident elements due to padding
662+ {
663+ int threads_input = max_threads;
664+ while (threads_input / 2 >= log_probs.size (0 )) {
665+ threads_input /= 2 ;
666+ }
667+ threads_batch = std::min (max_threads / threads_input, (int ) batch_size);
668+ dim3 block (threads_input, threads_batch);
669+ dim3 grid (
670+ (log_probs.size (0 ) + threads_input-1 )/threads_input,
671+ (batch_size+threads_batch-1 )/threads_batch);
672+ ctc_loss_zero_padded_gradients<scalar_t ><<<grid, block, 0 , stream>>> (
673+ grad.data <scalar_t >(),
674+ input_lengths_t .data <int64_t >(),
675+ grad.stride (0 ),
676+ grad.stride (1 ),
677+ grad.stride (2 ),
678+ grad.size (0 ),
679+ grad.size (1 ),
680+ grad.size (2 )
681+ );
682+ THCudaCheck (cudaGetLastError ());
683+ }
684+
624685 return grad;
625686}
626687
0 commit comments