Skip to content

Commit

Permalink
Fix vectorization of transformers models
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 6, 2023
1 parent cc914d6 commit 8c4fb4b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
27 changes: 21 additions & 6 deletions outlines/models/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def call_model_generate_method(
# `generate` does not accept NumPy arrays
prompt = list(prompt)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_size="left")
model = AutoModelForCausalLM.from_pretrained(model_name)

prompt_tokens = tokenizer(prompt, return_tensors="pt")
tokenizer.pad_token = tokenizer.eos_token
prompt_tokens = tokenizer(prompt, return_tensors="pt", padding=True)

logit_processors: Optional[List[Callable]] = None
stopping_criteria: Optional[List[Callable]] = None
Expand Down Expand Up @@ -153,15 +154,29 @@ def call_model_generate_method(
stopping_criteria=stopping_criteria,
)
new_tokens = returned_tokens[:, prompt_tokens["input_ids"].shape[1] :]
new_tokens = new_tokens.squeeze()
if len(prompt) == 1:
new_tokens = new_tokens.squeeze()

if samples == 1:
if new_tokens.ndim < 2:
results = tokenizer.decode(new_tokens, skip_special_tokens=True)
results = [postprocessing(results)]
results = np.array([postprocessing(results)])
else:
results = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
results = [postprocessing(result) for result in results]
results = np.array(results)

return np.atleast_2d(results)
if len(prompt) == 1:
results = np.expand_dims(results, 0)
else:
results = np.expand_dims(results, 1)

# If we pass a batch of prompts to the model and ask for
# several samples we get a list of results that we need
# to reshape to the right dimensions.
if len(prompt) > 1 and samples > 1:
results = np.reshape(results, (-1, samples))

return results


def create_stop_constraint(
Expand Down
29 changes: 29 additions & 0 deletions tests/models/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,37 @@ def test_samples():
assert len(answers) == 3


def test_prompt_array():
model = HuggingFaceCompletion(TEST_MODEL, max_tokens=10)
prompts = [["Hello", "Bonjour"], ["Ciao", "Hallo"]]
answers = model(prompts)
assert isinstance(answers, np.ndarray)
assert answers.shape == (2, 2)

answers = model(prompts, samples=5)
assert isinstance(answers, np.ndarray)
assert answers.shape == (2, 2, 5)


def test_type_int():
model = HuggingFaceCompletion(TEST_MODEL, max_tokens=10)
answer = model("test", type="int")
int(answer)

answers = model(["test", "other_test"], type="int")
for answer in answers:
int(answer)


def test_type_float():
model = HuggingFaceCompletion(TEST_MODEL, max_tokens=10)
answer = model("test", type="float")
float(answer)

answers = model(["test", "other_test"], type="float")
for answer in answers:
float(answer)


def test_incompatible_constraints():
model = HuggingFaceCompletion(TEST_MODEL, max_tokens=10)
Expand All @@ -46,6 +66,10 @@ def test_choices():
answer = model("test", is_in=choices)
assert answer in choices

answers = model(["test", "other_test"], is_in=choices)
for answer in answers:
assert answer in choices


def test_stop():
model = HuggingFaceCompletion(TEST_MODEL, max_tokens=1000)
Expand All @@ -55,6 +79,11 @@ def test_stop():
for seq in stop:
assert seq not in answer

answers = model(["test", "other_test"], stop_at=stop)
for seq in stop:
for answer in answers:
assert seq not in answer


@pytest.mark.xfail
def test_type_multiple_samples():
Expand Down

0 comments on commit 8c4fb4b

Please sign in to comment.