Skip to content

Commit

Permalink
Added MultilingualNanoset
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 16, 2024
1 parent 0485fd6 commit 539832a
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 3 deletions.
125 changes: 122 additions & 3 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.config import (
DataArgs,
DatasetStageArgs,
MultilingualNanosetDatasetsArgs,
NanosetDatasetsArgs,
PretrainDatasetsArgs,
)
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
Expand Down Expand Up @@ -171,13 +177,98 @@ def get_dataloader_from_data_stage(
dataloader_drop_last=True,
)

return train_dataloader
# Case 4: MultilingualNanosets
elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Nanoset
from nanotron.data.multilingual_nanoset import MultilingualNanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = MultilingualNanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size,
random_seed=data.seed,
)

# Prepare dataloader
train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")

return dataloader


def get_valid_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
valid_split_num_samples: int,
# consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples
):

# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

# Only support Validation with MultilingualNanosets
if isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Multilingual Nanoset
from nanotron.data.multilingual_nanoset import MultilingualNanoset

with main_rank_first(trainer.parallel_context.world_pg):
valid_dataset = MultilingualNanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
valid_split_num_samples=valid_split_num_samples,
is_valid=True,
random_seed=data.seed,
)

# Prepare dataloader
valid_dataloader = build_nanoset_dataloader(
valid_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=0,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)

return valid_dataloader
else:
raise ValueError(
f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset"
)


def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}

Expand Down Expand Up @@ -219,6 +310,33 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
return dataloaders


def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}

for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size

log_rank(
f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer, stage.data, valid_split_num_samples=valid_split_num_samples
)
)
dataloaders[stage.name] = dataloader
return dataloaders


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
Expand All @@ -231,7 +349,8 @@ def get_args():

# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)
train_dataloader = get_dataloader(trainer)
valid_dataloader = get_valid_dataloader(trainer)

# Train
trainer.train(dataloader)
trainer.train(train_dataloader, valid_dataloader)
221 changes: 221 additions & 0 deletions src/nanotron/data/multilingual_nanoset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import os
import warnings
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from datatrove.utils.dataset import DatatroveFolderDataset
from nanotron import logging
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.logging import log_rank
from numba import jit

logger = logging.get_logger(__name__)


class MultilingualNanoset(torch.utils.data.Dataset):
"""
The Nanoset dataset
Args:
dataset_folders (List[str]): List of folders with tokenized datasets
dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
sequence_length (int): Sequence length of the built samples
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
"""

def __init__(
self,
dataset_folders: List[str],
sequence_length: int,
token_size: int,
train_split_num_samples: int,
valid_split_num_samples: int,
is_valid: bool = False,
dataset_weights: Union[List[float], None] = None,
random_seed: int = 1234,
) -> None:

# Checks
if isinstance(dataset_folders, str):
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
dataset_folders = [dataset_folders]

# Init
self.dataset_folders = dataset_folders
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.valid_split_num_samples = valid_split_num_samples
self.is_valid = is_valid
self.random_seed = random_seed
self.datatrove_datasets = []
for dataset_folder in self.dataset_folders:
self.datatrove_datasets.append(
DatatroveFolderDataset(
folder_path=dataset_folder,
filename_pattern=os.path.join(dataset_folder, "*.ds"),
seq_len=sequence_length,
recursive=False,
token_size=token_size,
shuffle=True,
)
)

# Build Nanoset Index
## To build the index we need the length of each dataset
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
## Set dataset weights
if (
dataset_weights is None
): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch
self.dataset_weights = normalize(self.dataset_lengths)
else:
self.dataset_weights = normalize(dataset_weights)
assert len(dataset_folders) == len(
self.dataset_weights
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts
self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples
# Assert that we have sufficient samples to build the valid split
for ds_index in range(len(self.dataset_lengths)):
assert (
self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index]
), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples."
self.train_dataset_lenghts = [
a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts)
] # Subtract the valid samples from the training dataset

if is_valid: # Valid MultilingualNanoset
self.split_num_samples = valid_split_num_samples
self.split_samples_per_epoch = valid_split_num_samples
self.num_epochs = 1
self.split_dataset_lenghts = self.valid_dataset_lenghts
self.split_dataset_offsets = self.train_dataset_lenghts

else: # Train MultilingualNanoset
self.split_num_samples = train_split_num_samples
self.split_samples_per_epoch = sum(self.train_dataset_lenghts)
self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1
self.split_dataset_lenghts = self.train_dataset_lenghts
self.split_dataset_offsets = [
0 for _ in range(len(self.dataset_lengths))
] # For training there is NO offset

self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()

self.print_nanoset_info()

def __len__(self) -> int:
"""
Returns:
int: The number of samples of the Nanoset
"""

return len(self.dataset_index)

def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"""
Returns sequence_length + 1 tokens from the memmap dataset
Args:
idx (int): The index into the dataset
Returns:
Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary
"""
dataset = self.dataset_index[idx]
dataset_sample = self.dataset_sample_index[idx]

return self.datatrove_datasets[dataset][dataset_sample]

def build_nanoset_index(self) -> np.ndarray:
"""
Build dataset index and dataset sample index
"""
# Build the dataset indexes for 1 epoch
dataset_index, dataset_sample_index = build_nanoset_index_helper(
n_samples=self.split_samples_per_epoch,
weights=self.dataset_weights,
dataset_sizes=self.split_dataset_lengths,
offsets=self.split_dataset_offsets,
)
# Shuffle the indexes the same way
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)])
# Just keep the necessary samples
dataset_index = dataset_index[: self.split_num_samples]
dataset_sample_index = dataset_sample_index[: self.split_num_samples]

return dataset_index, dataset_sample_index

def print_nanoset_info(self):

log_rank(
f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}",
logger=logger,
level=logging.INFO,
rank=0,
)

# Print samples from each dataset + weight
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
for index, sample_count in enumerate(dataset_sample_count):
log_rank(
f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
logger=logger,
level=logging.INFO,
rank=0,
)


@jit(nopython=True, cache=True)
def build_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given multiple datasets and a weighting array, build samples indexes
such that it follows those weights.
For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths
for generating the valid split
"""
# Create empty arrays for dataset indices and dataset sample indices
dataset_index = np.empty((n_samples,), dtype="uint")
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples

# Initialize buffer for number of samples used for each dataset
current_samples = np.zeros((len(weights),), dtype="long")

# Iterate over all samples
for sample_idx in range(n_samples):

# Convert sample index to float for comparison against weights
sample_idx_float = max(sample_idx, 1.0)

# Find the dataset with the highest error
errors = weights * sample_idx_float - current_samples
max_error_index = np.argmax(errors)

# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = (
current_samples[max_error_index] % dataset_sizes[max_error_index]
) + offsets[max_error_index]

# Update the total samples for the selected dataset
current_samples[max_error_index] += 1

return dataset_index, dataset_sample_index

0 comments on commit 539832a

Please sign in to comment.