Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major bug & fix: Fix bug in batched multi sample generation #1025

Merged
merged 4 commits into from
Jul 11, 2024

Conversation

JulesGM
Copy link
Contributor

@JulesGM JulesGM commented Jul 8, 2024

The following lines break in batched generation.
There is a single list that is

[b_0_s_0, b_0_s_1, b_0_s_2, b_0_s_3, b_1_s_0, b_1_s_1, ...]

with b_0_s_0 being the example generated for batch 0 and sample 0 of multi-sample batch generation.

At the end of the generation code, the following tries to separate the generated samples in a batch_size quantity of sub lists. The problematic code is as follows:

for i in range(batch_size):
            output.append(formatted[i : i + num_samples])

We indeed get a list of batch size, but not the expected ones. Instead it should be:

for i in range(0, batch_size * num_samples, num_samples):
                    output.append(next_tokens[i : i + num_samples])

As an example, using the prompts:

'1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 =

and

'2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 =

which gives a batch size of 2, with the \d( \d)+ regex, and with multinomial generation with a number of samples of 10, the current wrong output is (if you remove excluding the prompt from the output, which I did because the outputs looked fishy):

[[
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 9 9 9',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 6 6',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 1 2',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 5 5 5',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 0 0 0',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 4 5',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 3 3 3'
], [
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 9 9 9',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 6 6',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 1 2',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 5 5 5',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 0 0 0',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 4 5',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 3 3 3',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 8 0 8 8'
]]

We can see the problem, it returned next_tokens[0 : num_samples] and next_tokens[1 : num_samples + 1] which is not what we want.

The new code returns

[[
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 6 8 1',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 2 4',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 7 7',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 9 3 3',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 4 9',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 6 6 6',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 4 4 4',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 7 3 2',
    '1 1 1 + 3 3 3 =? Solution: 1 1 1 + 3 3 3 = 1 1 0'
], [
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 4 4 4 4',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 2 4 4 4',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 8 2 2 2',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 8 4 4 4',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 4 4 4 4',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 2 2 2 2',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 5 5 5 5',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 2 2 2 2',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 2 2 2 2',
    '2 2 2 2 + 4 4 4 4 =? Solution: 2 2 2 2 + 4 4 4 4 = 8 8 8 8'
]]

Which corresponds to the index ranges [0 : num_samples] and [num_samples : 2 * num_samples], which is what we want.

A more simple complete example is

#!/usr/bin/env python
# coding: utf-8

import outlines
import outlines.models.transformers
import outlines.samplers
from outlines.generate.generator import sequence_generator
import transformers
import more_itertools as mit
import rich
import rich.panel
from typing import *


MODEL          = "susnato/phi-2"
hf_model       = transformers .AutoModelForCausalLM .from_pretrained (MODEL).cuda()
hf_tokenizer   = transformers .AutoTokenizer        .from_pretrained (
    MODEL, padding_side="right")
outlines_model = outlines     .models               .Transformers    (hf_model, hf_tokenizer)

generator = outlines.generate.regex(outlines_model, "\d+", sampler=outlines.samplers.MultinomialSampler(10))
output = generator.stream(prompts=["What is 11 + 33 ? Solution: 11 + 33 = ", "What is 2222 + 4444 ? Solution: 2222 + 4444 = "], max_tokens=100)

for o in output:
    rich.print(o)

gives

[
    ['21', '11', '41', '14', '55', '10', '11', '09', '31', '10'],
    [ '11', '41', '14', '55', '10', '11', '09', '31', '10', '6666']
]

when not fixed (see how the second list is just the first list with one item at the start fewer, and one reasonable-looking generation at the end),

and, when fixed:

[
    ['21', '11', '41', '14', '55', '10', '11', '09', '31', '10'],
    ['6666', '66666666', '6666', '0000000000000000', '6666', '6666', '6666', '6666', '6666', '6666']
]

@JulesGM JulesGM changed the title Major bug fix: Fix bug in batched multi sample generation Major bug & fix: Fix bug in batched multi sample generation Jul 8, 2024
@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 8, 2024

@lapp0 @rlouf

Co-authored-by: Patrice Bechard <bechardpatrice@gmail.com>
@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 10, 2024

folks this is serious

@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 10, 2024

@brandonwillard

@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 10, 2024

Merci @patricebechard are you able to give approval / merge?

@lapp0
Copy link
Collaborator

lapp0 commented Jul 10, 2024

Thanks so much for finding and fixing this bug!

Could you please add a test case which fails in main and passes with your fix?

Also as an alternative we might consider #966

@patricebechard
Copy link
Contributor

Merci @patricebechard are you able to give approval / merge?

Nop I am not a maintainer, just trying to help :)

@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 10, 2024 via email

@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 10, 2024

Making a test that fails is really hard without an option to return the inputs as part of the outputs. Having that option would make it easier though.

Copy link
Collaborator

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to reproduce your results on main and verify your fix resolves the issue!

@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 11, 2024

Screenshot 2024-07-11 at 9 42 59 AM It's not letting me merge

@rlouf
Copy link
Member

rlouf commented Jul 11, 2024

Thank you for contributing a fix! In the future, please do not tag maintainers and other users in the PR.

@rlouf rlouf merged commit b54a964 into dottxt-ai:main Jul 11, 2024
6 checks passed
@JulesGM
Copy link
Contributor Author

JulesGM commented Jul 12, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants