From 772d7c717b5464c6027667245413b66f3ee56432 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 31 Jul 2019 11:56:48 -0700 Subject: [PATCH] Fix CTC loss for zero-length targets on GPU (#23298) Summary: Fixes: https://github.com/pytorch/pytorch/issues/18215 at last! Also sprinkle tests... Pull Request resolved: https://github.com/pytorch/pytorch/pull/23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d --- aten/src/ATen/native/LossCTC.cpp | 3 +- aten/src/ATen/native/cuda/LossCTC.cu | 122 +++++++++++++++++++-------- test/test_autograd.py | 39 ++++++--- test/test_nn.py | 19 +++-- 4 files changed, 133 insertions(+), 50 deletions(-) diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index a8f91c9653668..8337761f0a70c 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu } } if (reduction == Reduction::Mean) { - auto target_lengths_t = at::tensor(target_lengths, res.options()); + auto target_lengths_t = + at::tensor(target_lengths, res.options()).clamp_min(1); return (res / target_lengths_t).mean(); } else if (reduction == Reduction::Sum) { return res.sum(); diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 89d5d4f4af3e8..f2559b00397be 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -24,10 +24,20 @@ namespace native { namespace { -// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done -// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/ -template -__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) { +// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) +// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in +// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK +// - note that no bound-checking is done +// - it is important to only call it witth idx == 0 if the target length is 0 +// - __restrict__ impact to be measured, see +// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/ +template +__device__ static inline int64_t get_target_prime( + const target_t* __restrict__ target, + int64_t offset, + int64_t stride, + int64_t idx, + int64_t BLANK) { if (idx % 2 == 0) { return BLANK; } else { @@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK]; break; case 1: - if (target_length > 0) { - la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)]; - } - else { - la = neginf; - } + la = target_length == 0 ? neginf + : log_probs_data + [lp_batch_offset + + lp_char_stride * + get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + 1, + BLANK)]; break; default: la = neginf; @@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, // These two only depend on s, so we can cache them. int64_t current_char; // l_s in eq (6) bool have_three; // flag which of the two cases in eq (6) we have - if (s < 2*target_length+1) { - current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK); - have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char)); + if (s < 2 * target_length + 1 && target_length > 0) { + current_char = get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s, + BLANK); + have_three = + ((s > 1) && + (get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s - 2, + BLANK) != current_char)); } else { current_char = BLANK; have_three = false; } for (int64_t t=1; t < max_input_length; t++) { __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch - if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) { + if ((t < input_length) && (s < 2 * target_length + 1)) { // only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands, // lamax is the maximum for the logsumexp trick. scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s]; @@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, // compute the loss (eq (8)) if (threadIdx.x == 0) { scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)]; - scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)]; + scalar_t l2 = target_length > 0 + ? log_alpha_data + [la_batch_offset + la_input_stride * (input_length - 1) + + la_target_stride * (target_length * 2 - 1)] + : neginf; scalar_t m = ((l1 > l2) ? l1 : l2); m = ((m == neginf) ? 0 : m); scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; @@ -236,7 +266,6 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); - dim3 block(threads_target, threads_batch); dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, scalar_t lb; if (s == 2*target_length) { lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK]; - } else if ((target_length > 0) && (s == 2*target_length-1)) { - int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK); + } else if (s == 2 * target_length - 1) { // false for target_length == 0 + int64_t current_target_prime = get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s, + BLANK); lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime]; } else { lb = neginf; @@ -301,11 +335,21 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, int64_t s = threadIdx.x + block_s; int64_t current_target_prime; bool have_three; - if (s < 2*target_length+1) { - current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK); - have_three = ((s < 2*target_length-1) && - (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) != - current_target_prime)); + if (s < 2 * target_length + 1 && target_length > 0) { + current_target_prime = get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s, + BLANK); + have_three = + ((s < 2 * target_length - 1) && + (get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s + 2, + BLANK) != current_target_prime)); } else { current_target_prime = BLANK; have_three = false; @@ -313,7 +357,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, // now go backward in t. Note that we need to skip the last timestep that we did above. for (int64_t t=max_input_length-2; t>=0; t--) { __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item - if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) { + if ((t < input_length - 1) && (s < 2 * target_length + 1)) { scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s]; scalar_t lbmax = lb1; scalar_t lb2, lb3; @@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime]; log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb; - } else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) { - log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf; + } else if ( + (s < 2 * max_target_length + 1) && + (((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) || + (t >= input_length))) { + log_beta_data + [lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = + neginf; } } } @@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s] for (int s = 0; s < 2*max_target_length+1; s++) { - if ((target_length > 0) && (s < 2*target_length+1)) { - int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK); + if (s < 2 * target_length + 1) { // if target_length == 0, s == 0 + int64_t current_target_prime = get_target_prime( + targets_data, + tg_batch_offset, + tg_target_stride, + s, + BLANK); scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]); scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime]; @@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ { dim3 block(threads_target, threads_batch); dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); - ctc_loss_backward_log_beta_gpu_kernel<<>> (log_beta.data(), log_probs.data(), input_lengths_t.data(), log_probs.size(0), @@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ // For the non-blank characters, we use a kernel to compute the subtrahend. // Again we might configure block and grid in a better way. int threads_target = max_threads; - while (threads_target / 2 >= max_target_length) { + while (threads_target / 2 >= max_target_length && threads_target > 1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); dim3 block(threads_target, threads_batch); - dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); + dim3 grid( + std::max( + (max_target_length + threads_target - 1) / threads_target, 1), + (batch_size + threads_batch - 1) / threads_batch, + 1); ctc_loss_backward_collect_nonblank_gpu_kernel<<>> (grad.data(), grad_out.data(), grad_out.stride(0), @@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ } else { // small problem, use naive algorithm // Still no block/grid configuration guru... int threads_input = max_threads; - while (threads_input / 2 >= log_probs.size(0)) { + while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) { threads_input /= 2; } threads_batch = std::min(max_threads / threads_input, (int) batch_size); dim3 block(threads_input, threads_batch); dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); - ctc_loss_backward_collect_gpu_kernel<<>> (grad.data(), grad_out.data(), grad_out.stride(0), diff --git a/test/test_autograd.py b/test/test_autograd.py index 6b97d576ee0d5..f3e491adbd7a3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1609,16 +1609,26 @@ def test_ctc_loss(self): target_length = 15 gradcheck_input_size = 10 - # device, input_length - tests = [('cpu', 150, False), - ('cpu', 150, True)] + ZERO_NONE = 0 + ZERO_SOME = 1 + ZERO_ALL = 2 + + # device, input_length, vary_lengths, zero_lengths + tests = [('cpu', 150, False, ZERO_NONE), + ('cpu', 150, True, ZERO_NONE), + ('cpu', 50, True, ZERO_SOME), + ('cpu', 50, True, ZERO_ALL)] if torch.cuda.is_available(): - tests += [('cuda', 50, False), - ('cuda', 150, False), - ('cuda', 50, True), - ('cuda', 150, True)] - - for device, input_length, vary_lengths in tests: + tests += [('cuda', 50, False, ZERO_NONE), + ('cuda', 150, False, ZERO_NONE), + ('cuda', 50, True, ZERO_NONE), + ('cuda', 150, True, ZERO_NONE), + ('cuda', 50, True, ZERO_SOME), + ('cuda', 150, True, ZERO_SOME), + ('cuda', 50, True, ZERO_ALL), + ('cuda', 150, True, ZERO_ALL)] + + for device, input_length, vary_lengths, zero_mode in tests: targets = torch.randint(1, num_labels, (batch_size, target_length), device=device, dtype=torch.long) x = torch.randn(gradcheck_input_size, device=device, requires_grad=True) @@ -1626,8 +1636,15 @@ def test_ctc_loss(self): device=device) input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() if vary_lengths or i == 0 else input_length) for i in range(batch_size)] - target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() - if vary_lengths else target_length) for i in range(batch_size)] + if zero_mode == ZERO_ALL: + target_lengths = [0 for _ in range(batch_size)] + else: + target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() + if vary_lengths else target_length) for _ in range(batch_size)] + if zero_mode == ZERO_SOME: + idxes = torch.randint(0, batch_size, (10,)) + for i in idxes: + target_lengths[i] = 0 def ctc_after_softmax(x): x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] diff --git a/test/test_nn.py b/test/test_nn.py index e0db0d6bd5237..a972a50c19ad5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5599,20 +5599,29 @@ def test_CTCLoss_lengthchecks_cpu(self): with self.assertRaises(RuntimeError): torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) - def test_CTCLoss_empty_target_cpu(self): + def _test_CTCLoss_empty_target(self, device): target_lengths = [0, 0, 0] input_lengths = [50, 50, 50] - targets = torch.randint(1, 15, (0,), dtype=torch.int) - log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) + targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device) + log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2) loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') self.assertTrue((loss >= 0).all().item()) + self.assertAlmostEqual(-log_probs.sum(0)[:, 0], loss) target_lengths = [0, 9, 0] input_lengths = [50, 50, 50] - targets = torch.randint(1, 15, (9,), dtype=torch.int) - log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) + targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device) + log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2) loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') self.assertTrue((loss >= 0).all().item()) + self.assertAlmostEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]]) + + def test_CTCLoss_empty_target_cpu(self): + self._test_CTCLoss_empty_target('cpu') + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_CTCLoss_empty_target_cuda(self): + self._test_CTCLoss_empty_target('cuda') @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_CTCLoss_zero_infinity(self):