Skip to content

Commit

Permalink
Just in case
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 22, 2024
1 parent eed7bce commit 27133e1
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ tokens:
micro_batch_size: 4
sequence_length: 1024
train_steps: 200
val_check_interval: -1
val_check_interval: 3
4 changes: 2 additions & 2 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage(

with main_rank_first(trainer.parallel_context.world_pg):
valid_dataset = MultilingualNanoset(
dataset_folders=data.dataset.validation_folder,
dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder
sequence_length=trainer.sequence_length,
token_size=token_size,
dataset_tokens=data.dataset.dataset_tokens,
dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang
is_valid=True,
random_seed=data.seed,
)
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __post_init__(self):
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()}
self.dataset_tokens = list(self.lang_to_ids.values())
assert len(self.training_folder) == len(
self.validation_folder
Expand Down
30 changes: 20 additions & 10 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,14 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(
dim=1
) # TODO esto de entrada da float/float = float


# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))
# Y pasa el assert close!!
# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)))


class Loss(nn.Module):
Expand All @@ -818,14 +825,16 @@ def forward(
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38

loss = sharded_cross_entropy(
sample_loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(0, 1)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float)
# NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss.
# We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort.
# TODO @thomasw21: I think indexing causes a sync we don't actually want
# TODO @thomasw21: loss = loss[label_mask].sum()
return {"sample_loss": sample_loss}


class LlamaForTraining(NanotronModel):
Expand All @@ -847,7 +856,7 @@ def __init__(
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
module_output_keys={"sample_loss"},
)
self.parallel_context = parallel_context
self.config = config
Expand All @@ -864,12 +873,13 @@ def forward(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
outputs = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
)
outputs["loss"] = torch.mean(outputs["sample_loss"])
return outputs

@torch.no_grad()
def init_model_randomly(self, config: Config):
Expand Down
27 changes: 20 additions & 7 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Dict, Iterable, Optional, Union

import torch
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
Expand All @@ -12,8 +15,6 @@
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

logger = logging.get_logger(__name__)

Expand All @@ -29,6 +30,7 @@ def forward(
state: PipelineTrainBatchState,
micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]],
model: torch_nn.Module,
is_validation: bool = False,
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Increment the number of backwards
state.nb_forwards += 1
Expand All @@ -52,7 +54,7 @@ def forward(
output["loss"] = output["loss"] / self.nb_microbatches

# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
if not isinstance(output["loss"], TensorPointer) and not is_validation:
assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
return output
Expand Down Expand Up @@ -138,28 +140,39 @@ def validate_batch_iter(
self.nb_microbatches = nb_microbatches

outputs = []
lang_ids = []

with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
output = self.forward(
context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True
)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()

# We make `output` a dict
# TODO convert to dict other items returned by the model (MoE aux loss for example)
# But in next if statement be careful if we return other items in all of the pp processes
# This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects?
if not isinstance(output, dict):
output = {"loss": output}

# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)

return outputs
# TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol
# Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language
# 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces
# Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal.
outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend??????
lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend????

return outputs, lang_ids


class AllForwardAllBackwardPipelineEngine(PipelineEngine):
Expand Down
143 changes: 129 additions & 14 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ def _print_training_plan(self):
)
log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)

def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
def _update_dataloader_based_on_training_stages(
self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False
):
from collections.abc import Generator

if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
Expand All @@ -309,9 +311,16 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da
dataloader = dataloaders[0]
else:
dataloader = dataloaders
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)

if is_validation:
self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
else:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
Expand All @@ -328,7 +337,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc

log_rank(
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
Expand Down Expand Up @@ -369,7 +378,7 @@ def find_stage_idx_to_resume():

self.metadata.last_stage_idx = stage_idx

if is_resume_from_training:
if is_resume_from_training and not is_validation:
remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, self.config, self.metadata
)
Expand All @@ -387,9 +396,15 @@ def find_stage_idx_to_resume():
break

if dataloader is not None:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
if is_validation:
self.current_validation_dataloader_lenght = len(dataloader)
self.current_validation_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
else:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)

def train(
self,
Expand Down Expand Up @@ -428,9 +443,23 @@ def train(

self.iteration_start_time = time.time()
self._update_dataloader_based_on_training_stages(train_dataloader_or_dls)
self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True)

# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
self.training_step_time = time.time()

# Validation step
# TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm
# Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps
# Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar
# Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro,
# Cada validation es mucho mas largo que un training step
# Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas
if self.iteration_step % self.config.tokens.val_check_interval == 0:
global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader)
self.validation_step_time = time.time()
self.validation_step_logs(global_loss, lang_losses)

# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
Expand Down Expand Up @@ -546,12 +575,36 @@ def training_step(
return outputs, loss_avg

def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
outputs = self.pipeline_engine.validate_batch_iter(
outputs, lang_ids = self.pipeline_engine.validate_batch_iter(
model=self.model,
batch=(next(dataloader) for _ in range(self.limit_val_batches)),
nb_microbatches=self.limit_val_batches,
batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)),
nb_microbatches=self.current_validation_dataloader_lenght,
)
return outputs

lang_losses = {
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys()
}
# Compute losses
if isinstance(outputs[0], torch.Tensor):
# Multilingual losses
for loss, lang_id in zip(outputs, lang_ids):
lang_losses[
self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id]
].append(loss)
# Global loss
global_loss_avg = torch.stack(outputs).sum()
# Sync losses across DP
for lang in lang_losses.keys():
lang_losses[lang] = torch.stack(lang_losses[lang]).sum()
dist.all_reduce(
lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG
) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos...
dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG)
else:
global_loss_avg = None
lang_losses = None

return global_loss_avg, lang_losses

def train_step_logs(
self,
Expand All @@ -561,7 +614,7 @@ def train_step_logs(
# TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607
dist.barrier()
torch.cuda.synchronize()
elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000
elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000
tokens_per_sec = (
self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000)
) # tokens_per_sec is calculated using sequence_length
Expand Down Expand Up @@ -641,6 +694,68 @@ def train_step_logs(
else:
exit(0)

def validation_step_logs(
self,
global_loss: torch.Tensor,
lang_losses: torch.Tensor,
) -> None:
# TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607
dist.barrier()
torch.cuda.synchronize()
total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size
elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000
tokens_per_sec = (
total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000)
) # tokens_per_sec is calculated using sequence_length
# TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs
model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec(
iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000,
sequence_length=self.sequence_length,
global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches
)

if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks"

log_entries = [
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"validation_consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
"human_format",
), # , "12d"),
LogItem(
"validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"
), # , ".1f"),
LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"),
LogItem(
"validation_tokens_per_sec_per_gpu",
tokens_per_sec / self.parallel_context.world_pg.size(),
"human_format",
), # , "1.6E"),
LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"),
LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"),
LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
]

# NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI.
# Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474
# GitHub complains: https://github.com/wandb/wandb/issues/3035
log_entries.extend(
[LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()]
)

# NOTE: only one rank writes to wandb
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
)

self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)

def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
"""Initialize the model and load weights from checkpoint if needed."""
# TODO: add max_position_embeddings
Expand Down

0 comments on commit 27133e1

Please sign in to comment.