Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 2: 9b and 27b versions #1545

Merged
merged 59 commits into from
Jul 22, 2024
Merged

Gemma 2: 9b and 27b versions #1545

merged 59 commits into from
Jul 22, 2024

Conversation

Andrei-Aksionov
Copy link
Collaborator

@Andrei-Aksionov Andrei-Aksionov commented Jul 2, 2024

Hi there 👋

Fixes #1535

Google released the latest and greatest Gemma model - v2.
This time it comes in three sizes:

  • 2b (not yet released)
  • 9b
  • 27b

Based on the technical report and the official implementation here are the main changes that I've spotted:

  1. Embeddings scaler needs to be casted down before applying.
  2. Needs to be careful with attention scores scaler: it's not equal to head_size, but rather n_embd/n_head. In case of Gemma head_size might not be equal to n_embd/n_head.
  3. Logits soft-caping for attention scores and for final logits. Soft-caping is needed only for training (looks like it's more important for a larger model) and not so much for inference. Since flash attention doesn't support soft-caping, it needs to be disabled if not in training mode.
  4. Sliding window attention is used instead of a global window on every odd (idx) layer with half the size.
  5. RMSNorm does downcasting right at the end. That was the behavior before I added support for Gemma v1.
  6. Transformer block now has two more normalization layers: right after attention layer (before residual connection) and right after MLP (also before residual connection). Previously we had norm -> attn -> residual -> norm -> MLP -> residual. Now: norm -> attn -> norm -> residual -> norm -> MLP -> norm -> residual.
  7. Both 9b and 27b use grouped query attention. In Gemma v1 7b had a regular multi-head attention, while 2b variant had multi-query attention (single key-value pair is shared across all query heads).

@rasbt
Copy link
Collaborator

rasbt commented Jul 2, 2024

Nice summary. I think this touches all the main points. The others (knowledge distillation for the small models; tied embeddings) would not affect the architecture, it's more of a pretraining method. So yeah, looks great! Many thanks for taking this on!

@rasbt
Copy link
Collaborator

rasbt commented Jul 5, 2024

@Andrei-Aksionov
Sliding window attention (an ugly one, but hey, it works)

Cool! We can also add that to the existing Mistral/Mixtral models then 😊

@Andrei-Aksionov
Copy link
Collaborator Author

Cool! We can also add that to the existing Mistral/Mixtral models then

I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it.
But after this PR is merged, adding SLA would be just a matter of an additional line in a config.

@rasbt
Copy link
Collaborator

rasbt commented Jul 5, 2024

I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it.

I think you are right.

But after this PR is merged, adding SLA would be just a matter of an additional line in a config.

Nice!

@Andrei-Aksionov
Copy link
Collaborator Author

Andrei-Aksionov commented Jul 6, 2024

Gemma 2 9b/9b-it now has an initial support (with a lot of “scaffolding”).

Generation returns plausible results, but chat does a couple of strange things:

  1. OOM. Don't understand why if a regular generate script consumes ~20 GB.
    Update: It's not a Gemma specific problem Chat consumes more VRAM than Generate #1558, so it's not a blocker.
  2. Had to use quantization (bnb.nf4) and the model was very restrictive, often didn't want to respond and instead asked to rephrase the question. I know that LlaMA 3, because of a very long training, has saturated bf16 dtype up to the very last digit and thus quantization affects more than other models. Maybe here we have the same (thanks to a “teacher”)? 🤷
    Update # 1: if to use generate script with quantization, then I get a proper output. Something else is broken in chat script, besides a higher memory consumption.
    Update # 2: KV-cache needs to be change to support sliding window attention. Chat script pre-allocates too much memory (up to model.max_seq_length), so a layer with sliding window has a wrong kv-cache.

Anyway, there is a lot of work that needs to be done (besides what I've mentioned above) before I can open this PR for a review:

  • 1. Only final_softcapping affects tests. Need to make tests fail if attention_logit_softcapping is messed up.
  • 2. Deal with all TODOs. The code works, but is very ugly and non-performant.
  • 3. Use torch profiler to make sure that there are no shady device syncs happen in the background.
  • 4. Add code and a test for LitGPT --> HF format conversion.
  • 5. Does the torch.compile work with softcapping? If not, a clear error message needs to be printed.
  • 6. Do a short training as a sanity-check.
  • 7. Figure out what to do with CausalSelfAttention from adapter.py. Tests for adapter don't fail because of # 1.
  • 8. Add support for 27b variant.

Copy link
Collaborator

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great PR. It's crazy that you pulled this off. Really awesome.

Andrei-Aksionov and others added 9 commits July 19, 2024 12:07
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
@Andrei-Aksionov
Copy link
Collaborator Author

Andrei-Aksionov commented Jul 19, 2024

One more thing. Due to time constraints, I didn't test Gemma v2 27b version.
Tests are running fine, but it would be nice to check the generated output.

@rasbt could you do this?

@rasbt
Copy link
Collaborator

rasbt commented Jul 19, 2024

One more thing. Due to time constraints, I didn't test Gemma v2 27b version.
Tests are running fine, but it would be nice to check the generated output.

@rasbt could you do this?

Yes, I am happy to do this. The other thing is I will also generate config files for the smaller models

@rasbt
Copy link
Collaborator

rasbt commented Jul 19, 2024

Works great!

Screenshot 2024-07-19 at 10 04 03 AM

litgpt/config.py Outdated Show resolved Hide resolved
@rasbt
Copy link
Collaborator

rasbt commented Jul 22, 2024

Based on the config file run, the train and val loss look great. It's a surprisingly low MMLU though. There's nothing wrong with the finetuned model though and it works fine during chat:

Screenshot 2024-07-22 at 3 38 27 PM

(Not 100% sure, but maybe the MMLU scores in the README were created with --num_fewshot greater than 1.

Anyways, I think everything else seems to be fine though and good to merge now right?

@Andrei-Aksionov
Copy link
Collaborator Author

Yep, let's merge.

@rasbt
Copy link
Collaborator

rasbt commented Jul 22, 2024

Awesome, this is great! Thanks for this amazing PR!

@rasbt rasbt merged commit 916a84c into main Jul 22, 2024
9 checks passed
@rasbt rasbt deleted the gemma_2 branch July 22, 2024 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Gemma 2 Checkpoints
3 participants