Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Ancestors on Correct Device During Sampling #651

Merged
merged 3 commits into from
Feb 13, 2024

Conversation

lapp0
Copy link
Collaborator

@lapp0 lapp0 commented Feb 12, 2024

Fixes #649

@rlouf
Copy link
Member

rlouf commented Feb 12, 2024

You would need to do this for every sampling method

@lapp0
Copy link
Collaborator Author

lapp0 commented Feb 12, 2024

Good catch. Beam and greedy needs changes as well.

Smoke tested, can confirm for all 3 samplers the device is set correctly:

>>> next_token_ids.device, ancestors.device, weights.device
(device(type='cuda', index=0), device(type='cuda', index=0), device(type='cuda', index=0))

@@ -66,7 +66,9 @@ def __call__(
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True)

ancestors = torch.arange(next_token_logits.shape[0])
ancestors = torch.arange(
next_token_logits.shape[0], device=next_token_logits.device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's be consistent across samplers and use next_token_logits for each sampler :)

@rlouf rlouf merged commit a33692e into outlines-dev:main Feb 13, 2024
5 checks passed
@rlouf
Copy link
Member

rlouf commented Feb 13, 2024

Thanks for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Device error for JSON generation with outlines 0.0.29
2 participants