Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do you modify the function prepare_inputs_for_generation of the LLM? #106

Closed
linhaojia13 opened this issue Jul 12, 2024 · 2 comments
Closed

Comments

@linhaojia13
Copy link

linhaojia13 commented Jul 12, 2024

In the transformers codes, the prepare_inputs_for_generation is as:

    def prepare_inputs_for_generation(

            ...

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

            ...

However, in bunny/model/language_model/llama/modeling_llama.py, the prepare_inputs_for_generation is as:

    def prepare_inputs_for_generation(

            ...

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
            else:
                remove_prefix_length = input_ids.shape[1] - 1
                input_ids = input_ids[:, remove_prefix_length:]
            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

            ...

You add

            else:
                remove_prefix_length = input_ids.shape[1] - 1
                input_ids = input_ids[:, remove_prefix_length:]

When I debug the bunny codes, I find the input_ids.shape[1] is not the number of the generated token, which is different with llava's codes.
From what I understand, Bunny's code framework is pretty much the same as llava's. So why is there such a big difference in input_ids during inference between your code and llava's that you need to change the source codes of the transformers?

@Isaachhh
Copy link
Collaborator

# text prompt
prompt = 'Why is the image funny?'
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)

When encoding the text, Bunny would use -200 to represent <image> as the placeholder of image. Assuming that the length of input_ids is L, it consists of L-1 normal tokens and one -200.

After prepare_inputs_labels_for_multimodal, the input_ids would be converted to input_embeds and then fed into the model. And right now, the image would take the actual length in input_embeds. For SigLIP-SO, an image is encoded to 729 tokens. So the length of input_embeds is L+728.

And this change of length would cause some mismatch of attention_mask and so on.
We add

            else:
                remove_prefix_length = input_ids.shape[1] - 1
                input_ids = input_ids[:, remove_prefix_length:]

to make the inference running normally. You can try what if deleting this part of code.

And actually, once LLaVA required transformers == 4.31 and at that time you can see this part of code in transformers) (this commit of transformers deleted this part of code). Bunny is consistent with LLaVA at that time.

And the release of LLaVA-1.6 overrides generate function (see this commit) and then LLaVA can be compatible with latest version of transformers. Bunny doesn't override generate function and needs to modify the prepare_inputs_for_generation function.

@linhaojia13
Copy link
Author

Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants