Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba PR #83

Merged
merged 60 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
abafcb9
fix: llama import
3outeille Jan 24, 2024
e8b8314
feat: at leading loading mamba properly
3outeille Jan 24, 2024
d9513b0
loss going down
3outeille Jan 24, 2024
147fd25
refacto: init mamba weights
3outeille Jan 24, 2024
3e4dda1
fix: dtype for brrr compatibility
3outeille Jan 28, 2024
bebbf6f
fix: bring back strict union to fix NanotronConfigs + use torch.dtype…
3outeille Jan 29, 2024
ad384b9
fix: revert back to casting str to torch dtype
3outeille Jan 29, 2024
02d21de
fix
3outeille Jan 29, 2024
1168bfe
fix: mismatch params counting due to no tie embedding weight
3outeille Jan 29, 2024
7bfc56c
fix: run_generate is now compatible with Brrr
3outeille Jan 30, 2024
1fe1cd4
chore: single file mamba slow (for residual fix)
3outeille Jan 30, 2024
720df8e
fix: residual for slow mamba
3outeille Feb 2, 2024
7902959
fix(mamba-slow): transpose embedding
3outeille Feb 3, 2024
b2ca553
update(mamba-slow)
3outeille Feb 8, 2024
3c4ef42
fix: unified random unit
3outeille Feb 8, 2024
afcbb10
feat(optimizer): can now apply weight decay to specifcs parameters
3outeille Feb 8, 2024
7216b11
Merge branch 'mamba' into mamba-update
3outeille Feb 12, 2024
a6a163f
refacto: mamba1
3outeille Feb 12, 2024
ad17887
feat: add selective scan interface
3outeille Feb 12, 2024
6bfd6aa
feat: make fast path compatible with TP
3outeille Feb 12, 2024
f0e219b
refacto: no split weight
3outeille Feb 13, 2024
ae312fe
refacto: cleaner init weights
3outeille Feb 13, 2024
ff26ff5
refacto: init weights
3outeille Feb 14, 2024
0ab7c9b
feat: save/load _no_weight_decay attribute for Mamba model
3outeille Feb 16, 2024
3caf7fe
clean import
3outeille Feb 28, 2024
ec6f954
clean utils in mamba
3outeille Feb 28, 2024
1773420
cleaning initialization for mamba
3outeille Feb 28, 2024
6fc0d9a
no sync conv and rmsnorm for mamba
3outeille Feb 28, 2024
6e8fd1f
Merge branch 'mamba-update' into mamba-pr
3outeille Feb 28, 2024
fa6953d
clean run_train
3outeille Feb 28, 2024
20dc9ec
clean run_train
3outeille Feb 28, 2024
beb73ee
rename folder
3outeille Feb 29, 2024
d6ae604
add assert only ALL_REDUCE mode
3outeille Feb 29, 2024
5d161b1
add typing to mamba
3outeille Feb 29, 2024
befb66d
remove init method in favor of new initialization method
3outeille Feb 29, 2024
35370b7
fix test serialize
3outeille Feb 29, 2024
790eddc
small fix
3outeille Feb 29, 2024
37867d1
add mamba example
3outeille Feb 29, 2024
54a9c63
fix logger
3outeille Mar 1, 2024
3953a5c
update README
3outeille Feb 29, 2024
3fa2194
update yaml
3outeille Feb 29, 2024
d4bba7c
move to examples
3outeille Mar 1, 2024
a52c37f
update readme
3outeille Mar 1, 2024
1f54ae2
discard run_generate for now
3outeille Mar 1, 2024
7616ce7
revert dynamic weight decay for now
3outeille Mar 4, 2024
e944f63
unifying init_model_randomly for other models
3outeille Mar 1, 2024
8224221
decouple mamba logic from core
3outeille Mar 1, 2024
d28c968
fix logging + init_method
3outeille Mar 4, 2024
576eb56
delete yaml file
3outeille Mar 4, 2024
ea2fa96
small fix
3outeille Mar 4, 2024
a8448bb
fix tp assert
3outeille Mar 4, 2024
106d59c
separate config mamba
3outeille Mar 4, 2024
db889b9
update requirements readme
3outeille Mar 4, 2024
f4816ee
various fix
3outeille Mar 4, 2024
d243b70
Merge branch 'main' into main
3outeille Mar 4, 2024
1098e19
fix import
3outeille Mar 4, 2024
ae6af18
change directory level
3outeille Mar 4, 2024
baba00e
decouple Mamba logic from sync weight
3outeille Mar 4, 2024
7b513e8
fix decorator
3outeille Mar 4, 2024
dbefcc7
small fixes
3outeille Mar 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 47 additions & 115 deletions examples/doremi/doremi/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from transformers import LlamaConfig

from nanotron import logging
Expand All @@ -26,141 +28,71 @@

class BaseLLaMa(NanotronModel):
@torch.no_grad()
def init_model_randomly(self, init_method, scaled_init_method):
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Args:
init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""

model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""

for module_name, module in model.named_modules():
if isinstance(module, TensorParallelColumnLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers

if full_param_name in initialized_parameters:
# Already initialized
continue
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)

if "weight" == param_name:
init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module_name, param_name = param_name.rsplit(".", 1)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelRowLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue
if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
scaled_init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module = model.get_submodule(module_name)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TritonRMSNorm):
assert {"weight"} == {name for name, _ in module.named_parameters()}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
param.fill_(1)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelEmbedding):
# TODO @thomasw21: Handle tied embeddings
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()}

assert isinstance(module.weight, NanotronParameter)
if module.weight.is_tied:
tied_info = module.weight.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
full_param_name = f"{module_name}.weight"

if full_param_name in initialized_parameters:
# Already initialized
continue
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not intialized")

init_method(module.weight)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)

assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
Expand Down
23 changes: 23 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
library_name: nanotron
---

# Mamba

Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/nanotron/)

## 🚀 Quickstart
3outeille marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install -r requirements.txt
# Run training
./examples/mamba/train_mamba.sh
3outeille marked this conversation as resolved.
Show resolved Hide resolved
```

![mamba](./assets/loss_mamba.png)

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5
## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
Binary file added examples/mamba/assets/loss_mamba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 64 additions & 0 deletions examples/mamba/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch

from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs
from nanotron.config.utils_config import cast_str_to_torch_dtype


@dataclass
class MambaInit:
# mamba_ssm.models.mixer_seq_simple._init_weights
initializer_range: float = 0.02
rescale_prenorm_residual: bool = (True,)
n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP


@dataclass
class ModelArgs:
"""Arguments related to model architecture"""

model_config: NanotronConfigs
init_method: Union[MambaInit, ExistingCheckpointInit]
dtype: Optional[torch.dtype] = None
make_vocab_size_divisible_by: int = 1
ddp_bucket_cap_mb: int = 25

def __post_init__(self):
if self.dtype is None:
self.dtype = torch.bfloat16
if isinstance(self.dtype, str):
self.dtype = cast_str_to_torch_dtype(self.dtype)

# if self.model_config.max_position_embeddings is None:
# self.model_config.max_position_embeddings = 0


@dataclass(kw_only=True) # pylint: disable=unexpected-keyword-arg
class MambaConfig(Config):
"""Main configuration class"""

model: ModelArgs


@dataclass
class MambaModelConfig:
"""Configuration for a Mamba model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""

is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion
d_model: int = 2560
num_hidden_layers: int = 64
vocab_size: int = 50277
ssm_cfg: Optional[dict] = None
rms_norm: bool = True
fused_add_norm: bool = True
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
# ==== Custom ======
dtype: str = "float32"
rms_norm_eps: float = 1e-5
pad_token_id: Optional[int] = None
101 changes: 101 additions & 0 deletions examples/mamba/config_mamba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 24
hf_dataset_config_name: null
hf_dataset_or_datasets:
roneneldan/TinyStories: 1.0
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: test
run: mamba
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
initializer_range: 0.02
n_residuals_per_layer: 1
rescale_prenorm_residual: true
make_vocab_size_divisible_by: 1
model_config:
d_model: 1536
dtype: bfloat16
fused_add_norm: true
is_mamba_config: true
num_hidden_layers: 48
pad_token_id: null
pad_vocab_size_multiple: 8
residual_in_fp32: true
rms_norm: true
rms_norm_eps: 1.0e-05
ssm_cfg:
bias: false
conv_bias: true
d_conv: 4
d_state: 16
dt_init: random
dt_init_floor: 0.0001
dt_max: 0.1
dt_min: 0.001
dt_rank: auto
dt_scale: 1.0
expand: 2
use_fast_path: true
vocab_size: 50277
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 90
lr_decay_style: cosine
lr_warmup_steps: 10
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 2048
train_steps: 100
val_check_interval: -1
Loading
Loading