diff --git a/docs/nanoset.md b/docs/nanoset.md index 61393438..02649bd0 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -1,42 +1,41 @@ # Nanosets -Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. +Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches. ## Install To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor. ``` -pip install nanotron[nanosets] +pip install -e '.[nanosets]' ``` This will install the following dependencies: -- `datatrove`: To preprocess the datasets +- `transformers`: To tokenize the datasets +- `datasets`: To preprocess the datasets - `numba`: To compile helper functions in order to speed up the creation of `Nanosets` -- `transformers`: For the tokenizers ## Data pre-processing -To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets. - -The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer. +To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example:
-python3 tools/preprocess_data.py \
-       --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
-       --output-folder datasets/emotion \
-       --n-tasks 16 \
-       hf \
-       --dataset dair-ai/emotion \
+{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
+{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
 
-First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`. +The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer. -Every worker will store in `--output-folder` 3 different kind of files: -- `*.ds` Containing the tokenized documents -- `*.ds.index` Containing the bounds of each tokenized document -- `*.ds.metadata` Containing the number of tokens and tokenizer used +
+torchrun --nproc-per-node 16 tools/preprocess_data.py \
+       --input HuggingFaceH4/testing_alpaca_small \
+       --split train \
+       --column completion \
+       --output-prefix datasets/testing_alpaca_small \
+       --tokenizer-name-or-path openai-community/gpt2
+
-> [!IMPORTANT] -Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files +The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. + +The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file. ## Working with Nanosets To work with `Nanosets`, we just need to configure 1 argument: -1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: +1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it: 1. If we specify a single path, we will create a `Nanoset` from a single dataset file. ```yaml data_stages: @@ -44,7 +43,7 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 1 data: dataset: - dataset_folder: datasets/SlimPajama-6B + dataset_path: datasets/SlimPajama-6B_input_ids.npy num_loading_workers: 0 seed: 1234 ``` @@ -55,9 +54,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 15 data: dataset: - dataset_folder: - - datasets/SlimPajama-6B - - datasets/testing_alpaca_small + dataset_path: + - datasets/SlimPajama-6B_input_ids.npy + - datasets/testing_alpaca_small_input_ids.npy num_loading_workers: 0 seed: 1234 ``` @@ -68,9 +67,9 @@ To work with `Nanosets`, we just need to configure 1 argument: start_training_step: 25 data: dataset: - dataset_folder: - datasets/SlimPajama-6B: 0.8 - datasets/testing_alpaca_small: 0.2 + dataset_path: + datasets/SlimPajama-6B_input_ids.npy: 0.8 + datasets/testing_alpaca_small_input_ids.npy: 0.2 num_loading_workers: 0 seed: 1234 ``` @@ -79,14 +78,11 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml +torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml ``` ## Under the hood -`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following: -1. Creating dataset mixtures from different dataset folder paths -2. Ensure that in each epoch, we consume each sample only once -3. Ensure that we never exhaust the `DataLoader` +`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`. Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)): - `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight. diff --git a/examples/config_nanoset.yaml b/examples/config_nanoset.yaml index 127ddb5e..31f23bf0 100644 --- a/examples/config_nanoset.yaml +++ b/examples/config_nanoset.yaml @@ -7,25 +7,25 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: datasets/c4-es/tokenized + dataset_path: datasets/testing_alpaca_small_input_ids.npy num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) start_training_step: 1 - data: dataset: - dataset_folder: - - datasets/SlimPajama-6B/tokenized - - datasets/c4-es/tokenized + dataset_path: + - datasets/yelp_review_full_input_ids.npy + - datasets/testing_alpaca_small_input_ids.npy num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_folder: - datasets/SlimPajama-6B/tokenized: 0.8 - datasets/c4-es/tokenized: 0.2 + dataset_path: + datasets/testing_alpaca_small_input_ids.npy: 0.8 + datasets/yelp_review_full_input_ids.npy: 0.2 num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -57,7 +57,7 @@ model: initializer_range: 0.02 intermediate_size: 64 is_llama_config: true - max_position_embeddings: 1024 + max_position_embeddings: 256 num_attention_heads: 4 num_hidden_layers: 2 num_key_value_heads: 4 @@ -67,7 +67,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 50257 + vocab_size: 32000 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -88,11 +88,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER profiler: null @@ -105,6 +105,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 1024 + sequence_length: 128 train_steps: 200 val_check_interval: -1 diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 8eefa9c2..5c31d07f 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -18,18 +18,6 @@ pip install -r requirements.txt > https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 -## Bug related to nanotron -Encountered the following issue when ran train_mamba.sh: -``` -causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv -``` -Solved this by doing: -pip uninstall mamba-ssm -pip install causal_conv1d==1.1.1 -pip install mamba-ssm --no-cache-dir -https://github.com/state-spaces/mamba/issues/169 - - ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba diff --git a/pyproject.toml b/pyproject.toml index 6a0cfb83..e65f37a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ fast-modeling = [ nanosets = [ "transformers", - "datatrove[io,processing]@git+https://github.com/huggingface/datatrove", + "datasets", "numba", ] diff --git a/run_train.py b/run_train.py index 021d955d..b33231f4 100644 --- a/run_train.py +++ b/run_train.py @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage( elif isinstance(data.dataset, NanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) - token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 del tokenizer # Create Nanoset from nanotron.data.nanoset import Nanoset with main_rank_first(trainer.parallel_context.world_pg): train_dataset = Nanoset( - dataset_folders=data.dataset.dataset_folder, + dataset_paths=data.dataset.dataset_path, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, - token_size=token_size, + token_dtype=token_dtype, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 40b8f1dd..65201bdb 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,20 +93,25 @@ def __post_init__(self): @dataclass class NanosetDatasetsArgs: - dataset_folder: Union[str, List[str]] - dataset_weights: Optional[List[float]] = None + dataset_path: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder - self.dataset_folder = [self.dataset_folder] + if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file + self.dataset_path = [self.dataset_path] self.dataset_weights = [1] + elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights + tmp_dataset_path = self.dataset_path.copy() + self.dataset_path = list(tmp_dataset_path.keys()) + self.dataset_weights = list(tmp_dataset_path.values()) @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -140,7 +145,6 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False - save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index d92de405..57225243 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -48,9 +48,6 @@ class LlamaConfig: rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 - rope_interleaved: bool = ( - False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False - ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 7f20ad99..321ee045 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -34,8 +34,6 @@ class ParallelismArgs: tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False - tp_recompute_allgather: bool = True - expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py deleted file mode 100644 index 199527e1..00000000 --- a/src/nanotron/data/collator.py +++ /dev/null @@ -1,80 +0,0 @@ -import dataclasses -from typing import Dict, List, Union - -import numpy as np -import torch -from nanotron import distributed as dist -from nanotron.parallel.context import ParallelContext -from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer - - -@dataclasses.dataclass -class NanosetDataCollatorForCLM: - """ - Data collator used for causal language modeling with Nanosets dataset. - - - input_pp_rank: Discards last input id token - - output_pp_rank: Discards first label id token - - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. - """ - - sequence_length: int - input_pp_rank: int - output_pp_rank: int - parallel_context: ParallelContext - - def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. - current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) - if current_pp_rank not in [ - self.input_pp_rank, - self.output_pp_rank, - ]: - assert all(len(example) == 0 for example in examples) - return { - "input_ids": TensorPointer(group_rank=self.input_pp_rank), - "input_mask": TensorPointer(group_rank=self.input_pp_rank), - "label_ids": TensorPointer(group_rank=self.output_pp_rank), - "label_mask": TensorPointer(group_rank=self.output_pp_rank), - } - - # Make sure we load only what's necessary, ie we only load a `input_ids` column. - assert all(list(example.keys()) == ["input_ids"] for example in examples) - - # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? - input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) - batch_size, expanded_input_length = input_ids.shape - - result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} - - result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) - result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) - result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) - result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) - - assert ( - expanded_input_length == self.sequence_length + 1 - ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" - - # Process inputs: last token is the label - if current_pp_rank == self.input_pp_rank: - result["input_ids"] = input_ids[:, :-1] - result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) - - # Process labels: shift them to the left - if current_pp_rank == self.output_pp_rank: - result["label_ids"] = input_ids[:, 1:] - result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) - - if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: - raise ValueError( - f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" - f" {self.sequence_length}." - ) - if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: - raise ValueError( - f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" - f" {self.sequence_length}." - ) - - return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..4719c476 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,7 +1,7 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM from nanotron.dataloader import ( + DataCollatorForCLM, EmptyInfiniteDataset, get_dataloader_worker_init, get_sampler, @@ -32,7 +32,7 @@ def build_nanoset_dataloader( # No need to spawn a lot of workers, we can just use main dataloader_num_workers = 0 - data_collator = NanosetDataCollatorForCLM( + data_collator = DataCollatorForCLM( sequence_length=sequence_length, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 90200967..9d62b33d 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -1,10 +1,7 @@ -import os -import warnings from typing import Dict, List, Tuple, Union import numpy as np import torch -from datatrove.utils.dataset import DatatroveFolderDataset from nanotron import logging from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.logging import log_rank @@ -18,60 +15,49 @@ class Nanoset(torch.utils.data.Dataset): The Nanoset dataset Args: - dataset_folders (List[str]): List of folders with tokenized datasets - dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + dataset_paths (List[str]): List of paths to tokenized datasets + dataset_weights (List[float]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ sequence_length (int): Sequence length of the built samples - token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + token_dtype (Union[np.uint16, np.int32]): dtype of the tokens stored in the processed dataset files. np.uin16 for vocab sizes < 65535, np.int32 otherwise train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size """ def __init__( self, - dataset_folders: List[str], + dataset_paths: List[str], + dataset_weights: Union[List[float], None], sequence_length: int, - token_size: int, + token_dtype: Union[np.uint16, np.int32], train_split_num_samples: int, - dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, ) -> None: - # Checks - if isinstance(dataset_folders, str): - warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") - dataset_folders = [dataset_folders] - # Init - self.dataset_folders = dataset_folders + self.dataset_paths = dataset_paths + self.dataset_weights = dataset_weights self.sequence_length = sequence_length - self.token_size = token_size + self.token_dtype = token_dtype self.train_split_num_samples = train_split_num_samples self.random_seed = random_seed - self.datatrove_datasets = [] - for dataset_folder in self.dataset_folders: - self.datatrove_datasets.append( - DatatroveFolderDataset( - folder_path=dataset_folder, - filename_pattern=os.path.join(dataset_folder, "*.ds"), - seq_len=sequence_length, - recursive=False, - token_size=token_size, - shuffle=True, - ) - ) # Build Nanoset Index ## To build the index we need the length of each dataset - self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + self.dataset_lengths = [] + for dataset_path in self.dataset_paths: + self.dataset_buffer_mmap = np.memmap(dataset_path, mode="r", order="C", dtype=self.token_dtype) + self.dataset_buffer = memoryview(self.dataset_buffer_mmap) + dataset_number_of_tokens = int(len(self.dataset_buffer)) + number_of_samples = int( + (dataset_number_of_tokens - 1) / sequence_length + ) # Discard last sample if length < sequence_length + self.dataset_lengths.append(number_of_samples) ## Set dataset weights if ( - dataset_weights is None + self.dataset_weights is None ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch self.dataset_weights = normalize(self.dataset_lengths) else: self.dataset_weights = normalize(dataset_weights) - assert len(dataset_folders) == len( - self.dataset_weights - ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() @@ -93,12 +79,25 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: idx (int): The index into the dataset Returns: - Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + Dict[str, numpy.ndarray]: The input ids wrapped in a dictionary """ + dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - return self.datatrove_datasets[dataset][dataset_sample] + # Rebuild the memmap in every access to free memory + # https://stackoverflow.com/a/61472122 + self.dataset_buffer_mmap = np.memmap(self.dataset_paths[dataset], mode="r", order="C", dtype=self.token_dtype) + self.dataset_buffer = memoryview(self.dataset_buffer_mmap) + + # uint16 -> 2 bytes per token, int32 -> 4 bytes per token + offset = dataset_sample * self.sequence_length * (np.iinfo(self.token_dtype).bits / 8) + input_ids_tokens = np.frombuffer( + self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=int(offset) + ) + + # Return tokens as np.int32 as Torch can't handle uint16 + return {"input_ids": input_ids_tokens.astype(np.int32)} def build_nanoset_index(self) -> np.ndarray: """ @@ -125,6 +124,15 @@ def build_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index + def __del__(self) -> None: + """ + Clean up Nanoset + """ + + if hasattr(self, "dataset_buffer_mmap"): + self.dataset_buffer_mmap._mmap.close() + del self.dataset_buffer_mmap + def print_nanoset_info(self): log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0) @@ -133,10 +141,10 @@ def print_nanoset_info(self): ) # Print samples from each dataset + weight - dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_paths)) for index, sample_count in enumerate(dataset_sample_count): log_rank( - f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + f"> Total number of samples from the {self.dataset_paths[index].rsplit('/', 1)[-1]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..5f77100e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn @@ -74,10 +74,9 @@ def init_rotary_embeddings(self): self.freqs_cis = self.freqs_cis.to(torch.float) assert self.freqs_cis.dtype == torch.float freqs = 1.0 / ( - self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim) - ).to( - "cuda" - ) # should be computed on CPU, otherwise different results with Transformers. + self.theta + ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim) + ) t = torch.arange(self.end, device="cuda") freqs = torch.outer(t, freqs).float() complex_freqs = torch.polar(torch.ones_like(freqs), freqs) @@ -119,78 +118,6 @@ def forward( return x_out.type(dtype) -## Copy from transformers. Non interleaved version of RoPE. Will be refactored later -class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim: int, end: int, theta: float = 500000.0): - super().__init__() - self.dim = dim - self.end = end - self.theta = theta - self.init_rotary_embeddings() - - def init_rotary_embeddings(self): - inv_freq = 1.0 / ( - self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim) - ) # important to compute on CPU - self.register_buffer( - "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False - ) - self.inv_freq = self.inv_freq.to( - torch.float - ) # make it float32 before copy to avoid precision loss during copy_ - self.inv_freq.copy_(inv_freq) - - @torch.no_grad() - def forward( - self, - x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] - position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] - ): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def rotate_half(self, x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed - - class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -228,7 +155,6 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -238,6 +164,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) + # TODO @nouamane: why can't we torch.jit.script GLUActivation? self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] @@ -389,27 +316,16 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. - if config.rope_interleaved: - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) - else: - self.rotary_embedding = LlamaRotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) - self.rope_interleaved = config.rope_interleaved + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding( - dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved - ) + self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -488,16 +404,8 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - # interleaved version. - if self.rope_interleaved: - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) - # non interleaved version. - else: - cos, sin = self.rotary_embedding(value_states, position_ids) - query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) if "key" not in store: # First inference iteration (Prefill) @@ -636,7 +544,7 @@ def forward( cache_seqlens=position_offsets.contiguous(), softmax_scale=softmax_scale, causal=True, - rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention + rotary_interleaved=False, # GPT-NeoX style ) store.update( @@ -711,9 +619,9 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - + self.recompute_layer = parallel_config.recompute_layer - + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -732,12 +640,12 @@ def _core_forward( hidden_states = hidden_states + residual return hidden_states, output["sequence_mask"] - + def _checkpointed_forward( self, hidden_states: torch.Tensor, sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) def forward( @@ -745,7 +653,7 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) else: @@ -756,7 +664,6 @@ def forward( "sequence_mask": sequence_mask, } - class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -819,14 +726,7 @@ def __init__( module_input_keys={"input_ids", "input_mask"}, module_output_keys={"input_embeds"}, ) - log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0) - if config.rope_interleaved: - log_rank( - "The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers", - logger=logger, - level=logging.INFO, - rank=0, - ) + self.decoder = nn.ModuleList( [ PipelineBlock( @@ -865,7 +765,6 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, - "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index ef3b4c50..688eaa78 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -22,8 +22,6 @@ def forward( ) -# This is equivalent to LLaMA RMSNorm -# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112 class TritonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..873d77df 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -85,8 +85,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - out = DifferentiableReduceScatterSum.apply(grad_output, group) - return out, None + return DifferentiableReduceScatterSum.apply(grad_output, group), None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -114,7 +113,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): *rest_size, device=tensor.device, dtype=tensor.dtype, - requires_grad=False, + requires_grad=tensor.requires_grad, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index e2ee3a29..fdef48ac 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -20,12 +20,13 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -88,10 +89,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retrieve tensors from the forward path. + # Retreive tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as their gradient. + # All the inputs have softmax as thier gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] @@ -120,12 +121,10 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): + def forward(ctx, tensor, weight, bias, group, tp_mode): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group - ctx.tp_recompute_allgather = tp_recompute_allgather - ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -142,7 +141,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - # ctx.save_for_backward(tensor, weight) + ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -150,19 +149,14 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: - gathered_tensor = MemoryBuffer().get( - "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype - ) - else: - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, - ) + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -210,10 +204,6 @@ def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): # Wait communication handle.wait() - if tp_recompute_allgather: - ctx.save_for_backward(tensor, weight) - else: - ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -271,8 +261,8 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle1: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: + handle: Optional[dist.Work] = None + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -283,10 +273,14 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get( - "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + unsharded_tensor = torch.empty( + unsharded_batch_size, + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, ) - handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -295,6 +289,9 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) + if handle is not None: + handle.wait() + # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -306,128 +303,41 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle2: Optional[dist.Work] = None + handle: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if handle1 is not None: - handle1.wait() - # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) + grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle2 is not None: - handle2.wait() + if handle is not None: + handle.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None, None + return grad_tensor, grad_weight, grad_bias, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") -class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): - """ - Column linear with memory_buffer for the allgather, context parallel - enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and - async communication disabled. - """ - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool, - ): - - # Do allgather. - sharded_batch_size, *rest_size = input.shape - unsharded_batch_size = sharded_batch_size * group.size() - if group.size() == 1: - total_input = input.contiguous() - elif tp_recompute_allgather: - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - - # Prepare context. - ctx.group = group - ctx.tp_recompute_allgather = tp_recompute_allgather - ctx.input_size = input.shape - if tp_recompute_allgather: - ctx.save_for_backward(input, weight, bias) - else: - ctx.save_for_backward(total_input, weight, bias) - - # Get linear output. - out = F.linear(total_input, weight, bias) - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # Either allgather the inputs again or get them from context. - group = ctx.group - tp_recompute_allgather = ctx.tp_recompute_allgather - input_size = ctx.input_size - if group.size() == 1 or not tp_recompute_allgather: - total_input, weight, bias = ctx.saved_tensors - else: - input, weight, bias = ctx.saved_tensors - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.contiguous() - grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] - total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] - grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) - total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) - - # Compute gradients. - grad_weight = grad_output.T @ total_input - grad_input = grad_output @ weight - if group.size() == 1: - sub_grad_input = grad_input - else: - # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 - # We set grad_input to be contiguous in case it isn't already. - grad_input = grad_input.contiguous() - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - - return sub_grad_input, grad_weight, grad_bias, None, None - - def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -435,19 +345,18 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - return F.linear(input, weight, bias) - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( - input, weight, bias, group, tp_recompute_allgather - ) - raise ValueError(f"Got unexpected mode: {tp_mode}.") + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + input = differentiable_all_gather(input, group=group) + else: + raise ValueError(f"Got unexpected mode: {tp_mode}.") + + return F.linear(input, weight, bias) class _RowLinearAsyncCommunication(torch.autograd.Function): @@ -478,7 +387,8 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle: Optional[dist.Work] = None + handle_0: Optional[dist.Work] = None + handle_1: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -488,8 +398,12 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = MemoryBuffer().get( - "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + total_grad_output = torch.empty( + unsharded_batch_size, + *rest_size, + device=grad_output.device, + dtype=grad_output.dtype, + requires_grad=False, ) # Doing gather + slicing during the NeMo forward pass can make this tensor @@ -498,69 +412,31 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) + handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) + + grad_tensor = grad_output.matmul(weight) + + # wait for the first all_gather to finish before starting the second all_gather + if handle_0 is not None: + handle_0.wait() - # total_grad_output: [b, s, h_out] - # weight: [h_out, h_in/n] - # total_grad_tensor: [b, s, h_in/n] - # grad_output: [b/n, s, h_out] - sharded_batch_size, *rest_size_grad_output = grad_output.shape - rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] + # TODO @thomasw21: gather along another dimension + sharded_batch_size, *rest_size = grad_tensor.shape if group.size() == 1: - total_grad_tensor = grad_output.matmul(weight) + total_grad_tensor = grad_tensor else: unsharded_batch_size = sharded_batch_size * group.size() + total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size_grad_tensor, - device=grad_output.device, - dtype=grad_output.dtype, + *rest_size, + device=grad_tensor.device, + dtype=grad_tensor.dtype, requires_grad=False, ) - before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( - total_grad_tensor, - split_size_or_sections=[ - sharded_batch_size * dist.get_rank(group), - sharded_batch_size, - sharded_batch_size * (group.size() - dist.get_rank(group) - 1), - ], - dim=0, - ) - # compute local shard - torch.mm( - input=grad_output.view(-1, grad_output.shape[-1]), - mat2=weight, - out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), - ) - if handle is not None: - handle.wait() - - before_shard_grad_output, _, after_shard_grad_output = torch.split( - total_grad_output, - split_size_or_sections=[ - sharded_batch_size * dist.get_rank(group), - sharded_batch_size, - sharded_batch_size * (group.size() - dist.get_rank(group) - 1), - ], - dim=0, - ) - - # before shard compute - if before_shard_grad_tensor.numel() > 0: - torch.mm( - input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), - mat2=weight, - out=before_shard_grad_tensor.view(-1, weight.shape[1]), - ) - # after shard compute - if after_shard_grad_tensor.numel() > 0: - torch.mm( - input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), - mat2=weight, - out=after_shard_grad_tensor.view(-1, weight.shape[1]), - ) + handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -578,6 +454,9 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None + if handle_1 is not None: + handle_1.wait() + return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4c7325cd..40e89968 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,7 +51,6 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() @@ -60,7 +59,6 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size - self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -93,7 +91,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index f694b0e6..b9ac12ae 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,31 +1,11 @@ import functools -import operator import os -import torch from torch import nn from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param -from nanotron.utils import Singleton - - -class MemoryBuffer(metaclass=Singleton): - """ - Global memory buffer to store intermediate activations that need not to be cached for the backward pass. - """ - - def __init__(self): - self.buffer = {} - - def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: - required_numel = functools.reduce(operator.mul, shape, 1) - if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: - self.buffer[name, dtype] = torch.empty( - required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False - ) - return self.buffer[name, dtype][:required_numel].view(shape) def assert_cuda_max_connections_set_to_1(func): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 42613410..870f3b64 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -445,10 +445,7 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - - if self.config.checkpoints.save_final_state: - self.save_checkpoint() - + self.post_training() def training_step( @@ -703,7 +700,7 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: ) reloaded_from_checkpoint = True if not reloaded_from_checkpoint: - log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) + log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) if isinstance(self.config.model.init_method, ExistingCheckpointInit): # Initialize model from an pretrained model checkpoint self.param_shard_metadata = load_weights( diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index b3831801..14fe1ca8 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,10 +1,11 @@ import functools import inspect +import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import ContextManager, List, Optional +from typing import Callable, ContextManager, List, Optional import torch from packaging import version @@ -14,25 +15,6 @@ from nanotron import distributed as dist -class Singleton(type): - """ - Singleton metaclass. - Create objects using this class as the metaclass to enable singleton behaviour. - For instance: - ``` - class Logger(metaclass=Singleton): - ... - ``` - """ - - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` @@ -70,7 +52,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Useful to run only once per node first (e.g. to create local files, etc) + Usefull to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -141,7 +123,6 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() - def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 72deb7f5..33bb2480 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -3,7 +3,6 @@ import json import os import sys -from argparse import Namespace from collections import OrderedDict from pathlib import Path @@ -11,6 +10,8 @@ package_path = Path(package.__file__).parent.parent.parent sys.path.append(str(package_path)) +from argparse import Namespace + import nanotron.distributed as dist import torch from nanotron.data.nanoset import Nanoset @@ -22,34 +23,31 @@ def create_dataset_paths(tmp_dir: str, quantity: int): - json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}.json") for i in range(quantity)] - datatrove_tokenized_dataset_paths = [os.path.join(tmp_dir, f"tokenized_documents_{i}") for i in range(quantity)] + json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}") for i in range(quantity)] + mmap_dataset_path = [f"{path}_input_ids.npy" for path in json_dataset_path] - return json_dataset_path, datatrove_tokenized_dataset_paths + return json_dataset_path, mmap_dataset_path def create_dummy_json_dataset(path_to_json: str, dummy_text: str, n_samples: int = 50000): - with open(path_to_json, "a") as json_file: + with open(path_to_json + ".json", "a") as json_file: for sample in range(n_samples): sample_dict = {"text": f"[{sample}] Hello! Im sample {sample}! And this is my dummy text: {dummy_text}"} json_file.write(json.dumps(sample_dict)) json_file.write("\n") -def preprocess_dummy_dataset(json_dataset_path: str, datatrove_tokenized_dataset_path: str, tokenizer: str): +def preprocess_dummy_dataset(path_to_json: str, tokenizer: str): # Create args for preprocessing args = Namespace( - readers="jsonl", - dataset=json_dataset_path, + input=path_to_json + ".json", column="text", - glob_pattern=None, - output_folder=datatrove_tokenized_dataset_path, + output_prefix=path_to_json, tokenizer_name_or_path=tokenizer, - eos_token=None, - n_tasks=1, - logging_dir=None, + add_special_tokens=False, ) + # tools/preprocess_data.py main main(args) @@ -124,7 +122,7 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par IDX_SAMPLE = 23 nanoset_identifiers = OrderedDict() - nanoset_identifiers["dataset_folders"] = nanoset.dataset_folders + nanoset_identifiers["dataset_paths"] = nanoset.dataset_paths nanoset_identifiers["dataset_weights"] = nanoset.dataset_weights.tolist() nanoset_identifiers["sequence_length"] = nanoset.sequence_length nanoset_identifiers["train_split_num_samples"] = nanoset.train_split_num_samples @@ -133,7 +131,6 @@ def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: Par nanoset_identifiers["input_ids"] = nanoset[IDX_SAMPLE]["input_ids"].tolist() nanoset_identifiers["dataset_index"] = nanoset.dataset_index.tolist() nanoset_identifiers["dataset_sample_index"] = nanoset.dataset_sample_index.tolist() - nanoset_identifiers["token_size"] = nanoset.token_size unique_description_hash = compute_hash(nanoset_identifiers) assert_tensor_synced_across_pg( diff --git a/tests/nanoset/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py index 113c545c..2c3ff542 100644 --- a/tests/nanoset/test_build_nanoset_dataloader.py +++ b/tests/nanoset/test_build_nanoset_dataloader.py @@ -1,7 +1,6 @@ import sys from math import isclose from pathlib import Path -from typing import List package_path = Path(__file__).parent.parent sys.path.append(str(package_path)) @@ -34,7 +33,7 @@ for all_3d_configs in get_all_3d_configurations(gpus) ], ) -@pytest.mark.parametrize("train_steps", [500, 10000]) +@pytest.mark.parametrize("train_steps", [5, 100]) @pytest.mark.parametrize("sequence_length", [512, 8192]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() @@ -43,21 +42,16 @@ def test_build_nanoset_dataloader( ): test_context = TestContext() - # Create dataset folders - json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( - tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 - ) + # Create dataset files + json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) - # Preprocess json dataset with datatrove - for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): - preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) - init_distributed(tp=tp, dp=dp, pp=pp)(_test_build_nanoset_dataloader)( - datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, + json_paths=json_paths, + path_to_mmap_files=mmap_dataset_paths, train_steps=train_steps, sequence_length=sequence_length, tokenizer_name_or_path=tokenizer_name_or_path, @@ -66,7 +60,8 @@ def test_build_nanoset_dataloader( def _test_build_nanoset_dataloader( parallel_context: ParallelContext, - datatrove_tokenized_dataset_folders: List[str], + json_paths: str, + path_to_mmap_files: str, train_steps: int, sequence_length: int, tokenizer_name_or_path: str, @@ -76,37 +71,41 @@ def _test_build_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() + # Preprocess dummy json datasets + for json_path in json_paths: + preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) + input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_folders": [datatrove_tokenized_dataset_folders[0]], + "dataset_paths": [path_to_mmap_files[0]], "dataset_weights": [1], "sequence_length": sequence_length, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_folders": datatrove_tokenized_dataset_folders, + "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], "dataset_weights": None, "sequence_length": sequence_length, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_folders": datatrove_tokenized_dataset_folders, + "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], "dataset_weights": [8, 2], "sequence_length": sequence_length, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": train_steps * GLOBAL_BATCH_SIZE, "random_seed": SEED, } @@ -120,7 +119,7 @@ def _test_build_nanoset_dataloader( # Assert we have the same Nanoset in all ranks assert_nanoset_sync_across_all_ranks(train_dataset, parallel_context) - dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_folders)) + dataset_sample_count = count_dataset_indexes(train_dataset.dataset_index, len(train_dataset.dataset_paths)) for idx, ds_length in enumerate(train_dataset.dataset_lengths): # Assert Nanoset doesn't sample indexes greater than the datasets assert ( @@ -130,7 +129,7 @@ def _test_build_nanoset_dataloader( # Assert Nanoset builds up the correct blend WRT the dataset_weights assert isclose( normalize(dataset_sample_count).tolist()[idx], train_dataset.dataset_weights[idx], abs_tol=0.05 - ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_folders[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" + ), f"Requested Nanoset to contain {round(train_dataset.dataset_weights[idx]*100, 2)}% of samples from {train_dataset.dataset_paths[idx]} but got {round(normalize(dataset_sample_count).tolist()[idx]*100, 2)}%" # Create Dataloaders dataloader = build_nanoset_dataloader( train_dataset, @@ -163,27 +162,22 @@ def _test_build_nanoset_dataloader( for all_3d_configs in get_all_3d_configurations(gpus) ], ) -@pytest.mark.parametrize("skipped_batches", [20, 5555]) +@pytest.mark.parametrize("skipped_batches", [20, 50]) @pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"]) @rerun_if_address_is_in_use() def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches: int, tokenizer_name_or_path: str): test_context = TestContext() - # Create dataset folders - json_paths, datatrove_tokenized_dataset_folders = create_dataset_paths( - tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2 - ) + # Create dataset files + json_paths, mmap_dataset_paths = create_dataset_paths(tmp_dir=test_context.get_auto_remove_tmp_dir(), quantity=2) # Create dummy json datasets for idx, json_path in enumerate(json_paths): create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000) - # Preprocess json dataset with datatrove - for json_path, datatrove_tokenized_dataset_folder in zip(json_paths, datatrove_tokenized_dataset_folders): - preprocess_dummy_dataset(json_path, datatrove_tokenized_dataset_folder, tokenizer_name_or_path) - init_distributed(tp=tp, dp=dp, pp=pp)(_test_recover_nanoset_dataloader)( - datatrove_tokenized_dataset_folders=datatrove_tokenized_dataset_folders, + json_paths=json_paths, + path_to_mmap_files=mmap_dataset_paths, skipped_batches=skipped_batches, tokenizer_name_or_path=tokenizer_name_or_path, ) @@ -191,7 +185,8 @@ def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches: def _test_recover_nanoset_dataloader( parallel_context: ParallelContext, - datatrove_tokenized_dataset_folders: List[str], + json_paths: str, + path_to_mmap_files: str, skipped_batches: int, tokenizer_name_or_path: str, ): @@ -200,39 +195,43 @@ def _test_recover_nanoset_dataloader( N_MICRO_BATCHES_PER_BATCH = 8 GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * N_MICRO_BATCHES_PER_BATCH * parallel_context.dp_pg.size() SEQUENCE_LENGTH = 1024 - TRAIN_STEPS = 10000 + TRAIN_STEPS = 100 + + # Preprocess dummy json datasets + for json_path in json_paths: + preprocess_dummy_dataset(path_to_json=json_path, tokenizer=tokenizer_name_or_path) input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1) # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 del tokenizer # Create Nanoset configs: 1. Normal 2. Blended 3. Blended with weights nanoset_config = { - "dataset_folders": [datatrove_tokenized_dataset_folders[0]], + "dataset_paths": [path_to_mmap_files[0]], "dataset_weights": [1], "sequence_length": SEQUENCE_LENGTH, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_nanoset_config = { - "dataset_folders": datatrove_tokenized_dataset_folders, + "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], "dataset_weights": None, "sequence_length": SEQUENCE_LENGTH, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } blended_weighted_nanoset_config = { - "dataset_folders": datatrove_tokenized_dataset_folders, + "dataset_paths": [path_to_mmap_files[0], path_to_mmap_files[1]], "dataset_weights": [8, 2], "sequence_length": SEQUENCE_LENGTH, - "token_size": token_size, + "token_dtype": token_dtype, "train_split_num_samples": TRAIN_STEPS * GLOBAL_BATCH_SIZE, "random_seed": SEED, } diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 16008eaa..127ba2fa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,30 +18,17 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) -@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear( - tp: int, - dp: int, - pp: int, - tp_mode: TensorParallelLinearMode, - async_communication: bool, - tp_recompute_allgather: bool, -): +def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") - if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: - pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + tp_mode=tp_mode, async_communication=async_communication ) def _test_column_linear( - parallel_context: ParallelContext, - tp_mode: TensorParallelLinearMode, - async_communication: bool, - tp_recompute_allgather: bool, + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -57,7 +44,6 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, - tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -100,7 +86,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage + # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True @@ -164,32 +150,15 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) -@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear( - tp: int, - dp: int, - pp: int, - tp_mode: TensorParallelLinearMode, - async_communication: bool, - tp_recompute_allgather: bool, -): +def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") - if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: - pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( - tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather - ) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) -def _test_row_linear( - parallel_context: ParallelContext, - tp_mode: TensorParallelLinearMode, - async_communication: bool, - tp_recompute_allgather: bool, -): +def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 @@ -239,19 +208,14 @@ def _test_row_linear( random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - random_input.requires_grad = True + # Row linear receives as input sharded input - random_sharded_input = ( - random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] - .detach() - .clone() - ) - random_sharded_input.requires_grad = True + random_sharded_input = random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -297,16 +261,6 @@ def _test_row_linear( else: assert row_linear.bias is None - torch.testing.assert_close( - random_sharded_input.grad, - random_input.grad[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ], - ) - parallel_context.destroy() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index f3cdab70..465d22f0 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -1,21 +1,26 @@ -""" -To process HuggingFace Datasets: - python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/emotion --n-tasks 16 hf --dataset dair-ai/emotion -To process Jsonl files: - python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/c4-es --n-tasks 16 jsonl --dataset raw_datasets/c4-es-json-files -""" - import argparse +import os +import shutil +import sys + +import numpy as np +import torch.distributed as dist +from tqdm import tqdm +from transformers import AutoTokenizer -from datatrove.executor.local import LocalPipelineExecutor -from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader -from datatrove.pipeline.tokens import DocumentTokenizer +from datasets import concatenate_datasets, load_dataset def get_args(): parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", type=str, required=True, help="Path to local stored dataset or repository on the Hugging Face hub" + ) + group.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset") + parser.add_argument("--split", type=str, default="train", help="Which split of the data to process") - group = parser.add_argument_group(title="Tokenizer") + group = parser.add_argument_group(title="tokenizer") group.add_argument( "--tokenizer-name-or-path", type=str, @@ -23,54 +28,13 @@ def get_args(): help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.", ) group.add_argument( - "--eos-token", - type=str, - default=None, - help="EOS token to add after each document. Default: None", - ) - - group = parser.add_argument_group(title="Output data") - group.add_argument( - "--output-folder", type=str, required=True, help="Path to the output folder to store the tokenized documents" - ) - group = parser.add_argument_group(title="Miscellaneous configs") - group.add_argument( - "--logging-dir", - type=str, - default=None, - help="Path to a folder for storing the logs of the preprocessing step. Default: None", - ) - group.add_argument( - "--n-tasks", type=int, default=8, help="Total number of tasks to run the preprocessing step. Default: 8" - ) - # Subparsers for processing either Hugging Face datasets or jsonl files - sp = parser.add_subparsers( - dest="readers", - required=True, - description="Type of dataset to process. It can be either a Hugging Face Dataset loaded with datasets.load_data ('hf') or a .jsonl dataset ('jsonl')", - ) - - p1 = sp.add_parser(name="hf") - p1.add_argument( - "--dataset", - type=str, - required=True, - help="Path to local stored dataset or repository on the Hugging Face hub that can be loaded with datasets.load_dataset", + "--add-special-tokens", + action="store_true", + help="Whether or not to add special tokens when encoding the sequences. This will be passed to the Tokenizer", ) - p1.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") - p1.add_argument("--split", type=str, default="train", help="Which split of the data to process. Default: train") - p2 = sp.add_parser(name="jsonl") - p2.add_argument( - "--dataset", - type=str, - required=True, - help="Path to a .jsonl file or a folder containing multiple .jsonl files", - ) - p2.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") - p2.add_argument( - "--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None" - ) + group = parser.add_argument_group(title="output data") + group.add_argument("--output-prefix", type=str, required=True, help="Path to the output processed dataset file") args = parser.parse_args() @@ -78,33 +42,74 @@ def get_args(): def main(args): - # Build datatrove reader - if args.readers == "hf": - datatrove_reader = HuggingFaceDatasetReader( - dataset=args.dataset, - text_key=args.column, - dataset_options={"split": args.split}, - ) + + world_size, rank = int(os.environ["WORLD_SIZE"]), int(os.environ["RANK"]) + + # Remove stdout from all processes except main to not flood the stdout + if rank: + sys.stdout = open(os.devnull, "w") + + # Check if output directory exists + if not os.path.isdir(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))): + print(f"Creating {os.path.abspath(os.path.join(args.output_prefix, os.path.pardir))} directory...") + os.makedirs(os.path.abspath(os.path.join(args.output_prefix, os.path.pardir)), exist_ok=True) + + if args.input.endswith(".json"): # For processing JSON files (Cross compatibility with other projects) + ds = load_dataset("json", data_files=args.input) + ds = concatenate_datasets( + [ds[splits] for splits in ds.keys()] + ) # load_dataset returns DatasetDict and we want a Dataset else: - datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) - - preprocess_executor = LocalPipelineExecutor( - pipeline=[ - datatrove_reader, - DocumentTokenizer( - output_folder=args.output_folder, - tokenizer_name_or_path=args.tokenizer_name_or_path, - eos_token=args.eos_token, - shuffle=False, - max_tokens_per_file=1e9, - ), - ], - tasks=args.n_tasks, - logging_dir=args.logging_dir, + ds = load_dataset(args.input, split=args.split) + + ds = ds.shard(num_shards=world_size, index=rank, contiguous=True) + ds = ds.select_columns(args.column) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) + token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16 + + # Create tmp directory for worker outputs + tmp_folder = os.path.abspath(os.path.join(args.output_prefix, os.pardir, "tmp")) + os.makedirs(tmp_folder, exist_ok=True) + + print("Creating workers output files...") + worker_output_file = os.path.join(tmp_folder, f"worker_{rank}_input_ids.npy") + ds = ds.map( + lambda x: {"input_ids": tokenizer(x, add_special_tokens=args.add_special_tokens).input_ids}, + input_columns=args.column, + batched=True, + desc="Tokenizing Dataset", + remove_columns=[args.column], ) - preprocess_executor.run() + + worker_input_ids_file = open(worker_output_file, "wb") + for sample in ds: + np_array = np.array(sample["input_ids"], dtype=token_dtype) + worker_input_ids_file.write(np_array.tobytes(order="C")) + worker_input_ids_file.close() + + # Wait for all workers to process each shard of the Dataset + dist.barrier() + + # Only the main rank merges the worker files + if not rank: + output_file = f"{args.output_prefix}_input_ids.npy" + input_ids_file = open(output_file, "wb") + for worker_idx in tqdm(range(world_size), desc="Merging workers output files"): + worker_output_file = os.path.join(tmp_folder, f"worker_{worker_idx}_input_ids.npy") + with open(worker_output_file, "rb") as f: + shutil.copyfileobj(f, input_ids_file) + os.remove(worker_output_file) + + input_ids_file.close() + os.rmdir(tmp_folder) + print(f"Done! {args.input} processed dataset stored in {output_file}") + + else: # Close devnull stdout redirect + sys.stdout.close() if __name__ == "__main__": _args = get_args() + dist.init_process_group(backend="gloo") main(_args)