Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions aten/src/ATen/native/cuda/LossCTC.cu
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];

log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s > 2*target_length+1) || (t >= input_length))) {
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
}
}
Expand Down Expand Up @@ -477,6 +477,40 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
}
}

// This is to zero gradients which corresponding to the out-of-sequence position
// Those gradients should not be used in any model update since the input
// elements are padded
template<typename scalar_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_zero_padded_gradients(
scalar_t* __restrict__ gradient_data, /* (T, B, D) layout */
const int64_t* __restrict__ input_lengths, /* (B, ) layout */
int64_t gr_timestep_stride,
int64_t gr_batch_stride,
int64_t gr_label_stride,
int64_t max_input_length, /* T */
int64_t batch_size, /* B */
int64_t num_labels /* D */ ) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;

if (b >= batch_size || t >= max_input_length) {
return;
}

scalar_t input_length = input_lengths[b];
if (t >= input_length) {
for (int l = 0; l < num_labels; l++)
gradient_data[
t * gr_timestep_stride + b * gr_batch_stride + l * gr_label_stride]
= 0.0f;
}
}


// The backward. It essentially computes eq 16 by using the above kernels.
// We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
template<typename scalar_t, ScalarType target_scalar_type>
Expand Down Expand Up @@ -517,7 +551,9 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();

auto log_beta = at::empty_like(log_alpha);
Tensor log_beta = at::empty_like(log_alpha);
log_beta.fill_(neginf);

Tensor grad = at::full_like(log_probs, neginf); // initialization for log(sum (alpha beta))

// As above, there may be better configurations to use.
Expand Down Expand Up @@ -621,6 +657,31 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
batch_size, num_labels, BLANK, zero_infinity);
THCudaCheck(cudaGetLastError()); // catch launch errors
}

// zero those invalid graident elements due to padding
{
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid(
(log_probs.size(0) + threads_input-1)/threads_input,
(batch_size+threads_batch-1)/threads_batch);
ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
grad.data<scalar_t>(),
input_lengths_t.data<int64_t>(),
grad.stride(0),
grad.stride(1),
grad.stride(2),
grad.size(0),
grad.size(1),
grad.size(2)
);
THCudaCheck(cudaGetLastError());
}

return grad;
}

Expand Down