From 8ed5bcb54aacaf231fbdf592ca5f432200ad01d5 Mon Sep 17 00:00:00 2001 From: "jinwonkim93@github.com" Date: Fri, 15 Dec 2023 03:14:40 +0000 Subject: [PATCH 1/9] [Feat] streaming multipack --- src/axolotl/utils/data.py | 70 +++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 10 ++++ tests/test_packed_pretraining.py | 87 ++++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 tests/test_packed_pretraining.py diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe..165e2c8fd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,9 +2,11 @@ import functools import hashlib import logging +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Union +import numpy as np import torch from datasets import ( Dataset, @@ -14,6 +16,7 @@ load_from_disk, ) from huggingface_hub import hf_hub_download +from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -41,9 +44,11 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.samplers.multipack import MultipackBatchSampler from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, + process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") @@ -819,3 +824,68 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): remove_columns=dataset.features.keys(), ) return dataset + + +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, + examples: List[str], + max_seq_length: int = 8192, + sample_packing_efficiency: int = 1, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_seq_length - 1, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + + input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] + attention_mask = [seq + [1] for seq in res["attention_mask"]] + + tokenized_examples = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + train_dataset = Dataset.from_dict(tokenized_examples) + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, max_seq_length + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=1, + drop_last=True, + batch_max_len=max_seq_length, + lengths=( + train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + packing_efficiency_estimate=sample_packing_efficiency, + ) + + chunked_data = defaultdict(list) + + for data in sampler: + features = train_dataset[data] + + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature].append(np.concatenate(arrays)) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature].append(np.concatenate(arrays)) + + chunked_data["labels"] = chunked_data["input_ids"].copy() + + return chunked_data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d975bb9a2..3139f5600 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset +def process_pretraining_datasets_for_packing(train_dataset, sequence_len): + drop_long = partial(drop_long_seq, sequence_len=sequence_len) + + train_dataset = train_dataset.filter(drop_long) + train_dataset = train_dataset.map( + add_position_ids, + ) + return train_dataset + + def calculate_total_num_steps(cfg, train_dataset, update=True): if not cfg.total_num_tokens: total_num_tokens = np.sum( diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py new file mode 100644 index 000000000..9dee6ed7f --- /dev/null +++ b/tests/test_packed_pretraining.py @@ -0,0 +1,87 @@ +"""Module for testing streaming dataset sequence packing""" +import math +import unittest +from functools import partial + +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.data import encode_packed_pretraining + + +class TestPacking(unittest.TestCase): + """ + Test class for packing streaming dataset sequences + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "[PAD]", + } + ) + self.max_seq_length = 8192 + self.batch_size = 6 + self.sample_packing_efficiency = 1 + self.data_collator_kwargs = { + "padding": True, + "pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64), + } + + def test_packing_stream_dataset(self): + # pylint: disable=duplicate-code + dataset = load_dataset( + "c4", + "en", + streaming=True, + )["train"] + + encode = partial( + encode_packed_pretraining, + self.tokenizer, + max_seq_length=self.max_seq_length, + sample_packing_efficiency=self.sample_packing_efficiency, + ) + + dataset = dataset.map( + encode, + batched=True, + input_columns="text", + remove_columns=dataset.features.keys(), + ) + + data_collator_fn = DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **self.data_collator_kwargs, + ) + + trainer_loader = DataLoader( + dataset, + batch_size=self.batch_size, + collate_fn=data_collator_fn, + drop_last=True, + ) + idx = 0 + for data in trainer_loader: + if idx > 10: + break + assert data["input_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["position_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["labels"].shape == (self.batch_size, self.max_seq_length) + assert data["attention_mask"].shape == ( + self.batch_size, + self.max_seq_length, + ) + idx += 1 + + +if __name__ == "__main__": + unittest.main() From da9aee1358e65965dd4972ba6b585b127f9e7d36 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 16:04:37 +0000 Subject: [PATCH 2/9] WIP make continued pretraining work w multipack --- examples/tiny-llama/pretrain.yml | 56 +++++++++++++++++++++++++++++ src/axolotl/cli/train.py | 2 ++ src/axolotl/core/trainer_builder.py | 6 ++-- src/axolotl/utils/collators.py | 26 ++++++++++++++ src/axolotl/utils/data.py | 37 ++++++++++--------- 5 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 examples/tiny-llama/pretrain.yml diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml new file mode 100644 index 000000000..b0319a95f --- /dev/null +++ b/examples/tiny-llama/pretrain.yml @@ -0,0 +1,56 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +max_steps: 200 +pretraining_dataset: c4 +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2248784df..54242dd58 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,6 +5,7 @@ from pathlib import Path import fire +import torch import transformers from axolotl.cli import ( @@ -19,6 +20,7 @@ from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") +# torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 465cfa1af..7b5fef570 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -157,7 +157,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing: + if self.args.sample_packing and False: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +193,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing: + if self.args.sample_packing and False: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -808,7 +808,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=self.build_collator(**data_collator_kwargs), + # data_collator=self.build_collator(**data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 0f0eb5a95..b4c4fa4df 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -178,3 +178,29 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "input_ids": input_ids, "labels": labels, } + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (1) * np.array(item) + for item in features[feature] + ] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item) for item in features[feature] + ] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) + diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 165e2c8fd..199d4a64a 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -42,6 +42,7 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) +from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers.multipack import MultipackBatchSampler @@ -72,6 +73,7 @@ def prepare_dataset(cfg, tokenizer): train_dataset = load_pretraining_dataset( cfg.pretraining_dataset, tokenizer, + cfg, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -811,9 +813,15 @@ def encode_pretraining( return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens) + encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size) + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name="en") dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, @@ -828,9 +836,10 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): def encode_packed_pretraining( tokenizer: PreTrainedTokenizerBase, + collate_fn, examples: List[str], - max_seq_length: int = 8192, - sample_packing_efficiency: int = 1, + max_seq_length: int = 2048, + batch_size: int = 4, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -859,33 +868,27 @@ def encode_packed_pretraining( sampler = MultipackBatchSampler( RandomSampler(train_dataset), - batch_size=1, + batch_size=batch_size, drop_last=True, - batch_max_len=max_seq_length, + batch_max_len=batch_size * max_seq_length, lengths=( train_dataset.data.column("position_ids") .to_pandas() .apply(lambda x: x[-1] + 1) .values ), - packing_efficiency_estimate=sample_packing_efficiency, ) chunked_data = defaultdict(list) for data in sampler: features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) for feature in features.keys(): if feature == "length": continue - if feature == "attention_mask": - arrays = [(1) * np.array(item) for item in features[feature]] - chunked_data[feature].append(np.concatenate(arrays)) - else: - arrays = [np.array(item) for item in features[feature]] - chunked_data[feature].append(np.concatenate(arrays)) - - chunked_data["labels"] = chunked_data["input_ids"].copy() + chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data From 36b244db2e88f718b184d314fcbbed8801937ed4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:15:43 -0500 Subject: [PATCH 3/9] fix up hadrcoding, lint --- examples/tiny-llama/pretrain.yml | 6 ++++-- src/axolotl/cli/train.py | 2 -- src/axolotl/core/trainer_builder.py | 11 +++++++++-- src/axolotl/utils/collators.py | 11 +++-------- src/axolotl/utils/data.py | 29 ++++++++++++++++++++++------- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index b0319a95f..dfd1bfca2 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -9,7 +9,9 @@ load_in_4bit: false strict: false max_steps: 200 -pretraining_dataset: c4 +pretraining_dataset: + path: c4 + name: en dataset_prepared_path: val_set_size: 0.0 output_dir: ./model-out @@ -45,7 +47,7 @@ xformers_attention: flash_attention: true warmup_steps: 10 -evals_per_epoch: +evals_per_epoch: eval_table_size: saves_per_epoch: 1 debug: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 54242dd58..2248784df 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,7 +5,6 @@ from pathlib import Path import fire -import torch import transformers from axolotl.cli import ( @@ -20,7 +19,6 @@ from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") -# torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7b5fef570..dc8b1501e 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) sample_packing: bool = field( default=False, metadata={"help": "Use sample packing for efficient training."}, @@ -157,7 +163,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +199,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -768,6 +774,7 @@ def build(self, total_num_steps): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index b4c4fa4df..b9c1c3b3c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -179,6 +179,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, } + @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ @@ -191,16 +192,10 @@ def __call__(self, features, return_tensors=None): if feature == "length": continue if feature == "attention_mask": - arrays = [ - (1) * np.array(item) - for item in features[feature] - ] + arrays = [(1) * np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) else: - arrays = [ - np.array(item) for item in features[feature] - ] + arrays = [np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) - diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 199d4a64a..0de5d4ee3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,9 +4,8 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union -import numpy as np import torch from datasets import ( Dataset, @@ -42,7 +41,7 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) -from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers.multipack import MultipackBatchSampler @@ -70,10 +69,17 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: + path = cfg.pretraining_dataset + name = None + if isinstance(dict, cfg.pretraining_dataset): + path = cfg.pretraining_dataset.path + name = cfg.pretraining_dataset.name + train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, + path, tokenizer, cfg, + name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -815,13 +821,22 @@ def encode_pretraining( def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): if cfg.sample_packing: - collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens) - encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size) + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train", name="en") + dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, From 680cbe21bb5d7f1ca996301cf9d4130b1566efc7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:22:34 -0500 Subject: [PATCH 4/9] fix dict check --- src/axolotl/utils/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 0de5d4ee3..40a330602 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -71,9 +71,9 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(dict, cfg.pretraining_dataset): - path = cfg.pretraining_dataset.path - name = cfg.pretraining_dataset.name + if isinstance(cfg.pretraining_dataset, dict): + path = cfg.pretraining_dataset["path"] + name = cfg.pretraining_dataset["name"] train_dataset = load_pretraining_dataset( path, From 789c972f87944677ed05ca05959e8c39d6442e31 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:37:37 -0500 Subject: [PATCH 5/9] update test for updated pretraining multipack code --- tests/test_packed_pretraining.py | 59 +++++++++++++++----------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 9dee6ed7f..a47e3983f 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -1,13 +1,13 @@ """Module for testing streaming dataset sequence packing""" -import math import unittest from functools import partial +import torch from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoTokenizer -from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.data import encode_packed_pretraining @@ -19,21 +19,9 @@ class TestPacking(unittest.TestCase): def setUp(self) -> None: # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "[PAD]", - } - ) - self.max_seq_length = 8192 - self.batch_size = 6 - self.sample_packing_efficiency = 1 - self.data_collator_kwargs = { - "padding": True, - "pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64), - } + self.tokenizer.pad_token = "" + self.max_seq_length = 2048 + self.batch_size = 2 def test_packing_stream_dataset(self): # pylint: disable=duplicate-code @@ -43,11 +31,19 @@ def test_packing_stream_dataset(self): streaming=True, )["train"] + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=self.max_seq_length, + ) + encode = partial( encode_packed_pretraining, self.tokenizer, + collate_fn, max_seq_length=self.max_seq_length, - sample_packing_efficiency=self.sample_packing_efficiency, + batch_size=self.batch_size, ) dataset = dataset.map( @@ -57,28 +53,27 @@ def test_packing_stream_dataset(self): remove_columns=dataset.features.keys(), ) - data_collator_fn = DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **self.data_collator_kwargs, - ) - trainer_loader = DataLoader( dataset, - batch_size=self.batch_size, - collate_fn=data_collator_fn, + batch_size=1, + collate_fn=None, drop_last=True, ) idx = 0 for data in trainer_loader: if idx > 10: break - assert data["input_ids"].shape == (self.batch_size, self.max_seq_length) - assert data["position_ids"].shape == (self.batch_size, self.max_seq_length) - assert data["labels"].shape == (self.batch_size, self.max_seq_length) - assert data["attention_mask"].shape == ( - self.batch_size, - self.max_seq_length, + assert data["input_ids"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["position_ids"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["labels"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["attention_mask"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] ) idx += 1 From 2a4924882c21c29f4d8fed72c6c05e3386f57760 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 12:16:45 -0500 Subject: [PATCH 6/9] fix hardcoded data collator fix for multipack pretraining --- src/axolotl/core/trainer_builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index dc8b1501e..26cc91ed5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -815,7 +815,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - # data_collator=self.build_collator(**data_collator_kwargs), + data_collator=self.build_collator(training_args, **data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -836,7 +836,10 @@ def build(self, total_num_steps): return trainer - def build_collator(self, **kwargs): + def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): + if training_args.pretraining: + return None + if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer) From 7c3be2ef5ea918106b64e0b762bc161859c327d4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 15:58:34 -0500 Subject: [PATCH 7/9] fix the collator to be the max length for multipack pretraining --- src/axolotl/utils/data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 40a330602..b3c7606eb 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -822,7 +822,10 @@ def encode_pretraining( def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): if cfg.sample_packing: collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( - tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens * cfg.micro_batch_size, ) encode = functools.partial( encode_packed_pretraining, From bea8beec9131669d21f389ab17dbcc12160a20ae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 16:25:34 -0500 Subject: [PATCH 8/9] don't bother with latest tag for test --- .github/workflows/tests-docker.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml index 0aba6d505..1df5db40b 100644 --- a/.github/workflows/tests-docker.yml +++ b/.github/workflows/tests-docker.yml @@ -20,7 +20,6 @@ jobs: python_version: "3.10" pytorch: 2.0.1 axolotl_extras: - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" From 5a321c3763c66c5baca055ded03af1eb5115843e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 16:42:58 -0500 Subject: [PATCH 9/9] cleanup docker build/test --- .github/workflows/tests-docker.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml index 1df5db40b..e93884e64 100644 --- a/.github/workflows/tests-docker.yml +++ b/.github/workflows/tests-docker.yml @@ -36,7 +36,7 @@ jobs: images: winglian/axolotl - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Build and export to Docker + - name: Build Docker image uses: docker/build-push-action@v5 with: context: . @@ -48,8 +48,7 @@ jobs: file: ./docker/Dockerfile tags: | ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }} - - name: Unit Tests + - name: Unit Tests w docker image run: | docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/