diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..33f9db41 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -131,4 +131,4 @@ tokens: micro_batch_size: 4 sequence_length: 1024 train_steps: 200 - val_check_interval: -1 + val_check_interval: 3 diff --git a/run_train.py b/run_train.py index 39cda23b..ed9b5607 100644 --- a/run_train.py +++ b/run_train.py @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.validation_folder, + dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, + dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang is_valid=True, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index dd2c157d..e5ea3ec1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -125,6 +125,7 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) + self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()} self.dataset_tokens = list(self.lang_to_ids.values()) assert len(self.training_folder) == len( self.validation_folder diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2411e5fa..7ae34dd5 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -801,7 +801,14 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # TODO esto de entrada da float/float = float + + +# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) +# Y pasa el assert close!! +# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) class Loss(nn.Module): @@ -818,14 +825,16 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -847,7 +856,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -864,12 +873,13 @@ def forward( input_ids=input_ids, input_mask=input_mask, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..bf690bd0 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +30,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -52,7 +54,7 @@ def forward( output["loss"] = output["loss"] / self.nb_microbatches # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): + if not isinstance(output["loss"], TensorPointer) and not is_validation: assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -138,12 +140,15 @@ def validate_batch_iter( self.nb_microbatches = nb_microbatches outputs = [] + lang_ids = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -151,15 +156,23 @@ def validate_batch_iter( send_activation() # We make `output` a dict + # TODO convert to dict other items returned by the model (MoE aux loss for example) + # But in next if statement be careful if we return other items in all of the pp processes + # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - - return outputs + # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol + # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language + # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces + # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. + outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + return outputs, lang_ids class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3f4c5189..583068cd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -300,7 +300,9 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): + def _update_dataloader_based_on_training_stages( + self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False + ): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -309,9 +311,16 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da dataloader = dataloaders[0] else: dataloader = dataloaders - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -328,7 +337,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -369,7 +378,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training: + if is_resume_from_training and not is_validation: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -387,9 +396,15 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -428,9 +443,23 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) + self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation step + # TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm + # Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps + # Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar + # Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro, + # Cada validation es mucho mas largo que un training step + # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas + if self.iteration_step % self.config.tokens.val_check_interval == 0: + global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + self.validation_step_time = time.time() + self.validation_step_logs(global_loss, lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -546,12 +575,36 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_ids = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), + nb_microbatches=self.current_validation_dataloader_lenght, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() + } + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_id in zip(outputs, lang_ids): + lang_losses[ + self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] + ].append(loss) + # Global loss + global_loss_avg = torch.stack(outputs).sum() + # Sync losses across DP + for lang in lang_losses.keys(): + lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + dist.all_reduce( + lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG + ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, @@ -561,7 +614,7 @@ def train_step_logs( # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -641,6 +694,68 @@ def train_step_logs( else: exit(0) + def validation_step_logs( + self, + global_loss: torch.Tensor, + lang_losses: torch.Tensor, + ) -> None: + # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 + dist.barrier() + torch.cuda.synchronize() + total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + tokens_per_sec = ( + total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) + ) # tokens_per_sec is calculated using sequence_length + # TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs + model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + ) + + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: + assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + + log_entries = [ + # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), + LogItem( + "validation_consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format" + ), # , ".1f"), + LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + ] + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()] + ) + + # NOTE: only one rank writes to wandb + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + wandb.log( + { + **{log_item.tag: log_item.scalar_value for log_item in log_entries}, + "iteration_step": self.iteration_step, + } + ) + + self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) + def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" # TODO: add max_position_embeddings