diff --git a/run_train.py b/run_train.py index ed9b5607..80c7a426 100644 --- a/run_train.py +++ b/run_train.py @@ -256,6 +256,7 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..b8bfb303 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -23,6 +23,7 @@ def build_nanoset_dataloader( consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -49,7 +50,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bf690bd0..bc6dc5b5 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -12,7 +12,7 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? self.nb_microbatches = nb_microbatches outputs = [] diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 583068cd..b1cc36ad 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -231,7 +231,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -253,6 +257,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -300,9 +306,108 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages( - self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False - ): + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + dataloader = None + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + and self.metadata.last_validation_stage_idx is not None + ): + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Si cambiamos de stage borramo el antiguo + # En ambos casos recrear el que toca !!! + # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE + stage = cast(DatasetStageArgs, stage) + print( + stage.name + ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! + + log_rank( + f"Ese print bueno {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + # self.metadata.last_stage_idx = stage_idx + """ + if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + """ + log_rank( + f"Preparing validation DataLoader from stage {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -311,16 +416,9 @@ def _update_dataloader_based_on_training_stages( dataloader = dataloaders[0] else: dataloader = dataloaders - - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -337,7 +435,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -360,7 +458,7 @@ def find_stage_idx_to_resume(): stage_idx_to_resume = find_stage_idx_to_resume() - for stage_idx, stage in enumerate(self.config.data_stages): + for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init if stage_idx < self.metadata.last_stage_idx: continue @@ -378,7 +476,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training and not is_validation: + if is_resume_from_training: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -396,15 +494,9 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -443,7 +535,6 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) - self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) @@ -457,9 +548,18 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + log_rank( + f"KOMO???? {self.iteration_step}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) self.validation_step_time = time.time() - self.validation_step_logs(global_loss, lang_losses) + self.validation_step_logs(val_global_loss, val_lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -592,10 +692,10 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] ].append(loss) # Global loss - global_loss_avg = torch.stack(outputs).sum() + global_loss_avg = torch.mean(torch.stack(outputs)) # Sync losses across DP for lang in lang_losses.keys(): - lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) dist.all_reduce( lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... @@ -630,7 +730,6 @@ def train_step_logs( lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -718,7 +817,6 @@ def validation_step_logs( assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "validation_consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -734,7 +832,7 @@ def validation_step_logs( "human_format", ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), ] @@ -746,12 +844,15 @@ def validation_step_logs( ) # NOTE: only one rank writes to wandb + # NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time. + # Set commit=False to log values with the next wandb.log with the training logs if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.log( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, - } + }, + commit=False, ) self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)