Skip to content

Commit

Permalink
Merge pull request #83 from 3outeille/main
Browse files Browse the repository at this point in the history
Add Mamba PR
  • Loading branch information
NouamaneTazi authored Mar 4, 2024
2 parents d059fca + dbefcc7 commit 5e128fc
Show file tree
Hide file tree
Showing 23 changed files with 2,309 additions and 578 deletions.
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ repos:
args:
- --fix
- --exit-non-zero-on-fix
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
# - repo: https://github.com/PyCQA/isort
# rev: 5.12.0
# hooks:
# - id: isort
# args:
# - --profile=black
# - --skip-glob=wandb/**/*
# - --thirdparty=wandb
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
Expand Down
162 changes: 47 additions & 115 deletions examples/doremi/doremi/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from transformers import LlamaConfig

from nanotron import logging
Expand All @@ -26,141 +28,71 @@

class BaseLLaMa(NanotronModel):
@torch.no_grad()
def init_model_randomly(self, init_method, scaled_init_method):
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Args:
init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""

model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""

for module_name, module in model.named_modules():
if isinstance(module, TensorParallelColumnLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers

if full_param_name in initialized_parameters:
# Already initialized
continue
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)

if "weight" == param_name:
init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module_name, param_name = param_name.rsplit(".", 1)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelRowLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue
if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
scaled_init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
module = model.get_submodule(module_name)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TritonRMSNorm):
assert {"weight"} == {name for name, _ in module.named_parameters()}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"

if full_param_name in initialized_parameters:
# Already initialized
continue

if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
param.fill_(1)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelEmbedding):
# TODO @thomasw21: Handle tied embeddings
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()}

assert isinstance(module.weight, NanotronParameter)
if module.weight.is_tied:
tied_info = module.weight.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
full_param_name = f"{module_name}.weight"

if full_param_name in initialized_parameters:
# Already initialized
continue
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not intialized")

init_method(module.weight)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)

assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
Expand Down
23 changes: 23 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
library_name: nanotron
---

# Mamba

Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/nanotron/)

## 🚀 Quickstart

```bash
pip install -r requirements.txt
# Run training
./examples/mamba/train_mamba.sh
```

![mamba](./assets/loss_mamba.png)

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5
## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
Binary file added examples/mamba/assets/loss_mamba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 64 additions & 0 deletions examples/mamba/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch

from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs
from nanotron.config.utils_config import cast_str_to_torch_dtype


@dataclass
class MambaInit:
# mamba_ssm.models.mixer_seq_simple._init_weights
initializer_range: float = 0.02
rescale_prenorm_residual: bool = (True,)
n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP


@dataclass
class ModelArgs:
"""Arguments related to model architecture"""

model_config: NanotronConfigs
init_method: Union[MambaInit, ExistingCheckpointInit]
dtype: Optional[torch.dtype] = None
make_vocab_size_divisible_by: int = 1
ddp_bucket_cap_mb: int = 25

def __post_init__(self):
if self.dtype is None:
self.dtype = torch.bfloat16
if isinstance(self.dtype, str):
self.dtype = cast_str_to_torch_dtype(self.dtype)

# if self.model_config.max_position_embeddings is None:
# self.model_config.max_position_embeddings = 0


@dataclass(kw_only=True) # pylint: disable=unexpected-keyword-arg
class MambaConfig(Config):
"""Main configuration class"""

model: ModelArgs


@dataclass
class MambaModelConfig:
"""Configuration for a Mamba model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""

is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion
d_model: int = 2560
num_hidden_layers: int = 64
vocab_size: int = 50277
ssm_cfg: Optional[dict] = None
rms_norm: bool = True
fused_add_norm: bool = True
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
# ==== Custom ======
dtype: str = "float32"
rms_norm_eps: float = 1e-5
pad_token_id: Optional[int] = None
Loading

0 comments on commit 5e128fc

Please sign in to comment.