Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba PR #83

Merged
merged 60 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
abafcb9
fix: llama import
3outeille Jan 24, 2024
e8b8314
feat: at leading loading mamba properly
3outeille Jan 24, 2024
d9513b0
loss going down
3outeille Jan 24, 2024
147fd25
refacto: init mamba weights
3outeille Jan 24, 2024
3e4dda1
fix: dtype for brrr compatibility
3outeille Jan 28, 2024
bebbf6f
fix: bring back strict union to fix NanotronConfigs + use torch.dtype…
3outeille Jan 29, 2024
ad384b9
fix: revert back to casting str to torch dtype
3outeille Jan 29, 2024
02d21de
fix
3outeille Jan 29, 2024
1168bfe
fix: mismatch params counting due to no tie embedding weight
3outeille Jan 29, 2024
7bfc56c
fix: run_generate is now compatible with Brrr
3outeille Jan 30, 2024
1fe1cd4
chore: single file mamba slow (for residual fix)
3outeille Jan 30, 2024
720df8e
fix: residual for slow mamba
3outeille Feb 2, 2024
7902959
fix(mamba-slow): transpose embedding
3outeille Feb 3, 2024
b2ca553
update(mamba-slow)
3outeille Feb 8, 2024
3c4ef42
fix: unified random unit
3outeille Feb 8, 2024
afcbb10
feat(optimizer): can now apply weight decay to specifcs parameters
3outeille Feb 8, 2024
7216b11
Merge branch 'mamba' into mamba-update
3outeille Feb 12, 2024
a6a163f
refacto: mamba1
3outeille Feb 12, 2024
ad17887
feat: add selective scan interface
3outeille Feb 12, 2024
6bfd6aa
feat: make fast path compatible with TP
3outeille Feb 12, 2024
f0e219b
refacto: no split weight
3outeille Feb 13, 2024
ae312fe
refacto: cleaner init weights
3outeille Feb 13, 2024
ff26ff5
refacto: init weights
3outeille Feb 14, 2024
0ab7c9b
feat: save/load _no_weight_decay attribute for Mamba model
3outeille Feb 16, 2024
3caf7fe
clean import
3outeille Feb 28, 2024
ec6f954
clean utils in mamba
3outeille Feb 28, 2024
1773420
cleaning initialization for mamba
3outeille Feb 28, 2024
6fc0d9a
no sync conv and rmsnorm for mamba
3outeille Feb 28, 2024
6e8fd1f
Merge branch 'mamba-update' into mamba-pr
3outeille Feb 28, 2024
fa6953d
clean run_train
3outeille Feb 28, 2024
20dc9ec
clean run_train
3outeille Feb 28, 2024
beb73ee
rename folder
3outeille Feb 29, 2024
d6ae604
add assert only ALL_REDUCE mode
3outeille Feb 29, 2024
5d161b1
add typing to mamba
3outeille Feb 29, 2024
befb66d
remove init method in favor of new initialization method
3outeille Feb 29, 2024
35370b7
fix test serialize
3outeille Feb 29, 2024
790eddc
small fix
3outeille Feb 29, 2024
37867d1
add mamba example
3outeille Feb 29, 2024
54a9c63
fix logger
3outeille Mar 1, 2024
3953a5c
update README
3outeille Feb 29, 2024
3fa2194
update yaml
3outeille Feb 29, 2024
d4bba7c
move to examples
3outeille Mar 1, 2024
a52c37f
update readme
3outeille Mar 1, 2024
1f54ae2
discard run_generate for now
3outeille Mar 1, 2024
7616ce7
revert dynamic weight decay for now
3outeille Mar 4, 2024
e944f63
unifying init_model_randomly for other models
3outeille Mar 1, 2024
8224221
decouple mamba logic from core
3outeille Mar 1, 2024
d28c968
fix logging + init_method
3outeille Mar 4, 2024
576eb56
delete yaml file
3outeille Mar 4, 2024
ea2fa96
small fix
3outeille Mar 4, 2024
a8448bb
fix tp assert
3outeille Mar 4, 2024
106d59c
separate config mamba
3outeille Mar 4, 2024
db889b9
update requirements readme
3outeille Mar 4, 2024
f4816ee
various fix
3outeille Mar 4, 2024
d243b70
Merge branch 'main' into main
3outeille Mar 4, 2024
1098e19
fix import
3outeille Mar 4, 2024
ae6af18
change directory level
3outeille Mar 4, 2024
baba00e
decouple Mamba logic from sync weight
3outeille Mar 4, 2024
7b513e8
fix decorator
3outeille Mar 4, 2024
dbefcc7
small fixes
3outeille Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 47 additions & 115 deletions examples/doremi/doremi/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from transformers import LlamaConfig

from nanotron import logging
Expand All @@ -26,141 +28,71 @@

class BaseLLaMa(NanotronModel):
@torch.no_grad()
def init_model_randomly(self, init_method, scaled_init_method):
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Args:
init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/

Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""

model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""

for module_name, module in model.named_modules():
if isinstance(module, TensorParallelColumnLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers

if full_param_name in initialized_parameters:
# Already initialized
continue
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)

if "weight" == param_name:
init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module_name, param_name = param_name.rsplit(".", 1)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelRowLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue
if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
scaled_init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module = model.get_submodule(module_name)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TritonRMSNorm):
assert {"weight"} == {name for name, _ in module.named_parameters()}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
param.fill_(1)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelEmbedding):
# TODO @thomasw21: Handle tied embeddings
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()}

assert isinstance(module.weight, NanotronParameter)
if module.weight.is_tied:
tied_info = module.weight.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
full_param_name = f"{module_name}.weight"

if full_param_name in initialized_parameters:
# Already initialized
continue
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not intialized")

init_method(module.weight)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)

assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
Expand Down
27 changes: 27 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
library_name: nanotron
---

# Mamba

Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/nanotron/)

## 🚀 Quickstart
3outeille marked this conversation as resolved.
Show resolved Hide resolved

```bash

pip install einops
pip install causal-conv1d>=1.1.0,<1.2.0
pip install mamba-ssm

# Run training
./examples/mamba/train_mamba.sh
3outeille marked this conversation as resolved.
Show resolved Hide resolved
```

![mamba](./assets/loss_mamba.png)

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5

## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
Binary file added examples/mamba/assets/loss_mamba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 126 additions & 0 deletions examples/mamba/mamba/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from dataclasses import dataclass, fields
from typing import Optional, Union

import torch
import yaml

from nanotron.config import (
CheckpointsArgs,
DataArgs,
ExistingCheckpointInit,
GeneralArgs,
LoggingArgs,
LRSchedulerArgs,
PretrainDatasetsArgs,
NanotronConfigs,
OptimizerArgs,
ParallelismArgs,
ProfilerArgs,
TokenizerArgs,
TokensArgs,
get_config_from_file,
)
from nanotron.config.lighteval_config import LightEvalConfig
from nanotron.config.utils_config import cast_str_to_torch_dtype, serialize


@dataclass
class MambaInit:
# mamba_ssm.models.mixer_seq_simple._init_weights
initializer_range: float = 0.02
rescale_prenorm_residual: bool = (True,)
n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP


@dataclass
class ModelArgs:
"""Arguments related to model architecture"""

model_config: NanotronConfigs
init_method: Union[MambaInit, ExistingCheckpointInit]
dtype: Optional[torch.dtype] = None
make_vocab_size_divisible_by: int = 1
ddp_bucket_cap_mb: int = 25

def __post_init__(self):
if self.dtype is None:
self.dtype = torch.bfloat16
if isinstance(self.dtype, str):
self.dtype = cast_str_to_torch_dtype(self.dtype)

# if self.model_config.max_position_embeddings is None:
# self.model_config.max_position_embeddings = 0


@dataclass
class Config:
3outeille marked this conversation as resolved.
Show resolved Hide resolved
"""Main configuration class"""

general: GeneralArgs
parallelism: ParallelismArgs
model: ModelArgs
tokenizer: TokenizerArgs
checkpoints: Optional[CheckpointsArgs] = None
logging: Optional[LoggingArgs] = None
tokens: Optional[TokensArgs] = None
optimizer: Optional[OptimizerArgs] = None
data: Optional[DataArgs] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None

@classmethod
def create_empty(cls):
cls_fields = fields(cls)
return cls(**{f.name: None for f in cls_fields})

def __post_init__(self):
# Some final sanity checks across separate arguments sections:
if self.profiler is not None and self.profiler.profiler_export_path is not None:
assert self.tokens.train_steps < 10

if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps
)

# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None

@property
def global_batch_size(self):
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp

def save_as_yaml(self, file_path: str):
config_dict = serialize(self)
file_path = str(file_path)
with open(file_path, "w") as f:
yaml.dump(config_dict, f)

# Sanity test config can be reloaded
_ = get_config_from_file(file_path, config_class=self.__class__)

def as_dict(self) -> dict:
return serialize(self)


@dataclass
class MambaConfig:
"""Configuration for a Mamba model

Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""

is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion
d_model: int = 2560
num_hidden_layers: int = 64
vocab_size: int = 50277
ssm_cfg: Optional[dict] = None
rms_norm: bool = True
fused_add_norm: bool = True
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
# ==== Custom ======
dtype: str = "float32"
rms_norm_eps: float = 1e-5
pad_token_id: Optional[int] = None
Loading