-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nextafter op #2510
base: devel
Are you sure you want to change the base?
Add nextafter op #2510
Conversation
Test is currently segfaulting for scalar-only inputs.
if (config.require_full_precision_promoted) { | ||
TORCH_CHECK( | ||
common_dtype == c10::ScalarType::Float || | ||
common_dtype == c10::ScalarType::Double, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that this is the right place to perform this check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I got this wrong too. PyTorch actually supports bfloat16, just not float16 for this op. It was added with a manual implementation taken from musl: https://github.com/pytorch/pytorch/pull/61829/files#diff-ece04c31934b3504382e10ed3e9a69f03ffabd81ad1a2a890aab19b1642f53c0R120
The `nextafter(x, y)` operation provides the nearest distinct representable floating point value to `x` between `x` and `y`. In CUDA these are obtained with the builtins `nextafter` and `nextafterf`. Other types such as bfloat16 are not directly supported, though PyTorch has implemented that case based on some code in musl: https://github.com/pytorch/pytorch/pull/61829/files#diff-ece04c31934b3504382e10ed3e9a69f03ffabd81ad1a2a890aab19b1642f53c0R120. In PyTorch, the torch.nextafter function is defined for any pair of arguments which would normally be promoted to either float32 or float64. So arguments which are both ints, bools, complex, or half-precision floats are not supported. This PR implements a binary op macro with a TypePromotionConfig that enforces that rule. This is a translation/update of csarofeen/pytorch#2510. --------- Co-authored-by: Jacob Hinkle <jhinkle@nvidia.com>
The
nextafter(x, y)
operation provides the nearest distinct representable floating point value tox
betweenx
andy
. In CUDA these are obtained with the builtinsnextafter
andnextafterf
.In PyTorch, the
torch.nextafter
function is defined for any pair of arguments which would normally be promoted to eitherfloat32
orfloat64
. So arguments which are both ints, bools, complex, or half-precision floats are not supported. This PR implements a binary op macro with aTypePromotionConfig
that enforces that rule.