diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 51a995664..4104e3080 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -231,7 +231,7 @@ def __call__( # We reshape the output to (batch_size, sample_size) output: List[List[FormattedOutput]] = list() - for i in range(batch_size): + for i in range(0, batch_size * num_samples, num_samples): output.append(formatted[i : i + num_samples]) # We remove leading dimensions for the output @@ -372,7 +372,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() - for i in range(batch_size): + for i in range(0, batch_size * num_samples, num_samples): output.append(next_tokens[i : i + num_samples]) # We remove leading dimensions for the output