Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python where op floatxfloat promotes to float64 #2380

Closed
mruberry opened this issue Jan 30, 2023 · 7 comments · Fixed by #2423
Closed

Python where op floatxfloat promotes to float64 #2380

mruberry opened this issue Jan 30, 2023 · 7 comments · Fixed by #2423
Assignees
Labels

Comments

@mruberry
Copy link
Collaborator

pred = make_tensor((5,), device='cuda', dtype=torch.bool)

fs = Fusion()
with FusionDefinition(fs) as fd:
    nv_pred = fd.define_tensor(sizes=pred.shape, strides=pred.stride(), dtype=DataType.Bool)
    five = fd.define_constant(5.)
    three = fd.define_constant(3.)

    result = fd.ops.where(nv_pred, five, three)

    fd.add_output(result)


nv_result = fs.execute((pred,))[0]
print(f"nv_result={nv_result}")
: nv_result=tensor([5., 3., 3., 3., 5.], device='cuda:0', dtype=torch.float64)

torch_result = torch.where(pred, 5., 3.)
print(f"torch_result={torch_result}")
: torch_result=tensor([5., 3., 3., 3., 5.], device='cuda:0')
@kevinstephano
Copy link
Collaborator

The solution is to change the define_constant API to allow for type specification since python numbers are inferred to be double for floating point.

@mruberry
Copy link
Collaborator Author

And define_scalar, too?

@kevinstephano kevinstephano self-assigned this Jan 30, 2023
@kevinstephano
Copy link
Collaborator

define_scalar already allows you to specify a type. I think we are okay there unless you saw an issue?

@mruberry
Copy link
Collaborator Author

define_scalar already allows you to specify a type. I think we are okay there unless you saw an issue?

My mistake, I didn't realize

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Feb 2, 2023

We currently don't have single-precision scalars, neither in Python nor C++. I hit some errors trying to add those (see #2403), but it could probably be done. However, it may be simpler to add a dtype to the where op that defaults to DataType.Float. The effect of the argument would be to insert a cast op after where.

@jacobhinkle
Copy link
Collaborator

@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for where is enough to ensure a float32-valued output:

c0f = fd.define_constant(3.0, DataType.Float)
c1f = fd.define_constant(5.0, DataType.Float)
t1f = fd.ops.where(t0, c0f, c1f) # DataType.Float
fd.add_output(t1f)

Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.

@mruberry
Copy link
Collaborator Author

mruberry commented Feb 8, 2023

@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for where is enough to ensure a float32-valued output:

c0f = fd.define_constant(3.0, DataType.Float)
c1f = fd.define_constant(5.0, DataType.Float)
t1f = fd.ops.where(t0, c0f, c1f) # DataType.Float
fd.add_output(t1f)

Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.

Yes, I think that would address the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants