Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some Sequence interface issues #319

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(
model,
regex_string: str,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
stop: Union[str, List[str]] = [],
allow_empty_tokens: bool = True,
initial_state: Optional[int] = None,
final_states: Optional[Set[int]] = None,
Expand All @@ -62,6 +64,8 @@ def __init__(
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
stop
Optional stopping string(s).
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
states_to_token_maps
Expand All @@ -71,7 +75,7 @@ def __init__(
Pre-computed set of token ids for tokens that are empty strings.

"""
super().__init__(model, max_tokens, sampler)
super().__init__(model, max_tokens, sampler, stop)

if (
states_to_token_maps is None
Expand Down Expand Up @@ -248,7 +252,13 @@ def regex(
Allow sampling of tokens corresponding to empty strings.

"""
return Regex(model, regex_string, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_string,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def integer(
Expand Down Expand Up @@ -284,7 +294,13 @@ def integer(
Allow sampling of tokens corresponding to empty strings.

"""
return Regex(model, r"[-+]?\d+", max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
r"[-+]?\d+",
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def float(
Expand Down Expand Up @@ -324,8 +340,8 @@ def float(
model,
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))",
max_tokens,
sampler,
allow_empty_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


Expand Down Expand Up @@ -359,7 +375,13 @@ def choice(
Allow sampling of tokens corresponding to empty strings.
"""
regex_str = r"(" + r"|".join(choices) + r")"
return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def json(
Expand Down Expand Up @@ -399,4 +421,10 @@ def json(

regex_str = build_regex_from_schema(schema)

return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)
2 changes: 2 additions & 0 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
from outlines.text.generate.sample import multinomial

self.sampler = multinomial
else:
self.sampler = sampler

def create_proposal(
self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor
Expand Down
30 changes: 30 additions & 0 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,33 @@ def test_transformers_reduced_vocabulary_caching():
vocab2 = reduced_vocabulary(tokenizer2)

assert vocab2 is vocab


def test_custom_sampler():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"

model = models.transformers(model_name)

seen = False
target_token_ids = model.tokenizer.encode(["c"])[0]

def biased_sampler(
logits: torch.DoubleTensor, samples: int, *_
) -> torch.DoubleTensor:
nonlocal seen

if not seen:
seen = True
return target_token_ids
else:
return torch.tensor([[model.tokenizer.eos_token_id]])

generator = generate.choice(model, ["a", "b", "c"], sampler=biased_sampler)
sequence = generator(
"""What is 1+1?
a. 3
b. 4
c. 2"""
)

assert sequence == "c"
Loading