Skip to content

Commit

Permalink
Merge pull request #62 from huggingface/nouamane/fixes
Browse files Browse the repository at this point in the history
Refactoring tying mechanism + small fixes
  • Loading branch information
NouamaneTazi authored Feb 12, 2024
2 parents 43b5e6b + bae29be commit 939522e
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 98 deletions.
22 changes: 7 additions & 15 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.gradient_accumulator import (
FP32GradBucketManager,
Expand Down Expand Up @@ -157,24 +158,15 @@ def lr_lambda(current_step: int):
def init_optimizer_and_grad_accumulator(
model: nn.Module, optimizer_args: OptimizerArgs, parallel_context: ParallelContext
) -> Tuple[BaseOptimizer, GradientAccumulator]:
# Normalize DDP
normalized_model = model.module if isinstance(model, DistributedDataParallel) else model
# Unwrap DDP
unwrapped_model: NanotronModel = model.module if isinstance(model, DistributedDataParallel) else model

module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in normalized_model.named_modules()}
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in unwrapped_model.named_modules()}
# Fix the root_model
root_model_id = id(normalized_model)
module_id_to_prefix[root_model_id] = ""
module_id_to_prefix[id(unwrapped_model)] = ""

# named parameters
named_parameters = [
(
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name,
param,
)
for name, param in normalized_model.named_parameters()
]
named_parameters = list(unwrapped_model.get_named_params_with_correct_tied())

# Basic optimizer builder
def basic_optimizer_builder(named_param_groups):
Expand Down Expand Up @@ -262,7 +254,7 @@ def grad_optimizer_builder(named_param_groups):
)
if param.is_tied
else name
for name, param in normalized_model.named_parameters()
for name, param in unwrapped_model.named_parameters()
},
),
hook=get_fp32_accum_hook(
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def log_rank(

@lru_cache(maxsize=None)
def warn_once(
logger: Logger, msg: str, group: Optional[dist.ProcessGroup] = None, rank: Optional[int] = None, **kwargs
msg: str, logger: Logger, group: Optional[dist.ProcessGroup] = None, rank: Optional[int] = None, **kwargs
):
log_rank(msg=msg, logger=logger, level=logging.WARNING, group=group, rank=rank, **kwargs)

Expand Down
29 changes: 28 additions & 1 deletion src/nanotron/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -15,6 +15,7 @@

if TYPE_CHECKING:
from nanotron.config import NanotronConfigs
from nanotron.parallel.parameters import NanotronParameter

logger = logging.get_logger(__name__)

Expand All @@ -34,10 +35,36 @@ def __init__(self, *args, **kwargs) -> None:
self.input_pp_rank: int
self.output_pp_rank: int

# Useful mapping to get param names
self.module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in self.named_modules()}
self.module_id_to_prefix[id(self)] = ""

def get_named_params_with_correct_tied(self) -> Iterator[Tuple[str, "NanotronParameter"]]:
"""Return named parameters with correct tied params names.
For example in the case of tied kv heads in MQA, we need to make sure tied params names are correct."""

def params_gen():
for name, param in self.named_parameters():
if param.is_tied:
yield (
param.get_tied_info().get_full_name_from_module_id_to_prefix(
module_id_to_prefix=self.module_id_to_prefix
),
param,
)
else:
yield name, param

yield from params_gen()

@abstractmethod
def init_model_randomly(self, init_method, scaled_init_method):
...

def tie_custom_params(self) -> None:
"""Tie custom parameters. For example for MQA marks kv heads as tied."""
pass

@staticmethod
def get_embeddings_lm_head_tied_names() -> list[str]:
"""Returns the names of the embeddings and lm_head weights that are tied together. Returns empty list if not tied.
Expand Down
7 changes: 7 additions & 0 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,13 @@ def init_model_randomly(self, init_method, scaled_init_method):
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"

def get_embeddings_lm_head_tied_names(self):
"""Get the names of the tied embeddings and lm_head weights"""
if self.config.tie_word_embeddings is True:
return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
else:
return []

def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
Expand Down
46 changes: 31 additions & 15 deletions src/nanotron/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from nanotron import distributed as dist
from nanotron.config import ParallelismArgs, RecomputeGranularity, Starcoder2Config
from nanotron.distributed import get_global_rank
from nanotron.generation.generate_store import AttachableStore
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
Expand All @@ -53,7 +52,7 @@
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import create_tied_parameter
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.random import RandomStates, branch_random_state
from nanotron.utils import checkpoint_method

Expand Down Expand Up @@ -570,7 +569,6 @@ def __init__(

# Marking as tied/sharded
mark_all_parameters_in_module_as_sharded(self.q, pg=self.pg, split_config=SplitConfig(split_dim=0))
self._mark_kv_parameters_in_module_as_tied()

# Init
self.reset_parameters()
Expand All @@ -586,18 +584,6 @@ def reset_parameters(self) -> None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self._qkv_bias, -bound, bound)

def _mark_kv_parameters_in_module_as_tied(self):
for name, param in list(self.kv.named_parameters()):
new_param = create_tied_parameter(
parameter=param,
name=name,
global_ranks=tuple(sorted((get_global_rank(self.pg, i) for i in range(self.pg.size())))),
# Always has to be ReduceOp SUM as now this is always duplicated regardless of tp mode
reduce_op=dist.ReduceOp.SUM,
root_module=self.kv,
)
setattr(self.kv, name, new_param)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.use_MQAColumnLinearReduceScatterAsyncCommunication:
assert self._qkv_weight.requires_grad is False
Expand Down Expand Up @@ -1449,6 +1435,36 @@ def forward(
)["loss"]
}

def tie_custom_params(self) -> None:
# find all params with names qkv.kv.weight and qkv.kv.bias in them
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if ".qkv.kv." in name:
assert not param.is_tied, f"Parameter {name} is already tied"
shared_weights = [
(
name,
# This adds all the tp_ranks in one go
tuple(
sorted(
self.parallel_context.world_rank_matrix[
dist.get_rank(self.parallel_context.pp_pg),
dist.get_rank(self.parallel_context.dp_pg),
:,
]
)
),
)
]
tie_parameters(
root_module=self,
ties=shared_weights,
parallel_context=self.parallel_context,
# We always SUM grads, because kv weights are always duplicated in MQA
reduce_op=dist.ReduceOp.SUM,
)

@torch.no_grad()
def init_model_randomly(self, init_method, scaled_init_method):
model = self
Expand Down
7 changes: 4 additions & 3 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from nanotron import distributed as dist
from nanotron import logging
from nanotron.models import NanotronModel

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -55,7 +56,7 @@ def tuple_from_str(cls, string: str):
@dataclasses.dataclass
class TiedInfo:
name: str
# This allows us to define the scope in which `name` is valid.
# name must be defined starting from `root_module` (e.g. root_module.dense0.dense1.weight)
root_module: nn.Module
global_ranks: Tuple[int, ...]
# None signifies that we do not reduce
Expand All @@ -68,7 +69,7 @@ def get_full_name_from_model(self, model: nn.Module) -> str:
return self.get_full_name_from_module_id_to_prefix(module_id_to_prefix)

def get_full_name_from_module_id_to_prefix(self, module_id_to_prefix: Dict[int, str]) -> str:
return f"{module_id_to_prefix[id(self.root_module)]}{self.name}"
return f"{module_id_to_prefix[id(self.root_module)]}{self.name}" # this assumes root_module is part of module_id_to_prefix


@dataclasses.dataclass
Expand Down Expand Up @@ -127,7 +128,7 @@ def _set_metadata(self, key: str, value: Any):
metadata[key] = value

def mark_as_tied(
self, name: str, global_ranks: Tuple[int, ...], reduce_op: Optional[dist.ReduceOp], root_module: nn.Module
self, name: str, global_ranks: Tuple[int, ...], reduce_op: Optional[dist.ReduceOp], root_module: NanotronModel
):
self._set_metadata(
self.NANOTRON_PARAMETER_METADATA_TIED_KEY,
Expand Down
12 changes: 11 additions & 1 deletion src/nanotron/parallel/tied_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,24 @@ def get_tied_id_to_param(


def sync_tied_weights_gradients(
module: nn.Module,
module: nn.Module, # TODO: NanotronModel
parallel_context: ParallelContext,
grad_accumulator: Optional[GradientAccumulator],
):
tied_id_to_param = get_tied_id_to_param(
parameters=[param for param in module.parameters() if param.requires_grad], root_module=module
)

# Only first and last rank should print the warning
for rank in [0, parallel_context.world_pg.size() - 1]:
log_rank(
f"[Debug Tied Weights] Syncing the following tied weights: {tied_id_to_param.keys()}",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_pg,
rank=rank,
)

# Group tensors to reduce by process groups
# Important to use ordered dict in order to be synchronized across all ranks
group_ranks_and_reduce_op_to_tensors_to_reduce = OrderedDict()
Expand Down
Loading

0 comments on commit 939522e

Please sign in to comment.