Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu
}
}
if (reduction == Reduction::Mean) {
auto target_lengths_t = at::tensor(target_lengths, res.options());
auto target_lengths_t =
at::tensor(target_lengths, res.options()).clamp_min(1);
return (res / target_lengths_t).mean();
} else if (reduction == Reduction::Sum) {
return res.sum();
Expand Down
122 changes: 89 additions & 33 deletions aten/src/ATen/native/cuda/LossCTC.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ namespace native {

namespace {

// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template<typename target_t>
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t get_target_prime(
const target_t* __restrict__ target,
int64_t offset,
int64_t stride,
int64_t idx,
int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
Expand Down Expand Up @@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
if (target_length > 0) {
la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
}
else {
la = neginf;
}
la = target_length == 0 ? neginf
: log_probs_data
[lp_batch_offset +
lp_char_stride *
get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
1,
BLANK)];
break;
default:
la = neginf;
Expand All @@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2*target_length+1) {
current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s > 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t=1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
// lamax is the maximum for the logsumexp trick.
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
Expand Down Expand Up @@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)];
scalar_t l2 = target_length > 0
? log_alpha_data
[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
Expand Down Expand Up @@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);

dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
scalar_t lb;
if (s == 2*target_length) {
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
} else if ((target_length > 0) && (s == 2*target_length-1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
} else {
lb = neginf;
Expand All @@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2*target_length+1) {
current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s < 2*target_length-1) &&
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
current_target_prime));
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s + 2,
BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
}
// now go backward in t. Note that we need to skip the last timestep that we did above.
for (int64_t t=max_input_length-2; t>=0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
Expand All @@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];

log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
} else if (
(s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
(t >= input_length))) {
log_beta_data
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
neginf;
}
}
}
Expand Down Expand Up @@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,

// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2*max_target_length+1; s++) {
if ((target_length > 0) && (s < 2*target_length+1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
Expand Down Expand Up @@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
{
dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_beta.data<scalar_t>(),
log_probs.data<scalar_t>(), input_lengths_t.data<int64_t>(), log_probs.size(0),
Expand Down Expand Up @@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length) {
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand All @@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand Down
39 changes: 28 additions & 11 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,25 +1609,42 @@ def test_ctc_loss(self):
target_length = 15
gradcheck_input_size = 10

# device, input_length
tests = [('cpu', 150, False),
('cpu', 150, True)]
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2

# device, input_length, vary_lengths, zero_lengths
tests = [('cpu', 150, False, ZERO_NONE),
('cpu', 150, True, ZERO_NONE),
('cpu', 50, True, ZERO_SOME),
('cpu', 50, True, ZERO_ALL)]
if torch.cuda.is_available():
tests += [('cuda', 50, False),
('cuda', 150, False),
('cuda', 50, True),
('cuda', 150, True)]

for device, input_length, vary_lengths in tests:
tests += [('cuda', 50, False, ZERO_NONE),
('cuda', 150, False, ZERO_NONE),
('cuda', 50, True, ZERO_NONE),
('cuda', 150, True, ZERO_NONE),
('cuda', 50, True, ZERO_SOME),
('cuda', 150, True, ZERO_SOME),
('cuda', 50, True, ZERO_ALL),
('cuda', 150, True, ZERO_ALL)]

for device, input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, device=device, requires_grad=True)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0

def ctc_after_softmax(x):
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
Expand Down
19 changes: 14 additions & 5 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5599,20 +5599,29 @@ def test_CTCLoss_lengthchecks_cpu(self):
with self.assertRaises(RuntimeError):
torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)

def test_CTCLoss_empty_target_cpu(self):
def _test_CTCLoss_empty_target(self, device):
target_lengths = [0, 0, 0]
input_lengths = [50, 50, 50]
targets = torch.randint(1, 15, (0,), dtype=torch.int)
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
self.assertTrue((loss >= 0).all().item())
self.assertAlmostEqual(-log_probs.sum(0)[:, 0], loss)

target_lengths = [0, 9, 0]
input_lengths = [50, 50, 50]
targets = torch.randint(1, 15, (9,), dtype=torch.int)
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
self.assertTrue((loss >= 0).all().item())
self.assertAlmostEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])

def test_CTCLoss_empty_target_cpu(self):
self._test_CTCLoss_empty_target('cpu')

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_CTCLoss_empty_target_cuda(self):
self._test_CTCLoss_empty_target('cuda')

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_CTCLoss_zero_infinity(self):
Expand Down