Skip to content

Commit

Permalink
Fix CTC loss for zero-length targets on GPU (#23298)
Browse files Browse the repository at this point in the history
Summary:
Fixes: pytorch/pytorch#18215 at last!

Also sprinkle tests...
Pull Request resolved: pytorch/pytorch#23298

Differential Revision: D16582145

Pulled By: soumith

fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
  • Loading branch information
t-vi authored and facebook-github-bot committed Jul 31, 2019
1 parent dc4f92e commit 89de07c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 34 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu
}
}
if (reduction == Reduction::Mean) {
auto target_lengths_t = at::tensor(target_lengths, res.options());
auto target_lengths_t =
at::tensor(target_lengths, res.options()).clamp_min(1);
return (res / target_lengths_t).mean();
} else if (reduction == Reduction::Sum) {
return res.sum();
Expand Down
122 changes: 89 additions & 33 deletions aten/src/ATen/native/cuda/LossCTC.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ namespace native {

namespace {

// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template<typename target_t>
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t get_target_prime(
const target_t* __restrict__ target,
int64_t offset,
int64_t stride,
int64_t idx,
int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
Expand Down Expand Up @@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
if (target_length > 0) {
la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
}
else {
la = neginf;
}
la = target_length == 0 ? neginf
: log_probs_data
[lp_batch_offset +
lp_char_stride *
get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
1,
BLANK)];
break;
default:
la = neginf;
Expand All @@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2*target_length+1) {
current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s > 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t=1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
// lamax is the maximum for the logsumexp trick.
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
Expand Down Expand Up @@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)];
scalar_t l2 = target_length > 0
? log_alpha_data
[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
Expand Down Expand Up @@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);

dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
scalar_t lb;
if (s == 2*target_length) {
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
} else if ((target_length > 0) && (s == 2*target_length-1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
} else {
lb = neginf;
Expand All @@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2*target_length+1) {
current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s < 2*target_length-1) &&
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
current_target_prime));
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s + 2,
BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
}
// now go backward in t. Note that we need to skip the last timestep that we did above.
for (int64_t t=max_input_length-2; t>=0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
Expand All @@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];

log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
} else if (
(s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
(t >= input_length))) {
log_beta_data
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
neginf;
}
}
}
Expand Down Expand Up @@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,

// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2*max_target_length+1; s++) {
if ((target_length > 0) && (s < 2*target_length+1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
Expand Down Expand Up @@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
{
dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_beta.data<scalar_t>(),
log_probs.data<scalar_t>(), input_lengths_t.data<int64_t>(), log_probs.size(0),
Expand Down Expand Up @@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length) {
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand All @@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand Down

0 comments on commit 89de07c

Please sign in to comment.