Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/Extend generation for all candidate completions #920

Conversation

mikeedjones
Copy link

@mikeedjones mikeedjones commented Apr 28, 2024

Following from #918 - This PR attempts a fuller implementation of the fix suggested there such that every candidate completion is extended until it contains valid content for every field in the template.

Fixes #914 and I suspect #749 and some other raised issues.

The current, unpatched, behaviour also seems to be replicated in the backend-refactor branch.

This PR makes the following changes to the logic of do_generate in dsp/primitives/predict.py:

  • Differentiates between completions and finished_completions where finished_completions have entries for every field in the template.
  • Introduces the functionextend_generation to dsp/primitives/predict.py, in the ns of _generate which ensures (up to two levels of recursion) that every template field has valid content for all n completions requested.
  • Introduces get_last_field to acquire the last valid field of an Example.
  • Maintains overwriting of temperature in extended completions.

Further suggestions:

  • Should this method raise warnings suggesting that max_tokens should be increased? This implementation could slow down forward passes significantly.

@arnavsinghvi11
Copy link
Collaborator

Thanks @mikeedjones for the PR! This makes sense, but just curious from #734 , whether setting a large number of tokens likely solves this issue? I feel like it's actually better to give the user control over request parameters and adjust accordingly than to excessively increase to more than needed (unless the recursion has some control over this)

Should this method raise warnings suggesting that max_tokens might be increased?

This should definitely be logged as any impact to request parameters are important to flag!

@mikeedjones
Copy link
Author

mikeedjones commented Apr 28, 2024

What do you mean by "solves the issue"? Increasing max_tokens would make it more likely that the fallback logic is not entered into, but very long or complicated signatures might still exceed even very high token-generation limits. For example, whilst Claude 3's context window is 200k tokens, the generation limit is 4096.

I've gone into more detail of the problem in the other, more atomised fix I proposed: #918 (comment)

@arnavsinghvi11
Copy link
Collaborator

I'm not sure I understand. If the generation limit is restricted, does setting max_tokens = 4096 not capture what's done here? If the long signatures exceed the very high token-generation limits, it would not work anyways right? maybe I'm misinterpreting so feel free to correct with an example!

@mikeedjones
Copy link
Author

mikeedjones commented Apr 28, 2024

The current flow, implemented now on main, checks which template fields are in the n completions made in the first pass. If none of the completions contain all the fields, there is some fallback logic is entered, in which the LM (generate function) is called recursively until the fields are created:

        # If none of the completions is completed (i.e., none has the final field set).
        if last_field_idx < len(field_names):
            # Pick the first completion that has gone farthest.
            completion = completions[0]
            ...

            new_kwargs = {
                **kwargs,
                max_tokens_key: max_tokens,
                "n": 1,
                "temperature": 0.0,
            }

            assert max_depth > 0
            return generate(template, **new_kwargs)(
                completion,
                stage=stage,
                max_depth=max_depth - 1,
                original_example=original_example,
            )

The fallback logic gets the "most complete" completion and uses it to make a further call to the LM to generate an extra k tokens (k is chosen by some more logic in primirtives/predict.py.

If max_tokens is increased then the likelihood the LM generates all the required fields goes up, but it is not certain.

For example, with the below:

import dspy
import os

llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=100)
dspy.settings.configure(lm=llm)

class SciGen(dspy.Signature):
    """context and answer on science questions"""
    question = dspy.InputField(desc="Input question")
    foo = dspy.OutputField(desc="foo_context") 
    bar = dspy.OutputField(desc="bar_context") 
    bazz = dspy.OutputField(desc="bazz_context") 
    buzz = dspy.OutputField(desc="buzz_context") 
    fuzz = dspy.OutputField(desc="fuzz_context") 
    fooz = dspy.OutputField(desc="fooz_context") 
    answer = dspy.OutputField(desc="Answer using only the contexts")

context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")

print(len(response.completions))
# 1

The first call to gpt-35 (round 1) produces 10 completions, none of which contain the required fieldsfooz or answer.

The fallback logic is therefore entered, where the "most complete" completion is used (maybe the one which contains fuzz) as the Example for another call of generate, with updated kwargs {"n":1, "max_tokens": k, "temperature":0} (round 2).

The updated completion only produces 1 completion, as n has been overwritten, based upon the most complete completion from round 1.

It is this fallback logic which is causing only one completion to be returned in #914.

For an arbitrarily long and complex signature, there is no guarantee that the model will generate the required fields - I suspect that's why the fallback logic was included in the first place! The fallback logic (and my update to it) extends the generation (using the completion from round 1 as input to the calls in round 2) to allow for arbitrarily long signatures - up to the context limit of the LM. But the current implementation replaces the user n with 1.

The ultimate limit on "rounds" is set by max_depth - so an ultimate limit to the output of 4096*max_depth as opposed to 4096.

@mikeedjones mikeedjones changed the title Extend generation for all candidate completions fix/Extend generation for all candidate completions Apr 28, 2024
@arnavsinghvi11
Copy link
Collaborator

Thanks @mikeedjones , this is really helpful! I do see the issue now lies more in the response parsing which triggers the fallback completion logic.

With your code above and the proposed changes in the PR, there are indeed 10 outputted completions, but these are actually 10 "incomplete" completions due to output parsing errors (e.g.

Prediction(
    foo='....',
    bar=''....',
    bazz=''...',
    buzz=''....',
    fuzz=''....\n\nFooz',
    fooz='', #is empty because of the previous parsing error in fuzz likely not producing the Fooz prefix as "Fooz:"
    answer="'.... \n\nAnswer: '...."
)

whereas with the existing logic, there are only 2 completions outputted, but they are "complete" with all fields parsed correctly (from the fallback logic).

This occurs even when I remove `"n": 1" from the #918 . Based on this, I believe we need to tackle a deeper parsing issue rather than extending generation for all candidates, especially since it's better to have 2 "complete" completions instead of 10 - but ideally we want 10 "complete" completions!

Let me know if that makes sense as this PR doesn't fully correct the existing issue (but potentially is on the right track!).

@mikeedjones
Copy link
Author

mikeedjones commented Apr 28, 2024

Good catch @arnavsinghvi11! Thank you :)

Yes, looks like I was using the last filled field to restart the completion as opposed to the first missing field.

Updated the test as the LM didn't reliably fill the nonsense fields - leading to inconsistent results.

import dspy


llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=75)
dspy.settings.configure(lm=llm)
class SciGen(dspy.Signature):
    """context and answer on science questions"""
    question = dspy.InputField(desc="Input question")
    sun_context = dspy.OutputField(desc="sun_context") 
    moon_context = dspy.OutputField(desc="moon_context") 
    earth_context = dspy.OutputField(desc="earth_context") 
    relative_distances_context = dspy.OutputField(desc="relative_distances_context") 
    answer = dspy.OutputField(desc="Answer only when you have all the context fields.")

context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")

assert len(response.completions) == 10

for answer in response.completions:
    for key in [
        "sun_context",
        "moon_context",
        "earth_context",
        "answer",
    ]:
        assert key in answer
        assert answer[key] is not None
        assert answer[key] != ""

I think the parsing errors you're seeing are also due to the LM producing junk when given the odd prompt generated by the odd signature. I think this could be a larger problem with dspy and attempts to make reliably parsable LM output.

@arnavsinghvi11
Copy link
Collaborator

Thanks @mikeedjones . Could you run ruff check . --fix-only and push again? Ready to merge after that.

To confirm, this change is more comprehensive than #918 and that PR can be closed after this is merged?

@mikeedjones
Copy link
Author

@arnavsinghvi11 yes that's correct. It should pick up a few other issues as well relating to n!=1

Linting appiled - cheers!

Cheers

@mikeedjones
Copy link
Author

@arnavsinghvi11 - is there anything outstanding for this PR? cheers! :)

@arnavsinghvi11 arnavsinghvi11 merged commit 6899b5f into stanfordnlp:main May 11, 2024
4 checks passed
@arnavsinghvi11
Copy link
Collaborator

Thanks @mikeedjones ! Very useful PR that caught an elaborate issue!

@okhat
Copy link
Collaborator

okhat commented Jun 18, 2024

This should not have been merged.

@mikeedjones
Copy link
Author

mikeedjones commented Jun 18, 2024

There's an open issue on empty input fields (#1108 ) which is being caused by this PR - but I think the older logic would have the same issue - or is the problem more serious?

EDIT: I'm not sure if the logic as originally implemented works as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Number of generation doesn't give 'n' results on other LM clients (vllm, ollama) except OpenAI
3 participants