Skip to content

Commit

Permalink
This looks good
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 24, 2024
1 parent 5c09e11 commit 94d6c2a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 96 deletions.
60 changes: 31 additions & 29 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
@@ -1,44 +1,46 @@
checkpoints:
checkpoint_interval: 1000
checkpoint_interval: 1000000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
training_folder: datasets/c4-es/train
validation_folder: datasets/c4-es/validation
training_folder:
datasets/c4-es/train: 0.85
datasets/c4-en/train: 0.05
datasets/c4-fr/train: 0.1
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
lang_to_ids:
es: 128002
en: 128003
fr: 128004
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
name: General purpose training (Blended dataset)
start_training_step: 1
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
lang_to_ids:
es: 128002
en: 128003
fr: 128004
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
name: Second purpose training (Single dataset)
start_training_step: 100
- data:
dataset:
training_folder:
datasets/c4-es/train: 0.6
datasets/c4-en/train: 0.3
datasets/c4-fr/train: 0.1
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
Expand All @@ -50,13 +52,13 @@ data_stages:

num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
start_training_step: 25
name: Third purpose training (>1 dataset)
start_training_step: 200
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: Nanoset
project: Multilingual
run: llama
seed: 42
step: null
Expand All @@ -75,12 +77,12 @@ model:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 512
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 512
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 1024
num_hidden_layers: 2
max_position_embeddings: 4096
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
Expand All @@ -89,7 +91,7 @@ model:
rope_theta: 500000.0
rms_norm_eps: 1.0e-06
rope_scaling: null
tie_word_embeddings: true
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
Expand All @@ -116,19 +118,19 @@ parallelism:
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp: 4
tp_linear_async_communication: false
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 10
micro_batch_size: 4
sequence_length: 1024
train_steps: 200
val_check_interval: 3
micro_batch_size: 3
sequence_length: 4096
train_steps: 800
val_check_interval: 50
9 changes: 8 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,15 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloader = (
get_valid_dataloader_from_data_stage(trainer, stage.data)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data)
else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data)
)
# TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times
# the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda
# funcs and directly create all dataloaders.
#
# This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders
# 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc.
# BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset
dataloaders[stage.name] = dataloader
return dataloaders

Expand Down
7 changes: 1 addition & 6 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,7 @@ def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(
dim=1
) # TODO esto de entrada da float/float = float


# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))
# Y pasa el assert close!!
# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)))
) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1]


class Loss(nn.Module):
Expand Down
16 changes: 6 additions & 10 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def validate_batch_iter(
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state?
state = PipelineEvalBatchState()
self.nb_microbatches = nb_microbatches

outputs = []
Expand All @@ -156,21 +156,17 @@ def validate_batch_iter(
send_activation()

# We make `output` a dict
# TODO convert to dict other items returned by the model (MoE aux loss for example)
# But in next if statement be careful if we return other items in all of the pp processes
# This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects?
if not isinstance(output, dict):
output = {"loss": output}

# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
# TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol
# Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language
# 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces
# Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal.
outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend??????
lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend????

outputs.extend(
list(output["sample_loss"])
) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors
lang_ids.extend(micro_batch["input_ids"][:, 0].tolist())

return outputs, lang_ids

Expand Down
103 changes: 53 additions & 50 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,25 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL
# 1. We call this function EVERY TIME we run the validation loop
# 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset
# in the first iteration and subsequent validations will fail
# TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders
# `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages)
# From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator.
# In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU)
#
# TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the
# DataLoader iterator every time we call this function from the current training stage, which is tracked during training
#
# Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage
# after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage
#
# NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they
# are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage,
# in this function we locally create the DataLoder from the lambda func --> Return Iterator)
#
# Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the
# object, not the object itself
#
# FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions
#
# TODO(tj.solergibert) Check the tuple case below
from collections.abc import Generator

Expand Down Expand Up @@ -339,11 +357,11 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL
self.config.data_stages
), "Number of dataloaders should match the number of dataset stages"

def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str):
import gc

log_rank(
f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory",
f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory",
logger=logger,
level=logging.INFO,
)
Expand All @@ -355,57 +373,38 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):

gc.collect()

dataloader = None

for stage_idx, stage in enumerate(self.config.data_stages):
if stage_idx < self.metadata.last_stage_idx:
continue
# NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage)
# in each and every training step.

if (
stage_idx is not self.metadata.last_validation_stage_idx
and self.metadata.last_validation_stage_idx is not None
):
): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index
self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index
# Si cambiamos de stage borramo el antiguo
# En ambos casos recrear el que toca !!!
# TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE
stage = cast(DatasetStageArgs, stage)
print(
stage.name
) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!!

log_rank(
f"Ese print bueno {stage.name}",
logger=logger,
level=logging.INFO,
rank=0,
)
# self.metadata.last_stage_idx = stage_idx
"""
if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro
# Delete previous stage DataLoader
prev_stage_name = self.config.data_stages[stage_idx - 1].name
prev_dataloader = dataloaders[prev_stage_name]

if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
"""
log_rank(
f"Preparing validation DataLoader from stage {stage.name}",
logger=logger,
level=logging.INFO,
rank=0,
)
if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(
prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name
)

self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index

# NOTE(tj.solergibert) Create AGAIN the DataLoader
dataloader = dataloaders[stage.name]
# NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
dataloader = dataloader() if callable(dataloader) else dataloader
break

self.current_validation_dataloader_lenght = 200 # TODO len(dataloader)
self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
) # NOTE(tj.solergibert) Create a Iterator from the DataLoader

def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
from collections.abc import Generator
Expand All @@ -431,11 +430,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da
self.config.data_stages
), "Number of dataloaders should match the number of dataset stages"

def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str):
import gc

log_rank(
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
Expand Down Expand Up @@ -472,7 +471,9 @@ def find_stage_idx_to_resume():

if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
clear_dataloader_from_memory(
prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name
)

self.metadata.last_stage_idx = stage_idx

Expand Down Expand Up @@ -548,18 +549,14 @@ def train(
# Cada validation es mucho mas largo que un training step
# Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas
if self.iteration_step % self.config.tokens.val_check_interval == 0:
log_rank(
f"KOMO???? {self.iteration_step}",
logger=logger,
level=logging.INFO,
rank=0,
)
self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls)
val_global_loss, val_lang_losses = self.validation_step(
dataloader=self.current_validation_dataloader
)
self.validation_step_time = time.time()
self.validation_step_logs(val_global_loss, val_lang_losses)
self.validation_step_logs(
val_global_loss, val_lang_losses
) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0

# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
Expand Down Expand Up @@ -684,6 +681,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
lang_losses = {
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys()
}
# WARNING(tj.solergibert) This mechanism will fail in the following [corner] case:
# If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group
# we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call
# will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty
# but of course it doesn't works as we then do the average across the DP group.
# We will fix this issue in the future if we encounter this problem again.
# A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1

# Compute losses
if isinstance(outputs[0], torch.Tensor):
# Multilingual losses
Expand All @@ -696,9 +701,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
# Sync losses across DP
for lang in lang_losses.keys():
lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))
dist.all_reduce(
lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG
) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos...
dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG)
dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG)
else:
global_loss_avg = None
Expand Down Expand Up @@ -833,7 +836,7 @@ def validation_step_logs(
), # , "1.6E"),
LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"),
LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"),
LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"),
]

# NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI.
Expand Down

0 comments on commit 94d6c2a

Please sign in to comment.