-
Notifications
You must be signed in to change notification settings - Fork 22.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix CTC loss for zero-length targets on GPU #23298
Conversation
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) { | ||
if (idx % 2 == 0) { | ||
template <typename target_t> | ||
__device__ static inline int64_t get_target_prime( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is me being ignorant about how CTCLoss works :) For my education, what exactly does get_target_prime
do? E.g., what is it called in the paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've adjusted this function to handle target_length == 0
but in many of the call sites, there is already a condition that implies that if you get to this function, target length is nonzero. Does this hold in all of the call sites? I guess I'll go check now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess in backwards, there are a few cases when you will get here when target_length == 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the comment above the function // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
isn't all that great - maybe amending it with when l is the targets, l' is BLANK l_0 BLANK l_1 ... l_targetlen BLANK
helps?
I'll do a bit more analysis if we need the target length condition here. It might well be that it is not called except with idx 0, which would be equally well...
Edit: Turns out that works well.
Something that would make me feel more confident about this, is specifically running all of the tests under |
I know this is super goofy and will never happen in practice, but what happens if |
target_length, | ||
BLANK); | ||
have_three = | ||
((s < 2 * target_length - 1) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always false when target_length == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Edited) I actually changed the condition here to include target_length > 0 in the outer if, this removes the need to check target length in the get_target_prime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a CTCLoss algorithmic expert, but I did do a reasonable amount of auditing of target_length
use sites and all of the adjustments look reasonable.
Test failure is real
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests need to pass
Thank you for the thorough review! |
Then you'll be glad to hear we do test that in test_autograd.py :) The grid setup changes were needed for these cases. I'll amend the PR for the other comments, thank you! |
So at least the most basic invocation of cuda memcheck seems to not detect any failures in the tests:
Regarding the tolerance in the failing tests: I previously had this at 3e-5 relative tolerance. Apparently that is not good enough, so I increased to 1e-4. The loss is ~1.5e2, so the relative tol tolerance is ~6e-7. (I added a comment to the test.) The obvious alternative would be to run the check with double precision. |
Either works. We do often run things double precision for this reason,
might be a good choice here.
Excerpts from Thomas Viehmann's message of 2019-07-26 12:48:52 -0700:
… So at least the most basic invocation of cuda memcheck seems to not detect any failures in the tests:
```
$ PYTHONPATH=./build/lib.linux-x86_64-3.7/ cuda-memcheck python3 test/test_nn.py TestNN.test_CTCLoss_empty_target_{cuda,cpu}
========= CUDA-MEMCHECK
/usr/lib/python3/dist-packages/numba/errors.py:104: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
warnings.warn(msg)
..
----------------------------------------------------------------------
Ran 2 tests in 0.180s
OK
========= ERROR SUMMARY: 0 errors
$ PYTHONPATH=build/lib.linux-x86_64-3.7/ cuda-memcheck python3 test/test_autograd.py TestAutograd.test_ctc_loss
========= CUDA-MEMCHECK
/usr/lib/python3/dist-packages/numba/errors.py:104: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
warnings.warn(msg)
..
----------------------------------------------------------------------
Ran 1 test in 6.869s
OK
========= ERROR SUMMARY: 0 errors
```
Regarding the tolerance in the failing tests: I previously had this at 3e-5 relative tolerance. Apparently that is not good enough, so I increased to 1e-4. The loss is ~1.5e2, so the relative tol tolerance is ~6e-7. (I added a comment to the test.) The obvious alternative would be to run the check with double precision.
|
So I think the remaining failures are spurious. |
@ezyang: anything I can do to move this forward? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes: pytorch/pytorch#18215 at last! Also sprinkle tests... Pull Request resolved: pytorch/pytorch#23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
Summary: Fixes: pytorch#18215 at last! Also sprinkle tests... Pull Request resolved: pytorch#23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
Fixes: #18215 at last!
Also sprinkle tests...