Skip to content

Commit

Permalink
Fix CTC loss for zero-length targets on GPU (pytorch#23298)
Browse files Browse the repository at this point in the history
Summary:
Fixes: pytorch#18215 at last!

Also sprinkle tests...
Pull Request resolved: pytorch#23298

Differential Revision: D16582145

Pulled By: soumith

fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
  • Loading branch information
t-vi authored and ssnl committed Aug 2, 2019
1 parent 4f52116 commit 772d7c7
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 50 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu
}
}
if (reduction == Reduction::Mean) {
auto target_lengths_t = at::tensor(target_lengths, res.options());
auto target_lengths_t =
at::tensor(target_lengths, res.options()).clamp_min(1);
return (res / target_lengths_t).mean();
} else if (reduction == Reduction::Sum) {
return res.sum();
Expand Down
122 changes: 89 additions & 33 deletions aten/src/ATen/native/cuda/LossCTC.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ namespace native {

namespace {

// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template<typename target_t>
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it witth idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t get_target_prime(
const target_t* __restrict__ target,
int64_t offset,
int64_t stride,
int64_t idx,
int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
Expand Down Expand Up @@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
if (target_length > 0) {
la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
}
else {
la = neginf;
}
la = target_length == 0 ? neginf
: log_probs_data
[lp_batch_offset +
lp_char_stride *
get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
1,
BLANK)];
break;
default:
la = neginf;
Expand All @@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2*target_length+1) {
current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s > 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t=1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
// lamax is the maximum for the logsumexp trick.
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
Expand Down Expand Up @@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)];
scalar_t l2 = target_length > 0
? log_alpha_data
[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
Expand Down Expand Up @@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);

dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
scalar_t lb;
if (s == 2*target_length) {
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
} else if ((target_length > 0) && (s == 2*target_length-1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
} else {
lb = neginf;
Expand All @@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2*target_length+1) {
current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
have_three = ((s < 2*target_length-1) &&
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
current_target_prime));
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s + 2,
BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
}
// now go backward in t. Note that we need to skip the last timestep that we did above.
for (int64_t t=max_input_length-2; t>=0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) {
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
Expand All @@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];

log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
} else if (
(s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
(t >= input_length))) {
log_beta_data
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
neginf;
}
}
}
Expand Down Expand Up @@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,

// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2*max_target_length+1; s++) {
if ((target_length > 0) && (s < 2*target_length+1)) {
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
Expand Down Expand Up @@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
{
dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_beta.data<scalar_t>(),
log_probs.data<scalar_t>(), input_lengths_t.data<int64_t>(), log_probs.size(0),
Expand Down Expand Up @@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length) {
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand All @@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(0)) {
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);

ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data<scalar_t>(),
grad_out.data<scalar_t>(), grad_out.stride(0),
Expand Down
39 changes: 28 additions & 11 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,25 +1609,42 @@ def test_ctc_loss(self):
target_length = 15
gradcheck_input_size = 10

# device, input_length
tests = [('cpu', 150, False),
('cpu', 150, True)]
ZERO_NONE = 0
ZERO_SOME = 1
ZERO_ALL = 2

# device, input_length, vary_lengths, zero_lengths
tests = [('cpu', 150, False, ZERO_NONE),
('cpu', 150, True, ZERO_NONE),
('cpu', 50, True, ZERO_SOME),
('cpu', 50, True, ZERO_ALL)]
if torch.cuda.is_available():
tests += [('cuda', 50, False),
('cuda', 150, False),
('cuda', 50, True),
('cuda', 150, True)]

for device, input_length, vary_lengths in tests:
tests += [('cuda', 50, False, ZERO_NONE),
('cuda', 150, False, ZERO_NONE),
('cuda', 50, True, ZERO_NONE),
('cuda', 150, True, ZERO_NONE),
('cuda', 50, True, ZERO_SOME),
('cuda', 150, True, ZERO_SOME),
('cuda', 50, True, ZERO_ALL),
('cuda', 150, True, ZERO_ALL)]

for device, input_length, vary_lengths, zero_mode in tests:
targets = torch.randint(1, num_labels, (batch_size, target_length),
device=device, dtype=torch.long)
x = torch.randn(gradcheck_input_size, device=device, requires_grad=True)
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
device=device)
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for i in range(batch_size)]
if zero_mode == ZERO_ALL:
target_lengths = [0 for _ in range(batch_size)]
else:
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
if vary_lengths else target_length) for _ in range(batch_size)]
if zero_mode == ZERO_SOME:
idxes = torch.randint(0, batch_size, (10,))
for i in idxes:
target_lengths[i] = 0

def ctc_after_softmax(x):
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
Expand Down
19 changes: 14 additions & 5 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5599,20 +5599,29 @@ def test_CTCLoss_lengthchecks_cpu(self):
with self.assertRaises(RuntimeError):
torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)

def test_CTCLoss_empty_target_cpu(self):
def _test_CTCLoss_empty_target(self, device):
target_lengths = [0, 0, 0]
input_lengths = [50, 50, 50]
targets = torch.randint(1, 15, (0,), dtype=torch.int)
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
self.assertTrue((loss >= 0).all().item())
self.assertAlmostEqual(-log_probs.sum(0)[:, 0], loss)

target_lengths = [0, 9, 0]
input_lengths = [50, 50, 50]
targets = torch.randint(1, 15, (9,), dtype=torch.int)
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
self.assertTrue((loss >= 0).all().item())
self.assertAlmostEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])

def test_CTCLoss_empty_target_cpu(self):
self._test_CTCLoss_empty_target('cpu')

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_CTCLoss_empty_target_cuda(self):
self._test_CTCLoss_empty_target('cuda')

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_CTCLoss_zero_infinity(self):
Expand Down

0 comments on commit 772d7c7

Please sign in to comment.