Skip to content

Commit

Permalink
Change argument name in logits_bias
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 21, 2023
1 parent 8eb7ac0 commit 6327421
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def expand_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor:
@torch.inference_mode()
def bias_logits(
logits: torch.Tensor,
ids_to_mask: List,
allowed_token_ids: List,
) -> torch.Tensor:
"""Mask the logits.
Expand All @@ -278,15 +278,15 @@ def bias_logits(
logits
Two dimensional tensor that contains the next-token probability
distribution.
ids_to_mask
The ids to mask in each dimension.
allowed_token_ids
A list that contains the tokens that can be generated by the model.
Returns
-------
A view of the original logits tensor where some values are masked.
"""
biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
for i, ids in enumerate(ids_to_mask):
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
return biased_logits

0 comments on commit 6327421

Please sign in to comment.