Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Megatron-11B #10301

Closed
wants to merge 7 commits into from
Closed

[WIP] Add Megatron-11B #10301

wants to merge 7 commits into from

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Feb 20, 2021

What does this PR do?

Fixes #9560

This PR introduces the Megatron model as described in https://github.com/pytorch/fairseq/blob/master/examples/megatron_11b/README.md
This one will probably be fun to test with DeepSpeed, as @stas00 mentioned it's referenced a lot in its docs 😄

It's important to mention that there are actually two independent implementations of Megatron-LM:

After some tinkering I realized that fairseq's checkpoint is already pretty compatible with the existing BART port. So, based on that and the fact that NVIDIA doesn't plan on releasing the 3B and 8B checkpoints, I chose to port only the fairseq version.

NOTE: The original fairseq implementation requires an 8-GPU server to even load the model weights, so I just load the checkpoints manually one by one and merge the model-parallelized tensors into single-model ones.

How to reproduce the conversion

  1. First, find a server with at least 85GB of RAM, this model is huge!
  2. Next, download and untar the checkpoint:
# WARNING: this file is 19GB
wget https://dl.fbaipublicfiles.com/fairseq/models/model_parallel/megatron_11b.tar.gz
tar -xzvf megatron_11b.tar.gz
wget -P ./megatron_11b/ 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -P ./megatron_11b/ 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
  1. Run the conversion script
python convert_megatron_original_pytorch_checkpoint_to_pytorch.py --fairseq_path /path/to/megatron_11b --pytorch_dump_path /path/to/megatron_hf_dump
  1. The conversion script will load the model-parallel shards of the checkpoint, group the sharded parameters and concatenate the weights, so that the fairseq.ModelParallelTransformerLanguageModel state_dict can be easily loaded into a CPU-compatible faiseq.TransformerLanguageModel. The de-parallelisation is based on ParlAI's conversion script.
  2. Then the script will initialize the huggingface Megatron model and load the converted state_dict into it.

Here's how Megatron differs from the existing BART/MBART implemenations:

  1. The most controversial difference, IMO, is the missing encoder, since it's a decoder-only model. For now, I decided to remove the encoder parts inherited from MBART, bit left the encoder-dependent parts in the decoder (e.g. encoder_hidden_states, encoder_attention_mask) and the cross-attention to simplify the review process on your end.
  2. Megatron uses SinusoidalPositionalEmbedding instead of learned ones, so I just yanked those from FSMT 😄
  3. Megatron does not have a layernorm_embedding
  4. Minor detail: the self_attn_layer_norm is applied before self-attention (like in MBART) instead of after (like in BART).

Important questions regarding the API:

  1. What should be done about the missing encoder? I think the decoder variable can be left as is, since it's compatible with the fairseq checkpoint keys, but the encoder_* references in the code bother me a lot. We need to somehow strike a balance between Copied from and removing the unused parts.
  2. I think the position of self_attn_layer_norm should be a parameter in the config, similar to decoder_normalize_before=True in faiseq. This will close the not-so-obvious difference between BART and MBART.
  3. The existence of layernorm_embedding can also be parametrized, similar to layernorm_embedding=False in fairseq.

Quick LM test

You can test out the model's capabilities like so (again, you'll probably need at least 85GB RAM, there's some weird memory duplication happening somewhere, this should not need more than 50):

from transformers import MegatronForCausalLM, MegatronTokenizer, TextGenerationPipeline

tokenizer = MegatronTokenizer.from_pretrained("megatron-11b")
model = MegatronForCausalLM.from_pretrained("anton-l/megatron-11b")

def generate(prompt, max_length=40, num_beams=5, num_return=3):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(
        input_ids=input_ids, num_beams=num_beams, num_return_sequences=num_return, max_length=max_length
    )
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return decoded

print(generate("Before boarding your rocket to Mars, remember to pack these items: "))
['Before boarding your rocket to Mars, remember to pack these items: 1. A parachute.',
 'Before boarding your rocket to Mars, remember to pack these items: 1. A parachute $100 bill2. A copy of your passport3. A copy of your passport444',
 'Before boarding your rocket to Mars, remember to pack these items: 1. A parachute $1 million dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars']

To be honest, I'm not too impressed with its text-generation power. 😄 I guess it's either that the model was too large to train it for enough steps, or I missed something during the conversion. The original implementation does not have a text-generation script (or any non-wikitext results, for that matter), so I'm kinda in the dark here.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten, @patil-suraj

@anton-l anton-l marked this pull request as draft February 20, 2021 15:30
@anton-l anton-l changed the title Add Megatron-11B [WIP] Add Megatron-11B Feb 20, 2021
@stas00
Copy link
Contributor

stas00 commented Feb 21, 2021

That's very neat, @anton-l! thank you for the port

You demonstrated a very good creativity by finding a way to recompose the model shards!

This one will probably be fun to test with DeepSpeed, as @stas00 mentioned it's referenced a lot in its docs

As you correctly noticed studying Megatron-LM's horizontal model parallel sharding is on my TODO list.

I suppose since transformers currently doesn't provide this feature you didn't port that part of the model, correct? i.e. you unsharded it. I had a brief read through the PR and didn't see anything of a sort - unless I somehow missed it? And without this feature, this is like any other transformers model - It's its horizontal model parallel feature that is needed to complete 3D parallelism with Deepspeed. Your PR is an excellent start.

I think the part that deals with sharding is here in the original:
https://github.com/jeffra/DSE/blob/79888e162425e8d64043a9597ee14751bd4b53d1/megatron/data/realm_index.py
Though this is the NVIDIA version.

So if the horizontal MP is eventually re-ported (I hope it will be so) the model will need to know when to load the flattened version and when the sharded one. But transformers doesn't even have a framework for loading multiple-part models at the moment, so I guess we will cross that bridge when we get to it.

I'm just just thinking aloud here, considering different options, not making any requests ;)


The fp32 weights are ~41GB https://huggingface.co/anton-l/megatron-11b/tree/main - i.e. it's quite similar to t5-11b, so it should be possible to load it on a 40GB gpu w/ DeepSpeed ZeRO-Offload if there are some 256GB of RAM available.


Also, FYI, Deepspeed are making a new port of Megatron-LM to work with DeepSpeed. https://github.com/jeffra/DSE/tree/master/megatron-lm

@anton-l
Copy link
Member Author

anton-l commented Feb 21, 2021

@stas00 you're correct, I didn't port the model-parallel implementation. Fairseq uses an older Megatron-LM version as a submodule here for its MP map-reduce fuctions. This makes it quite cumbersome to reproduce, since it requires compiling an older apex library among other dependencies with broken versioning. It would also require a patched version of faiseq's state loader, since right now it requires exactly 8 GPUs available to load the sharded checkpoint correctly.

However, on the surface it seems like adding support for model parallelism comes down to porting VocabParallelEmbedding, ColumnParallelLinear and RowParallelLinear layers as implemented here. This seems doable, but I don't have multiple GPUs to test it out :(

I guess a proper MP implementation should also take care of splitting the checkpointed layers regardless of how many GPUs are available (i.e. 2, 4 or 8). That would remove the requirement to have a full DGX setup if the user is willing to use gradient checkpointing/accumulation instead.

@stas00
Copy link
Contributor

stas00 commented Feb 21, 2021

@anhon-l, in order not to make your and reviewers' lives unnecessarily difficult, let's take the discussion of the Horizontal MP to a dedicated issue, since it could take some time to figure and none of is required for you to complete this PR and I trust @patil-suraj and @patrickvonplaten will support you at completing this awesome effort.

So if you could re-post your last comment here: #10321 and I will follow up there. Thank you!

@vspruyt
Copy link

vspruyt commented Feb 25, 2021

['Before boarding your rocket to Mars, remember to pack these items: 1. A parachute.',
 'Before boarding your rocket to Mars, remember to pack these items: 1. A parachute $100 bill2. A copy of your passport3. A copy of your passport444',
 'Before boarding your rocket to Mars, remember to pack these items: 1. A parachute $1 million dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars dollars']

To be honest, I'm not too impressed with its text-generation power. 😄 I guess it's either that the model was too large to train it for enough steps, or I missed something during the conversion. The original implementation does not have a text-generation script (or any non-wikitext results, for that matter), so I'm kinda in the dark here.

This is amazing work, big kudos! The seemingly low text-generation quality surprises me though, because of the crazy good output you get from https://inferkit.com/ which is also just Megatron11b, according to their docs (https://inferkit.com/docs/generation). Their output seems to be much better than GPT2.

@huggingface huggingface deleted a comment from github-actions bot Apr 14, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 14, 2021
@stas00
Copy link
Contributor

stas00 commented Apr 14, 2021

@anton-l, would you like to complete this PR? For it to be reviewed it needs to be a normal PR and not a draft.

I marked it as WIP so that the stale bot won't try to close it.

Thank you.

@stas00
Copy link
Contributor

stas00 commented May 11, 2021

pinging @anton-l - let's revisit this? Please let us know what you need.

I know meanwhile someone else did the porting of the original GPT2-345M checkpoint https://huggingface.co/nvidia/megatron-gpt2-345m and I see from the docs they use straight GPT2 transformers model to operate it.
https://huggingface.co/nvidia/megatron-gpt2-345m#text-generation

All they have is a conversion script:
https://github.com/huggingface/transformers/tree/master/src/transformers/models/megatron_gpt2
Can the same be done with the fairseq version - i.e. reuse some of the existing models for that? or is it unique enough to warrant its own?

Please bear with me, I'm just starting to figure out Megatron-LM and its variants (there is also a Deepspeed variant), so I'm just slightly above clueless at the moment - should have a better understanding in a few days once I had a chance working with it.

@anton-l
Copy link
Member Author

anton-l commented May 14, 2021

@stas00 sorry for the late reply!

It's great that someone figured out a way to post the original megatron models. When I was looking into that, it wasn't exactly straightforward due to the differences between the attention block implementations in HF GPT2 and Megatron, which was probably patched/parameterized in the meantime.

I chose to implement a separate model for the fairseq megatron because the model uses the same code as the existing MBART & FSMT, but there's only an encoder model, without the decoder. However, we could take a different route and convert the fairseq weights to fit GPT2, since it's clearly possible now. I'll try that tomorrow, and if it works out, we can discard this PR and just add a simple conversion script 👍

@ViktorThink
Copy link

This PR seems very promising and I know the model would be really useful to many.

As it was earlier pointed out, the converted model doesn't seem to have the same quality of generation as the model elsewhere. Perhaps the conversion script could have caused it somehow? Just curious if there was any success with converting the fairseq weights to fit GPT2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding Megatron models.
5 participants