Skip to content

Commit

Permalink
Before lunch
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 17, 2024
1 parent d91f9e1 commit d0c14e3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 58 deletions.
13 changes: 3 additions & 10 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage(
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size,
dataset_tokens=data.dataset.dataset_tokens,
random_seed=data.seed,
)
Expand Down Expand Up @@ -222,7 +221,6 @@ def get_dataloader_from_data_stage(
def get_valid_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
valid_split_num_samples: int,
# consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples
):

Expand All @@ -245,7 +243,6 @@ def get_valid_dataloader_from_data_stage(
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
valid_split_num_samples=valid_split_num_samples,
dataset_tokens=data.dataset.dataset_tokens,
is_valid=True,
random_seed=data.seed,
Expand All @@ -259,7 +256,6 @@ def get_valid_dataloader_from_data_stage(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=0,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)
Expand Down Expand Up @@ -319,21 +315,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size

log_rank(
f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set",
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples)
get_valid_dataloader_from_data_stage(trainer, stage.data)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer, stage.data, valid_split_num_samples=valid_split_num_samples
)
else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data)
)
dataloaders[stage.name] = dataloader
return dataloaders
Expand Down
6 changes: 4 additions & 2 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __post_init__(self):

@dataclass
class MultilingualNanosetDatasetsArgs:
dataset_folder: Union[str, dict, List[str]]
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, dict, List[str]]
dataset_tokens: List[
int
] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str])
Expand All @@ -125,7 +126,8 @@ def __post_init__(self):
self.dataset_folder = list(tmp_dataset_folder.keys())
self.dataset_weights = list(tmp_dataset_folder.values())

assert len(self.dataset_folder) == len(self.dataset_tokens)
assert len(self.training_folder) == len(self.validation_folder)
assert len(self.training_folder) == len(self.dataset_tokens)


@dataclass
Expand Down
76 changes: 30 additions & 46 deletions src/nanotron/data/multilingual_nanoset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import warnings
from math import ceil
from typing import Dict, List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -32,7 +31,6 @@ def __init__(
sequence_length: int,
token_size: int,
train_split_num_samples: int,
valid_split_num_samples: int,
dataset_tokens: List[int],
is_valid: bool = False,
dataset_weights: Union[List[float], None] = None,
Expand All @@ -49,7 +47,6 @@ def __init__(
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.valid_split_num_samples = valid_split_num_samples
self.dataset_tokens = dataset_tokens
self.is_valid = is_valid
self.random_seed = random_seed
Expand Down Expand Up @@ -80,36 +77,11 @@ def __init__(
self.dataset_weights
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts
self.valid_dataset_lenghts = [
ceil(weight * valid_split_num_samples) for weight in self.dataset_weights
] # Better not tu use numpy so we don't get overflow issues
# Assert that we have sufficient samples to build the valid split
for ds_index in range(len(self.dataset_lengths)):
assert (
self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index]
), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples."
self.train_dataset_lenghts = [
a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts)
] # Subtract the valid samples from the training dataset

if is_valid: # Valid MultilingualNanoset
self.split_num_samples = valid_split_num_samples
self.split_samples_per_epoch = valid_split_num_samples
self.num_epochs = 1
self.split_dataset_lenghts = self.valid_dataset_lenghts
self.split_dataset_offsets = self.train_dataset_lenghts
self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths)

else: # Train MultilingualNanoset
self.split_num_samples = train_split_num_samples
self.split_samples_per_epoch = sum(self.train_dataset_lenghts)
self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1
self.split_dataset_lenghts = self.train_dataset_lenghts
self.split_dataset_offsets = [
0 for _ in range(len(self.dataset_lengths))
] # For training there is NO offset

self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()
self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index()

self.print_nanoset_info()

Expand Down Expand Up @@ -139,31 +111,45 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:

return tokens

def build_nanoset_index(self) -> np.ndarray:
def build_train_nanoset_index(self) -> np.ndarray:
"""
Build dataset index and dataset sample index
Build train dataset index and dataset sample index
"""
# Compute samples per epoch and number of epochs
samples_per_epoch = sum(self.dataset_lengths)
num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1
# Build the dataset indexes for 1 epoch
dataset_index, dataset_sample_index = build_nanoset_index_helper(
n_samples=self.split_samples_per_epoch,
weights=self.dataset_weights,
dataset_sizes=self.split_dataset_lenghts,
offsets=self.split_dataset_offsets,
dataset_index, dataset_sample_index = build_train_nanoset_index_helper(
n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths
)
# Shuffle the indexes the same way
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)])
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
# Just keep the necessary samples
dataset_index = dataset_index[: self.split_num_samples]
dataset_sample_index = dataset_sample_index[: self.split_num_samples]
dataset_index = dataset_index[: self.train_split_num_samples]
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]

return dataset_index, dataset_sample_index

@jit(nopython=True, cache=True)
def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray:
"""
Build valid dataset index and dataset sample index
"""
dataset_index = []
dataset_sample_index = []

for i, length in enumerate(dataset_lengths):
dataset_index.extend([i] * length)
dataset_sample_index.extend(range(length))

return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long")

def print_nanoset_info(self):

log_rank(
Expand Down Expand Up @@ -191,8 +177,8 @@ def print_nanoset_info(self):


@jit(nopython=True, cache=True)
def build_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int]
def build_train_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given multiple datasets and a weighting array, build samples indexes
Expand All @@ -219,9 +205,7 @@ def build_nanoset_index_helper(

# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = (
current_samples[max_error_index] % dataset_sizes[max_error_index]
) + offsets[max_error_index]
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]

# Update the total samples for the selected dataset
current_samples[max_error_index] += 1
Expand Down

0 comments on commit d0c14e3

Please sign in to comment.