From 08005524dbdeccb2dd7c35d46c5e46f0213158bf Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Sun, 7 Jul 2024 20:37:27 -0700 Subject: [PATCH 1/3] Fix bug in batched multi sample generation --- outlines/generate/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 51a995664..8b2561075 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -372,7 +372,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() - for i in range(batch_size): + for i in range(0, batch_size * num_samples, num_samples): output.append(next_tokens[i : i + num_samples]) # We remove leading dimensions for the output From 937f0f40dfe0dfc1e7d08937feee28d74306c6e6 Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Sun, 7 Jul 2024 21:38:53 -0700 Subject: [PATCH 2/3] Fixed __call__ as well --- outlines/generate/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 8b2561075..f739dc98c 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -231,8 +231,8 @@ def __call__( # We reshape the output to (batch_size, sample_size) output: List[List[FormattedOutput]] = list() - for i in range(batch_size): - output.append(formatted[i : i + num_samples]) + for i in range(0, batch_size * num_samples, num_samples): + output.append(formatted[i : i + num_samples] # We remove leading dimensions for the output if batch_size == 1 and num_samples == 1: From c4f37c91ba44b2514886ddb7538a14fe6583c1d3 Mon Sep 17 00:00:00 2001 From: Jules Gagnon-Marchand Date: Mon, 8 Jul 2024 09:35:32 -0700 Subject: [PATCH 3/3] Update outlines/generate/api.py Co-authored-by: Patrice Bechard --- outlines/generate/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index f739dc98c..4104e3080 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -232,7 +232,7 @@ def __call__( # We reshape the output to (batch_size, sample_size) output: List[List[FormattedOutput]] = list() for i in range(0, batch_size * num_samples, num_samples): - output.append(formatted[i : i + num_samples] + output.append(formatted[i : i + num_samples]) # We remove leading dimensions for the output if batch_size == 1 and num_samples == 1: