From 94d6c2a9931cf735366be9b356a01270e657a9a1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 16:18:36 +0000 Subject: [PATCH] This looks good --- examples/config_multilingual_nanoset.yaml | 60 +++++----- run_train.py | 9 +- src/nanotron/models/llama.py | 7 +- .../parallel/pipeline_parallel/engine.py | 16 +-- src/nanotron/trainer.py | 103 +++++++++--------- 5 files changed, 99 insertions(+), 96 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 33f9db41..5573a224 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -7,38 +7,40 @@ checkpoints: data_stages: - data: dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation + training_folder: + datasets/c4-es/train: 0.85 + datasets/c4-en/train: 0.05 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation lang_to_ids: es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 - name: General purpose training (Single dataset) + name: General purpose training (Blended dataset) start_training_step: 1 - data: dataset: training_folder: - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation lang_to_ids: es: 128002 - en: 128003 - fr: 128004 num_loading_workers: 1 seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 + name: Second purpose training (Single dataset) + start_training_step: 100 - data: dataset: training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation @@ -50,13 +52,13 @@ data_stages: num_loading_workers: 1 seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + name: Third purpose training (>1 dataset) + start_training_step: 200 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: Multilingual run: llama seed: 42 step: null @@ -75,12 +77,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +91,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -116,19 +118,19 @@ parallelism: expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: 3 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 800 + val_check_interval: 50 diff --git a/run_train.py b/run_train.py index 80c7a426..2ddff5ad 100644 --- a/run_train.py +++ b/run_train.py @@ -325,8 +325,15 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders + # 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ae34dd5..133442af 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -803,12 +803,7 @@ def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( dim=1 - ) # TODO esto de entrada da float/float = float - - -# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) -# Y pasa el assert close!! -# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bc6dc5b5..549ef5eb 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] @@ -156,21 +156,17 @@ def validate_batch_iter( send_activation() # We make `output` a dict - # TODO convert to dict other items returned by the model (MoE aux loss for example) - # But in next if statement be careful if we return other items in all of the pp processes - # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol - # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language - # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces - # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. - outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? - lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) return outputs, lang_ids diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b1cc36ad..f720446a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -311,7 +311,25 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL # 1. We call this function EVERY TIME we run the validation loop # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset # in the first iteration and subsequent validations will fail - # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # # TODO(tj.solergibert) Check the tuple case below from collections.abc import Generator @@ -339,11 +357,11 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", logger=logger, level=logging.INFO, ) @@ -355,57 +373,38 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): gc.collect() - dataloader = None - for stage_idx, stage in enumerate(self.config.data_stages): if stage_idx < self.metadata.last_stage_idx: continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. if ( stage_idx is not self.metadata.last_validation_stage_idx - and self.metadata.last_validation_stage_idx is not None - ): + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index - # Si cambiamos de stage borramo el antiguo - # En ambos casos recrear el que toca !!! - # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE - stage = cast(DatasetStageArgs, stage) - print( - stage.name - ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! - - log_rank( - f"Ese print bueno {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) - # self.metadata.last_stage_idx = stage_idx - """ - if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + # Delete previous stage DataLoader prev_stage_name = self.config.data_stages[stage_idx - 1].name prev_dataloader = dataloaders[prev_stage_name] - if isinstance(prev_dataloader, DataLoader): - # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) - """ - log_rank( - f"Preparing validation DataLoader from stage {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # NOTE(tj.solergibert) Create AGAIN the DataLoader dataloader = dataloaders[stage.name] # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it dataloader = dataloader() if callable(dataloader) else dataloader break - self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader_lenght = len(dataloader) self.current_validation_dataloader = sanity_check_dataloader( dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -431,11 +430,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -472,7 +471,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -548,18 +549,14 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - log_rank( - f"KOMO???? {self.iteration_step}", - logger=logger, - level=logging.INFO, - rank=0, - ) self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader ) self.validation_step_time = time.time() - self.validation_step_logs(val_global_loss, val_lang_losses) + self.validation_step_logs( + val_global_loss, val_lang_losses + ) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0 # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -684,6 +681,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses = { lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() } + # WARNING(tj.solergibert) This mechanism will fail in the following [corner] case: + # If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group + # we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call + # will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty + # but of course it doesn't works as we then do the average across the DP group. + # We will fix this issue in the future if we encounter this problem again. + # A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1 + # Compute losses if isinstance(outputs[0], torch.Tensor): # Multilingual losses @@ -696,9 +701,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten # Sync losses across DP for lang in lang_losses.keys(): lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) - dist.all_reduce( - lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG - ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) else: global_loss_avg = None @@ -833,7 +836,7 @@ def validation_step_logs( ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), - LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"), ] # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI.