Skip to content

Commit

Permalink
Put ancestors on same device as next_token_logits (#651)
Browse files Browse the repository at this point in the history
Fixes #649

---------

Co-authored-by: Andrew Lapp <andrew@rew.la>
  • Loading branch information
lapp0 and Andrew Lapp committed Feb 13, 2024
1 parent 29bd1fe commit a33692e
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __call__(
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True)

ancestors = torch.arange(next_token_logits.shape[0])
ancestors = torch.arange(
next_token_logits.shape[0], device=next_token_logits.device
)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

return next_token_ids, ancestors, weights
Expand Down Expand Up @@ -144,7 +146,9 @@ def __call__(
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)

logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1)
ancestors = torch.arange(altered_next_token_logits.shape[0])
ancestors = torch.arange(
altered_next_token_logits.shape[0], device=next_token_logits.device
)
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

return next_token_ids, ancestors, weights
Expand Down Expand Up @@ -292,7 +296,7 @@ def __call__(

# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
first_batch_idx = torch.arange(
0, batch_size * self.samples, self.samples
0, batch_size * self.samples, self.samples, device=next_token_logits.device
).unsqueeze(1)
ancestors = ancestors + first_batch_idx

Expand Down

0 comments on commit a33692e

Please sign in to comment.