Skip to content

Commit

Permalink
just in case
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 23, 2024
1 parent 27133e1 commit 5c09e11
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 39 deletions.
3 changes: 2 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def get_valid_dataloader_from_data_stage(
micro_batch_size=trainer.micro_batch_size,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
shuffle=True,
)

return valid_dataloader
Expand Down Expand Up @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
stage = cast(DatasetStageArgs, stage)

log_rank(
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set",
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set",
logger=logger,
level=logging.INFO,
rank=0,
Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def build_nanoset_dataloader(
consumed_train_samples: int = 0,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
shuffle: bool = False,
) -> DataLoader:

# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
Expand All @@ -49,7 +50,7 @@ def build_nanoset_dataloader(
dl_rank=dp_rank,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
shuffle=False,
shuffle=shuffle,
)

return DataLoader(
Expand Down
4 changes: 2 additions & 2 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers

Expand Down Expand Up @@ -136,7 +136,7 @@ def validate_batch_iter(
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState() # TODO: do i need state?
state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state?
self.nb_microbatches = nb_microbatches

outputs = []
Expand Down
4 changes: 4 additions & 0 deletions src/nanotron/parallel/pipeline_parallel/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

import torch

from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
Expand Down Expand Up @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState):
microbatches_activations_to_recv = collections.deque()
activations_buffer = collections.deque()

# Reinitialise counter
nb_forwards = 0

def register_activation_requiring_backward(self, activation: torch.Tensor):
pass

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/serialize/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TrainingMetadata:
last_stage_idx: Optional[int] = None
data_stages: Optional[List[DataStageMetadata]] = None

last_validation_stage_idx: Optional[int] = None

def __post_init__(self):
# NOTE: this is a sanity check after loading a trained checkpoint
total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages)
Expand Down
171 changes: 136 additions & 35 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def __init__(
for stage in self.config.data_stages
]
self.metadata: TrainingMetadata = TrainingMetadata(
consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages
consumed_train_samples=0,
last_train_step=0,
last_stage_idx=0,
data_stages=data_stages,
last_validation_stage_idx=0,
)

# Setup tensorboard write and log writers on output rank
Expand All @@ -253,6 +257,8 @@ def __init__(
self.limit_val_batches = self.config.tokens.limit_val_batches
# NOTE: the dataloader currently in use for the current training stage
self.current_dataloader: Optional[DataLoader] = None
# NOTE: the dataloader currently in use for the current validation stage
self.current_validation_dataloader: Optional[DataLoader] = None

self.post_init()

Expand Down Expand Up @@ -300,9 +306,108 @@ def _print_training_plan(self):
)
log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)

def _update_dataloader_based_on_training_stages(
self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False
):
def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]):
# NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT:
# 1. We call this function EVERY TIME we run the validation loop
# 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset
# in the first iteration and subsequent validations will fail
# TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders
# TODO(tj.solergibert) Check the tuple case below
from collections.abc import Generator

if not hasattr(self.config, "data_stages") or self.config.data_stages is None:

if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case
dataloader = dataloaders[0]
else:
dataloader = dataloaders

self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)

return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
# remove this in the next PR
self.current_validation_dataloader = dataloaders
return

assert len(dataloaders) > 0, "No dataloaders provided"
assert len(dataloaders) == len(
self.config.data_stages
), "Number of dataloaders should match the number of dataset stages"

def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc

log_rank(
f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory",
logger=logger,
level=logging.INFO,
)

# NOTE: Clear dataloader from memory
del dataloader.dataset
del dataloader.sampler
del dataloader.batch_sampler

gc.collect()

dataloader = None

for stage_idx, stage in enumerate(self.config.data_stages):
if stage_idx < self.metadata.last_stage_idx:
continue

if (
stage_idx is not self.metadata.last_validation_stage_idx
and self.metadata.last_validation_stage_idx is not None
):
self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index
# Si cambiamos de stage borramo el antiguo
# En ambos casos recrear el que toca !!!
# TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE
stage = cast(DatasetStageArgs, stage)
print(
stage.name
) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!!

log_rank(
f"Ese print bueno {stage.name}",
logger=logger,
level=logging.INFO,
rank=0,
)
# self.metadata.last_stage_idx = stage_idx
"""
if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro
prev_stage_name = self.config.data_stages[stage_idx - 1].name
prev_dataloader = dataloaders[prev_stage_name]
if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
"""
log_rank(
f"Preparing validation DataLoader from stage {stage.name}",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = dataloaders[stage.name]
# NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
dataloader = dataloader() if callable(dataloader) else dataloader
break

self.current_validation_dataloader_lenght = 200 # TODO len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)

def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
from collections.abc import Generator

if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
Expand All @@ -311,16 +416,9 @@ def _update_dataloader_based_on_training_stages(
dataloader = dataloaders[0]
else:
dataloader = dataloaders

if is_validation:
self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
else:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
Expand All @@ -337,7 +435,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc

log_rank(
f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory",
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
Expand All @@ -360,7 +458,7 @@ def find_stage_idx_to_resume():

stage_idx_to_resume = find_stage_idx_to_resume()

for stage_idx, stage in enumerate(self.config.data_stages):
for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init
if stage_idx < self.metadata.last_stage_idx:
continue

Expand All @@ -378,7 +476,7 @@ def find_stage_idx_to_resume():

self.metadata.last_stage_idx = stage_idx

if is_resume_from_training and not is_validation:
if is_resume_from_training:
remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, self.config, self.metadata
)
Expand All @@ -396,15 +494,9 @@ def find_stage_idx_to_resume():
break

if dataloader is not None:
if is_validation:
self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
else:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)

def train(
self,
Expand Down Expand Up @@ -443,7 +535,6 @@ def train(

self.iteration_start_time = time.time()
self._update_dataloader_based_on_training_stages(train_dataloader_or_dls)
self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True)

# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
Expand All @@ -457,9 +548,18 @@ def train(
# Cada validation es mucho mas largo que un training step
# Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas
if self.iteration_step % self.config.tokens.val_check_interval == 0:
global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader)
log_rank(
f"KOMO???? {self.iteration_step}",
logger=logger,
level=logging.INFO,
rank=0,
)
self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls)
val_global_loss, val_lang_losses = self.validation_step(
dataloader=self.current_validation_dataloader
)
self.validation_step_time = time.time()
self.validation_step_logs(global_loss, lang_losses)
self.validation_step_logs(val_global_loss, val_lang_losses)

# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
Expand Down Expand Up @@ -592,10 +692,10 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id]
].append(loss)
# Global loss
global_loss_avg = torch.stack(outputs).sum()
global_loss_avg = torch.mean(torch.stack(outputs))
# Sync losses across DP
for lang in lang_losses.keys():
lang_losses[lang] = torch.stack(lang_losses[lang]).sum()
lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))
dist.all_reduce(
lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG
) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos...
Expand Down Expand Up @@ -630,7 +730,6 @@ def train_step_logs(
lr = self.lr_scheduler.get_last_lr()[0]

log_entries = [
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
Expand Down Expand Up @@ -718,7 +817,6 @@ def validation_step_logs(
assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks"

log_entries = [
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"validation_consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
Expand All @@ -734,7 +832,7 @@ def validation_step_logs(
"human_format",
), # , "1.6E"),
LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"),
LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"),
LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"),
LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
]

Expand All @@ -746,12 +844,15 @@ def validation_step_logs(
)

# NOTE: only one rank writes to wandb
# NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time.
# Set commit=False to log values with the next wandb.log with the training logs
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
},
commit=False,
)

self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)
Expand Down

0 comments on commit 5c09e11

Please sign in to comment.