diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5141302e..523d5ef1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,14 +19,14 @@ repos: args: - --fix - --exit-non-zero-on-fix - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: - - --profile=black - - --skip-glob=wandb/**/* - - --thirdparty=wandb + # - repo: https://github.com/PyCQA/isort + # rev: 5.12.0 + # hooks: + # - id: isort + # args: + # - --profile=black + # - --skip-glob=wandb/**/* + # - --thirdparty=wandb - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: diff --git a/examples/doremi/doremi/llama.py b/examples/doremi/doremi/llama.py index aa9f47eb..8ae9202a 100644 --- a/examples/doremi/doremi/llama.py +++ b/examples/doremi/doremi/llama.py @@ -1,6 +1,8 @@ +import math from typing import Dict, Optional, Union import torch +import torch.nn as nn from transformers import LlamaConfig from nanotron import logging @@ -26,15 +28,12 @@ class BaseLLaMa(NanotronModel): @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ + model = self initialized_parameters = set() # Handle tensor parallelism @@ -42,125 +41,58 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - for module_name, module in model.named_modules(): - if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers - if full_param_name in initialized_parameters: - # Already initialized - continue + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") + module_name, param_name = param_name.rsplit(".", 1) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" - if full_param_name in initialized_parameters: - # Already initialized - continue + if full_param_name in initialized_parameters: + # Already initialized + continue - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") + module = model.get_submodule(module_name) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if isinstance(module, TensorParallelColumnLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) diff --git a/examples/mamba/README.md b/examples/mamba/README.md new file mode 100644 index 00000000..5c31d07f --- /dev/null +++ b/examples/mamba/README.md @@ -0,0 +1,23 @@ +--- +library_name: nanotron +--- + +# Mamba + +Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/nanotron/) + +## 🚀 Quickstart + +```bash +pip install -r requirements.txt +# Run training +./examples/mamba/train_mamba.sh +``` + +![mamba](./assets/loss_mamba.png) + +> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 + +## Credits +Credits to the following repositories from which the code was adapted: +- https://github.com/state-spaces/mamba diff --git a/examples/mamba/assets/loss_mamba.png b/examples/mamba/assets/loss_mamba.png new file mode 100644 index 00000000..f2bfc040 Binary files /dev/null and b/examples/mamba/assets/loss_mamba.png differ diff --git a/examples/mamba/config.py b/examples/mamba/config.py new file mode 100644 index 00000000..c7bdc7b9 --- /dev/null +++ b/examples/mamba/config.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch + +from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs +from nanotron.config.utils_config import cast_str_to_torch_dtype + + +@dataclass +class MambaInit: + # mamba_ssm.models.mixer_seq_simple._init_weights + initializer_range: float = 0.02 + rescale_prenorm_residual: bool = (True,) + n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP + + +@dataclass +class ModelArgs: + """Arguments related to model architecture""" + + model_config: NanotronConfigs + init_method: Union[MambaInit, ExistingCheckpointInit] + dtype: Optional[torch.dtype] = None + make_vocab_size_divisible_by: int = 1 + ddp_bucket_cap_mb: int = 25 + + def __post_init__(self): + if self.dtype is None: + self.dtype = torch.bfloat16 + if isinstance(self.dtype, str): + self.dtype = cast_str_to_torch_dtype(self.dtype) + + # if self.model_config.max_position_embeddings is None: + # self.model_config.max_position_embeddings = 0 + + +@dataclass(kw_only=True) # pylint: disable=unexpected-keyword-arg +class MambaConfig(Config): + """Main configuration class""" + + model: ModelArgs + + +@dataclass +class MambaModelConfig: + """Configuration for a Mamba model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + + is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion + d_model: int = 2560 + num_hidden_layers: int = 64 + vocab_size: int = 50277 + ssm_cfg: Optional[dict] = None + rms_norm: bool = True + fused_add_norm: bool = True + residual_in_fp32: bool = True + pad_vocab_size_multiple: int = 8 + # ==== Custom ====== + dtype: str = "float32" + rms_norm_eps: float = 1e-5 + pad_token_id: Optional[int] = None diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml new file mode 100644 index 00000000..2d79880d --- /dev/null +++ b/examples/mamba/config_mamba.yaml @@ -0,0 +1,101 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 24 + hf_dataset_config_name: null + hf_dataset_or_datasets: + roneneldan/TinyStories: 1.0 + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: test + run: mamba + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + initializer_range: 0.02 + n_residuals_per_layer: 1 + rescale_prenorm_residual: true + make_vocab_size_divisible_by: 1 + model_config: + d_model: 1536 + dtype: bfloat16 + fused_add_norm: true + is_mamba_config: true + num_hidden_layers: 48 + pad_token_id: null + pad_vocab_size_multiple: 8 + residual_in_fp32: true + rms_norm: true + rms_norm_eps: 1.0e-05 + ssm_cfg: + bias: false + conv_bias: true + d_conv: 4 + d_state: 16 + dt_init: random + dt_init_floor: 0.0001 + dt_max: 0.1 + dt_min: 0.001 + dt_rank: auto + dt_scale: 1.0 + expand: 2 + use_fast_path: true + vocab_size: 50277 +optimizer: + accumulate_grad_in_fp32: true + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 90 + lr_decay_style: cosine + lr_warmup_steps: 10 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 2 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 2048 + train_steps: 100 + val_check_interval: -1 diff --git a/examples/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py new file mode 100644 index 00000000..40c211f9 --- /dev/null +++ b/examples/mamba/create_config_mamba.py @@ -0,0 +1,148 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import math +import os + +from config import MambaConfig, MambaInit, MambaModelConfig + +from nanotron.config import ( + CheckpointsArgs, + DataArgs, + GeneralArgs, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + TokenizerArgs, + TokensArgs, +) +from nanotron.logging import human_format + +ssm_cfg_dtype = "bfloat16" +ssm_cfg = { + "d_state": 16, + "d_conv": 4, + "expand": 2, + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, + "bias": False, + "use_fast_path": True, +} +# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json +model_config = MambaModelConfig( + d_model=1536, + num_hidden_layers=48, + vocab_size=50277, + ssm_cfg=ssm_cfg, + rms_norm=True, + fused_add_norm=True, + residual_in_fp32=True, + pad_vocab_size_multiple=8, + # Custom + dtype=ssm_cfg_dtype, + rms_norm_eps=1e-5, +) + +# NOTE: vocab_size is normally round up to the nearest multiple of 10. But here, we don't really care +tie_embedding = model_config.vocab_size * model_config.d_model # model_config.vocab_size * model_config.d_model +expand = 2 if ("expand" not in ssm_cfg) else ssm_cfg["expand"] +ngroups = 1 if ("ngroups" not in ssm_cfg) else ssm_cfg["ngroups"] +d_state = 16 if ("d_state" not in ssm_cfg) else ssm_cfg["d_state"] +d_conv = 4 if ("d_conv" not in ssm_cfg) else ssm_cfg["d_conv"] +dt_rank = ( + math.ceil(model_config.d_model / 16) + if ("dt_rank" not in ssm_cfg or ssm_cfg["dt_rank"] == "auto") + else ssm_cfg["dt_rank"] +) + +d_inner = int(expand * model_config.d_model) +in_proj = model_config.d_model * d_inner * 2 + +# conv1d.weight = out_channels * (in_channels // groups) * kernel_size +# conv1d.bias = out_channels +conv1d = d_inner * int(d_inner / d_inner) * d_conv + d_inner +# linear.weight = out_features * in_features +in_proj = model_config.d_model * d_inner * 2 + 0 +x_proj = d_inner * (dt_rank + d_state * 2) + 0 +out_proj = d_inner * model_config.d_model + 0 +dt_proj = dt_rank * d_inner + d_inner +A_log = d_inner * d_state +D = d_inner +norm = model_config.d_model +norm_f = model_config.d_model + +num_params = human_format( + ( + tie_embedding + + model_config.num_hidden_layers * (A_log + D + in_proj + conv1d + x_proj + dt_proj + out_proj + norm + norm_f) + ) +).replace(".", "p") + +print(f"Model has {num_params} parameters") + +seed = 42 + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1 + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + learning_rate_scheduler=LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + ), +) + +parallelism = ParallelismArgs( + dp=2, + pp=2, + tp=2, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + tp_linear_async_communication=False, +) + +tokens = TokensArgs(sequence_length=2048, train_steps=100, micro_batch_size=2, batch_accumulation_per_replica=1) + +dataset = PretrainDatasetsArgs( + hf_dataset_or_datasets={"roneneldan/TinyStories": 1.0}, + hf_dataset_config_name=None, + hf_dataset_splits="train", + dataset_processing_num_proc_per_process=24, + dataset_overwrite_cache=False, + text_column_name="text", +) + +checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +os.makedirs(checkpoints_path, exist_ok=True) + +config = MambaConfig( + general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + parallelism=parallelism, + model=ModelArgs( + init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), + model_config=model_config, + ), + tokenizer=TokenizerArgs("gpt2"), + optimizer=optimizer, + logging=LoggingArgs(), + tokens=tokens, + data=DataArgs(dataset=dataset, seed=seed), + profiler=None, +) + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_mamba.yaml") diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py new file mode 100644 index 00000000..5065ed53 --- /dev/null +++ b/examples/mamba/mamba.py @@ -0,0 +1,929 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mamba model. +""" +import math +import os +from functools import partial +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from config import MambaModelConfig +from einops import rearrange, repeat +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs +from nanotron.config.utils_config import cast_str_to_torch_dtype +from nanotron.generation.generate_store import AttachableStore +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from selective_scan_interface import mamba_inner_fn, selective_scan_fn +from torch.nn import init + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + +# import lovely_tensors as lt; lt.monkey_patch() + +logger = logging.get_logger(__name__) + + +class Mamba(nn.Module): + def __init__( + self, + d_model: int, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + d_state: int = 16, + d_conv: int = 4, + expand: int = 2, + dt_rank: str = "auto", + dt_min: float = 0.001, + dt_max: float = 0.1, + dt_init: str = "random", + dt_scale: float = 1.0, + dt_init_floor: float = 1e-4, + conv_bias: bool = True, + bias: bool = False, + use_fast_path: bool = True, # Fused kernel options + layer_idx: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication is False + "Only ALL_REDUCE and tp_linear_async_communication=False are supported" + + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + # Get current tensor parallel rank + self.tp_pg = tp_pg + self.tp_rank = dist.get_rank(self.tp_pg) + + self.in_proj = TensorParallelColumnLinear( + in_features=self.d_model, + out_features=self.d_inner * 2, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=False, + contiguous_chunks=None, + ) + + assert self.d_inner % self.tp_pg.size() == 0 + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner // self.tp_pg.size(), + out_channels=self.d_inner // self.tp_pg.size(), + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner // self.tp_pg.size(), + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = TensorParallelRowLinear( + in_features=self.d_inner, + out_features=self.dt_rank + self.d_state * 2, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=None, + ) + + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + # Perform in `def init_model_randomly` + # dt_init_std = self.dt_rank**-0.5 * dt_scale + # if dt_init == "constant": + # nn.init.constant_(self.dt_proj.weight, dt_init_std) + # elif dt_init == "random": + # nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + # else: + # raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner // self.tp_pg.size(), **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner // self.tp_pg.size(), + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner // self.tp_pg.size(), device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.out_proj = TensorParallelRowLinear( + in_features=self.d_inner, + out_features=self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=None, + ) + + def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + + if inference_params is not None: + raise NotImplementedError("Inference params not tested yet.") + + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + xz = self.in_proj(hidden_states).transpose(1, 2) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if ( + self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1" + ): # Doesn't support outputting the states + y = mamba_inner_fn( + d_inner=self.d_inner, + tp_pg=self.tp_pg, + xz=xz, + conv1d_weight=self.conv1d.weight, + conv1d_bias=self.conv1d.bias, + x_proj_weight=self.x_proj.weight, + delta_proj_weight=self.dt_proj.weight, + A=A, + B=None, # input-dependent B + C=None, # input-dependent C + D=self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + assert self.d_inner % self.tp_pg.size() == 0 + x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + + out = self.out_proj(y) + return out + + def step( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, + x, + dt, + A, + B, + C, + self.D, + z=z, + dt_bias=self.dt_proj.bias, + dt_softplus=True, + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=device, + dtype=conv_dtype, + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size: int, initialize_states: bool = False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = ( + conv_state, + ssm_state, + ) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + + +class Embedding(nn.Module, AttachableStore): + def __init__( + self, + tp_pg: dist.ProcessGroup, + config: MambaModelConfig, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.d_model, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] + store = self.get_local_store() + if store is not None: + if "past_length" in store: + past_length = store["past_length"] + else: + past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + store["past_length"] = past_length + cumsum_mask[:, -1] + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + # input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds} + + +class MambaDecoderLayer(nn.Module): + def __init__( + self, + config: MambaModelConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + factory_kwargs = {"device": device, "dtype": dtype} + + if config.ssm_cfg is None: + ssm_cfg = {} + else: + ssm_cfg = config.ssm_cfg + + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.fused_add_norm = config.fused_add_norm + + self.mixer = Mamba( + d_model=config.d_model, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs, + ) + + self.norm = partial( + nn.LayerNorm if not config.rms_norm else RMSNorm, + eps=config.rms_norm_eps, + **factory_kwargs, + )(config.d_model) + + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + residual: Optional[Union[torch.Tensor, TensorPointer]], + inference_params=None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + if not self.fused_add_norm: + # self.layer_idx was assigned when calling create_block + # residual=None happens only at the first block + residual = hidden_states if (self.layer_idx == 0) else hidden_states + residual + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + + return { + "hidden_states": hidden_states, + "sequence_mask": sequence_mask, # NOTE(fmom): dunno how to use it for now. Just keep it + "residual": residual, + } + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class MambaModel(nn.Module): + def __init__( + self, + config: MambaModelConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "input_mask"}, + module_output_keys={"input_embeds"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=MambaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "layer_idx": layer_idx, + "device": self.p2p.device, + "dtype": cast_str_to_torch_dtype(config.dtype), + }, + module_input_keys={"hidden_states", "sequence_mask", "residual"}, + module_output_keys={"hidden_states", "sequence_mask", "residual"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=RMSNorm, + module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, + module_input_keys={"x", "residual"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.d_model, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + + hidden_encoder_states = { + "hidden_states": output["input_embeds"], + "sequence_mask": input_mask, + "residual": output["input_embeds"], + } + + for block in self.decoder: + hidden_encoder_states = block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm( + x=hidden_encoder_states["hidden_states"], + residual=hidden_encoder_states["residual"], + )["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + # model_config = self.config + # d_ff = model_config.intermediate_size + # d_qkv = model_config.d_model // model_config.num_attention_heads + # block_compute_costs = { + # # CausalSelfAttention (qkv proj + attn out) + MLP + # LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.d_model + # + 3 * d_ff * model_config.d_model, + # # This is the last lm_head + # TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model, + # } + + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + MambaDecoderLayer: 1, + # This is the last lm_head + TensorParallelColumnLinear: 0, + } + log_rank( + "get_block_compute_costs() Not implemented yet", + logger=logger, + level=logging.INFO, + rank=0, + ) + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + # world_size = self.parallel_context.world_pg.size() + # try: + # num_key_values_heads = self.config.num_key_value_heads + # except AttributeError: + # num_key_values_heads = self.config.num_attention_heads + + # model_flops, hardware_flops = get_flops( + # num_layers=self.config.num_hidden_layers, + # hidden_size=self.config.d_model, + # num_heads=self.config.num_attention_heads, + # num_key_value_heads=num_key_values_heads, + # vocab_size=self.config.vocab_size, + # ffn_hidden_size=self.config.intermediate_size, + # seq_len=sequence_length, + # batch_size=global_batch_size, + # recompute_granularity=self.parallel_config.recompute_granularity, + # ) + + # model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + # hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + + # TODO(fmom): undo hardcoding of model_flops_per_s and hardware_flops_per_s + model_flops_per_s = 0 + hardware_flops_per_s = 0 + log_rank( + "get_flops_per_sec() Not implemented yet", + logger=logger, + level=logging.INFO, + rank=0, + ) + return model_flops_per_s, hardware_flops_per_s + + +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + # NOTE(fmom): undo transpose for now since Mamba is not using TP + # loss = sharded_cross_entropy( + # sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + # ).transpose(0, 1) + + loss = sharded_cross_entropy(sharded_logits, label_ids, group=self.tp_pg, dtype=torch.float) + + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} + + +class MambaForTraining(NanotronModel): + def __init__( + self, + config: MambaModelConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + self.model = MambaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ) + + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + input_mask=input_mask, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + return {"loss": loss} + + @torch.no_grad() + def init_model_randomly(self, config): + model = self + initialized_parameters = set() + + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + initializer_range = config.model.init_method.initializer_range + n_residuals_per_layer = config.model.init_method.n_residuals_per_layer + num_hidden_layers = config.model.model_config.num_hidden_layers + rescale_prenorm_residual = config.model.init_method.rescale_prenorm_residual + d_model = config.model.model_config.d_model + + if config.model.model_config.ssm_cfg is not None: + dt_init = config.model.model_config.ssm_cfg["dt_init"] + dt_rank = config.model.model_config.ssm_cfg["dt_rank"] + dt_scale = config.model.model_config.ssm_cfg["dt_scale"] + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + + if isinstance(module, TensorParallelColumnLinear) or isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + raise ValueError("We don't use bias for TensorParallelColumnLinear and TensorParallelRow") + else: + raise ValueError(f"Who the fuck is {param_name}?") + + if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) + + elif isinstance(module, nn.Conv1d): + fan_in = None + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(param) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, nn.Linear): + fan_in = None + + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + + if config.model.model_config.ssm_cfg is not None: + if dt_rank == "auto": + dt_init_std = math.ceil(d_model / 16) ** -0.5 * dt_scale + else: + dt_init_std = dt_rank**-0.5 * dt_scale + + if dt_init == "constant": + nn.init.constant_(module.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(module.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, std=initializer_range) + + elif isinstance(module, RMSNorm) or isinstance(module, nn.LayerNorm): + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, Mamba): + # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ + # In Mamba, only those 3 parameters don't have weight decay. + if param_name in ["dt_bias", "A_log", "D"]: + param._no_weight_decay = True + + else: + raise Exception(f"Parameter {full_param_name} was not initialized") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + @staticmethod + def get_embeddings_lm_head_tied_names(): + return [ + "model.token_position_embeddings.pp_block.token_embedding.weight", + "model.lm_head.pp_block.weight", + ] + + # TODO(fmom): implement get_block_compute_costs + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + # TODO(fmom): implement get_flops_per_sec + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt new file mode 100644 index 00000000..abc3bd62 --- /dev/null +++ b/examples/mamba/requirements.txt @@ -0,0 +1,5 @@ +torch==2.1.0 +einops +causal-conv1d==1.1.0 +mamba-ssm==1.1.4 +flash-attn==2.5.0 diff --git a/examples/mamba/selective_scan_interface.py b/examples/mamba/selective_scan_interface.py new file mode 100644 index 00000000..123641c8 --- /dev/null +++ b/examples/mamba/selective_scan_interface.py @@ -0,0 +1,513 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import causal_conv1d_cuda +import selective_scan_cuda +import torch +import torch.nn.functional as F +from causal_conv1d import causal_conv1d_fn +from einops import rearrange, repeat +from torch.cuda.amp import custom_bwd, custom_fwd + + +class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + ): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + x, + out, + None, + ctx.delta_softplus, + False, # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return ( + du, + ddelta, + dA, + dB, + dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None, + ) + + +def selective_scan_fn( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, +): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, +): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum("bdn,dn->bd", x, C) + else: + if C.dim() == 3: + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) + else: + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, + checkpoint_lvl=1, + ): + """ + xz: (batch, dim, seqlen) + """ + assert checkpoint_lvl in [0, 1] + batch, L = xz.shape[0], xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + + # x, z = xz.chunk(2, dim=1) + assert d_inner % tp_pg.size() == 0 + x, z = xz.view(batch, d_inner // tp_pg.size(), 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + # ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + + ctx.d_inner = d_inner + ctx.tp_pg = tp_pg + + ctx.save_for_backward( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + out, + ) + + return rearrange(out_z, "b d l -> b l d") + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + ( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + out, + ) = ctx.saved_tensors + batch, L = xz.shape[0], xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + + # x, z = xz.chunk(2, dim=1) + assert ctx.d_inner % ctx.tp_pg.size() == 0 + x, z = xz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + + # dx, dz = dxz.chunk(2, dim=1) + assert ctx.d_inner % ctx.tp_pg.size() == 0 + dx, dz = dxz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) + dx = dx.squeeze(2) + dz = dz.squeeze(2) + + dout = rearrange(dout, "b l e -> b e l") + + if dout.stride(-1) != 1: + dout = dout.contiguous() + + (dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z,) = selective_scan_cuda.bwd( + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + scan_intermediates, + out, + dz, + ctx.delta_softplus, + True, # option to recompute out_z + ) + + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return ( + None, # d_inner + None, # tp_pg + dxz, + dconv1d_weight, + dconv1d_bias, + dx_proj_weight, + ddelta_proj_weight, + dA, + dB, + dC, + dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, + dC_proj_bias, + None, + ) + + +def mamba_inner_fn( + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, +): + return MambaInnerFn.apply( + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B, + C, + D, + delta_bias, + B_proj_bias, + C_proj_bias, + delta_softplus, + ) + + +def mamba_inner_ref( + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + out_proj_bias, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, +): + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py new file mode 100644 index 00000000..4d587fcf --- /dev/null +++ b/examples/mamba/train_mamba.py @@ -0,0 +1,33 @@ +import argparse +import os +import sys + +from config import MambaModelConfig +from mamba import MambaForTraining +from trainer import MambaTrainer + +from nanotron import logging + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from run_train import get_dataloader # noqa + +logger = logging.get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + config_file = args.config_file + + # Load trainer and data + trainer = MambaTrainer(config_file, model_config_class=MambaModelConfig, model_class=MambaForTraining) + dataloader = get_dataloader(trainer) + + # Train + trainer.train(dataloader) diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh new file mode 100755 index 00000000..36384c8c --- /dev/null +++ b/examples/mamba/train_mamba.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Simple script to create a tiny mamba model and train it + +set -e -x + +# Create the YAML config file + +EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) +REPO_PATH=$(dirname $EXAMPLE_PATH) +python $EXAMPLE_PATH/create_config_mamba.py + +# Setup from environment variables + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export FI_PROVIDER="efa" + +python -u -m torch.distributed.run \ + --nproc_per_node 8 \ + --nnodes 1 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + $REPO_PATH/mamba/train_mamba.py --config-file $EXAMPLE_PATH/config_mamba.yaml diff --git a/examples/mamba/trainer.py b/examples/mamba/trainer.py new file mode 100644 index 00000000..e3dec27a --- /dev/null +++ b/examples/mamba/trainer.py @@ -0,0 +1,155 @@ +from typing import Optional, Type, Union + +from config import ExistingCheckpointInit, MambaConfig, MambaInit +from torch.nn.parallel import DistributedDataParallel + +from nanotron import logging +from nanotron.trainer import DistributedTrainer + +logger = logging.get_logger(__name__) + +from nanotron import distributed as dist +from nanotron.config import ParallelismArgs +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.parallel.tied_parameters import ( + create_pg_for_tied_weights, + get_tied_id_to_param, + tie_parameters, +) +from nanotron.serialize import load_weights, parse_ckpt_path + + +class MambaTrainer(DistributedTrainer): + def __init__( + self, + config_or_config_file: Union[MambaConfig, str], + config_class: Type[MambaConfig] = MambaConfig, + model_config_class: Optional[Type] = None, + model_class: Type[NanotronModel] = None, + ): + assert config_class == MambaConfig + super().__init__(config_or_config_file, config_class, model_config_class, model_class) + + def _mark_tied_parameters( + self, + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + ): + # Tie embeddings + embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names() + if len(embeddings_lm_head_tied_names) > 0: + shared_embeddings = [ + ( + target, + ( + parallel_context.world_rank_matrix[ + dist.get_rank(parallel_context.expert_pg), + get_pp_rank_of(target, module=model), + dist.get_rank(parallel_context.dp_pg), + dist.get_rank(parallel_context.tp_pg), + ], + ), + ) + for target in embeddings_lm_head_tied_names + ] + tie_parameters( + root_module=model, + ties=shared_embeddings, + parallel_context=parallel_context, + reduce_op=dist.ReduceOp.SUM, + ) + + # Tie custom params + model.tie_custom_params() + + # Sync all parameters that have the same name and that are not sharded + assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point" + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + name = f"{module_name}.{param_name}" + + if isinstance(param, NanotronParameter) and (param.is_sharded or param.is_tied): + continue + + if isinstance(module, TensorParallelRowLinear) and "bias" == param_name: + # bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it + continue + + shared_weights = [ + ( + name, + # sync across TP group + tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))), + ) + ] + + if ( + parallel_config is None + or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE + or hasattr(model.config.model.model_config, "is_mamba_config") + ): + # We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce + # when TP=2 we have LN that is duplicated across TP, so by design it's tied + reduce_op = None + else: + reduce_op = dist.ReduceOp.SUM + + tie_parameters( + root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + ) + + create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context) + + def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: + unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model + + # Load or initialize model weights + self.init_checkpoint_path = parse_ckpt_path(config=self.config) + reloaded_from_checkpoint = False + if self.init_checkpoint_path is not None: + # Reload from a training checkpoint + log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + self.param_shard_metadata = load_weights( + model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + ) + reloaded_from_checkpoint = True + if not reloaded_from_checkpoint: + log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) + if isinstance(self.config.model.init_method, ExistingCheckpointInit): + # Initialize model from an pretrained model checkpoint + self.param_shard_metadata = load_weights( + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.config.model.init_method.path, + ) + elif isinstance(self.config.model.init_method, MambaInit): + + unwrapped_model.init_model_randomly(config=self.config) + # Synchronize parameters so that the model is consistent + # sync all params across dp + for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) + + # sync tied params across tied groups + for (_, group_ranks), param in sorted( + get_tied_id_to_param( + parameters=model.parameters(), + root_module=unwrapped_model, + ).items(), + key=lambda x: x[0], + ): + group = self.parallel_context.world_ranks_to_pg[group_ranks] + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) + else: + raise ValueError(f"Unsupported {self.config.model.init_method}") + + return model diff --git a/examples/moe/llamoe.py b/examples/moe/llamoe.py index fb500cc2..83f23ab6 100644 --- a/examples/moe/llamoe.py +++ b/examples/moe/llamoe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMa MoE model.""" +import math from typing import Dict, Optional, Union import torch @@ -24,6 +25,9 @@ ) from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from moe import dMoE +from torch import nn +from torch.nn import init + from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs @@ -33,10 +37,7 @@ from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.pipeline_parallel.block import ( - PipelineBlock, - TensorPointer, -) +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -47,7 +48,6 @@ ) from nanotron.random import RandomStates from nanotron.utils import checkpoint_method -from torch import nn logger = logging.get_logger(__name__) @@ -834,12 +834,8 @@ def forward( return {"loss": loss} @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ @@ -850,122 +846,77 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - # TODO @nouamane: initialization for dmoe - for module_name, module in model.named_modules(): + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + if isinstance(module, TensorParallelColumnLinear): - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelRowLinear): - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - - # assert initialized_parameters == { - # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - # if param.is_tied - # else name - # for name, param in model.named_parameters() - # }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" - # TODO @nouamane: init dMoE + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + fan_in = None + + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 10b09105..ec48f8bc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,7 +11,11 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit +from nanotron.config.models_config import ( + ExistingCheckpointInit, + NanotronConfigs, + RandomInit, +) from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( RecomputeGranularity, @@ -21,9 +25,7 @@ ) from nanotron.generation.sampler import SamplerType from nanotron.logging import get_logger -from nanotron.parallel.pipeline_parallel.engine import ( - PipelineEngine, -) +from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode logger = get_logger(__name__) @@ -384,7 +386,7 @@ def get_config_from_file( skip_unused_config_keys: bool = False, skip_null_keys: bool = False, ) -> Config: - """Get a config objet from a file (python or YAML) + """Get a config object from a file (python or YAML) Args: config_path: path to the config file diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index da8760d7..17a859fe 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -18,12 +18,7 @@ from nanotron import distributed as dist from nanotron import logging -from nanotron.config import ( - Config, - LRSchedulerArgs, - OptimizerArgs, - ParallelismArgs, -) +from nanotron.config import Config, LRSchedulerArgs, OptimizerArgs, ParallelismArgs from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel @@ -40,9 +35,7 @@ ) from nanotron.optim.zero import ZeroDistributedOptimizer from nanotron.parallel import ParallelContext -from nanotron.parallel.tensor_parallel.nn import ( - TensorParallelLinearMode, -) +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.random import ( RandomStates, get_current_random_state, @@ -166,7 +159,6 @@ def init_optimizer_and_grad_accumulator( # Fix the root_model module_id_to_prefix[id(unwrapped_model)] = "" - # named parameters named_parameters = list(unwrapped_model.get_named_params_with_correct_tied()) # Basic optimizer builder @@ -175,8 +167,8 @@ def basic_optimizer_builder(named_param_groups): named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: AdamW( # pylint: disable=E0601 param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, + lr=optimizer_args.learning_rate_scheduler.learning_rate, eps=optimizer_args.adam_eps, betas=(optimizer_args.adam_beta1, optimizer_args.adam_beta2), fused=optimizer_args.torch_adam_is_fused, diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 5fc175f8..ac54e9fb 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -58,7 +58,7 @@ def params_gen(): yield from params_gen() @abstractmethod - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): ... def tie_custom_params(self) -> None: diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9a69800f..b930e0eb 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -15,7 +15,7 @@ """ PyTorch LLaMa model. """ from typing import Dict, Optional, Union - +import math import torch from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -881,14 +881,10 @@ def forward( label_mask=label_mask, )["loss"] return {"loss": loss} - + @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ @@ -899,126 +895,59 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - for module_name, module in model.named_modules(): + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit('.', 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + torch.nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) if param.is_tied diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index d67361d5..81b5bca6 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -30,8 +30,9 @@ flash_attn_with_kvcache, ) from torch import nn -from torch.nn import LayerNorm, init +from torch.nn import LayerNorm from torch.nn import functional as F +from torch.nn import init from nanotron import distributed as dist from nanotron.config import ParallelismArgs, Starcoder2Config @@ -44,9 +45,15 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer -from nanotron.parallel.sharded_parameters import SplitConfig, mark_all_parameters_in_module_as_sharded +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.functional import column_linear, sharded_cross_entropy +from nanotron.parallel.tensor_parallel.functional import ( + column_linear, + sharded_cross_entropy, +) from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -1458,169 +1465,85 @@ def tie_custom_params(self) -> None: ) @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ model = self - # Set to 0: LayerNorm bias / all bias initialized_parameters = set() # Handle tensor parallelism - with torch.no_grad(): - module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} - # Fix the root_model - module_id_to_prefix[id(model)] = "" - - for module_name, module in model.named_modules(): - if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, LayerNorm): - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, MQAColumnLinears): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - # TODO @thomasw21: handle the case there's no bias - assert {"q.weight", "q.bias", "kv.weight", "kv.bias"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if ".weight" in param_name: - init_method(param) - elif ".bias" in param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.weight" + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + + if isinstance(module, TensorParallelColumnLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, LayerNorm): + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, MQAColumnLinears): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") - if full_param_name in initialized_parameters: - # Already initialized - continue + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" @staticmethod def get_embeddings_lm_head_tied_names() -> List[str]: diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index aeb0c027..a2a3d4aa 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -13,9 +13,17 @@ from nanotron.logging import log_rank from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.sanity_checks import assert_tensor_synced_across_pg, check_optim_state_in_sync +from nanotron.sanity_checks import ( + assert_tensor_synced_across_pg, + check_optim_state_in_sync, +) from nanotron.serialize.metadata import CheckpointMetadata, load_meta, save_meta -from nanotron.serialize.optimizer import load_lr_scheduler, load_optimizer, save_lr_scheduler, save_optimizer +from nanotron.serialize.optimizer import ( + load_lr_scheduler, + load_optimizer, + save_lr_scheduler, + save_optimizer, +) from nanotron.serialize.weights import load_weights, save_weights """ @@ -131,10 +139,10 @@ def save( tied_info = tied_param.get_tied_info() group_ranks = tied_info.global_ranks group = parallel_context.world_ranks_to_pg[group_ranks] + assert_tensor_synced_across_pg( tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" ) - if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): check_optim_state_in_sync(optimizer, parallel_context.dp_pg) @@ -178,6 +186,7 @@ def save( src=get_global_rank(group=group, group_rank=reference_rank), group=group, ) + torch.testing.assert_close( tensor, reference_tensor, diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index aab0d531..27a36146 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -5,6 +5,8 @@ import torch from packaging.version import Version from safetensors.torch import safe_open, save_file +from safetensors import SafetensorError + from torch import nn from tqdm import tqdm @@ -88,7 +90,13 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde ) path.parent.mkdir(exist_ok=True, parents=True) try: - save_file(tensors={"data": param_or_buffer}, filename=path, metadata=metadata) + tensors = {"data": param_or_buffer} + + # Mamba has some parameters that should not be weight decayed + if hasattr(model.get_parameter(name), "_no_weight_decay"): + tensors.update({"_no_weight_decay": torch.tensor(model.get_parameter(name)._no_weight_decay)}) + + save_file(tensors=tensors, filename=path, metadata=metadata) except Exception as e: log_rank( f"Error saving {path} with {metadata}", @@ -251,6 +259,13 @@ def load_weights( with safe_open(path, framework="pt", device=str(param.device)) as fi: # TODO @thomasw21: Choose only a slice if we switch the TP topology param_or_buffer[:] = fi.get_tensor("data") + + # Only Mamba params has this attribute + try: + param._no_weight_decay = fi.get_tensor("_no_weight_decay") + except SafetensorError: + pass + elif not path.parent.exists(): raise ValueError( f"Checkpoint is empty or checkpoint structure is not matching the model architecture." diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d2c9dd66..dc4a44dd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -84,7 +84,6 @@ save, save_random_states, ) -from nanotron.utils import init_method_normal, scaled_init_method_normal logger = logging.get_logger(__name__) @@ -583,13 +582,9 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: root_folder=self.config.model.init_method.path, ) elif isinstance(self.config.model.init_method, RandomInit): - # Initialize model randomly - unwrapped_model.init_model_randomly( - init_method=init_method_normal(self.config.model.init_method.std), - scaled_init_method=scaled_init_method_normal( - self.config.model.init_method.std, self.model_config.num_hidden_layers - ), - ) + + unwrapped_model.init_model_randomly(config=self.config) + # Synchronize parameters so that the model is consistent # sync all params across dp for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): @@ -638,7 +633,7 @@ def _init_model( module.init_rotary_embeddings() # Mark some parameters as tied - mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) # count number of parameters num_params = sum(p.numel() for p in model.parameters()) @@ -774,6 +769,14 @@ def save_checkpoint(self) -> Path: return checkpoint_path + def _mark_tied_parameters( + self, + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + ): + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + def mark_tied_parameters( model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 80b3680b..14fe1ca8 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -123,26 +123,6 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() - -def init_method_normal(sigma: float) -> Callable[[torch.Tensor], None]: - """Init method based on N(0, sigma).""" - - def init_(tensor: torch.Tensor): - torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma: float, num_layers: int) -> Callable[[torch.Tensor], None]: - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) - - def init_(tensor: torch.Tensor): - torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device