Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming multipack for pretraining dataset #959

Merged
merged 9 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/tests-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
Expand All @@ -37,7 +36,7 @@ jobs:
images: winglian/axolotl
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and export to Docker
- name: Build Docker image
uses: docker/build-push-action@v5
with:
context: .
Expand All @@ -49,8 +48,7 @@ jobs:
file: ./docker/Dockerfile
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
- name: Unit Tests w docker image
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
58 changes: 58 additions & 0 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

sequence_len: 2048
sample_packing: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
18 changes: 14 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
Expand Down Expand Up @@ -157,7 +163,7 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
Expand Down Expand Up @@ -193,7 +199,7 @@ def _get_eval_sampler(
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
Expand Down Expand Up @@ -768,6 +774,7 @@ def build(self, total_num_steps):
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
Expand Down Expand Up @@ -808,7 +815,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(**data_collator_kwargs),
data_collator=self.build_collator(training_args, **data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -829,7 +836,10 @@ def build(self, total_num_steps):

return trainer

def build_collator(self, **kwargs):
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if training_args.pretraining:
return None

if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"input_ids": input_ids,
"labels": labels,
}


@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""

def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
99 changes: 95 additions & 4 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union

Expand All @@ -14,6 +15,7 @@
load_from_disk,
)
from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
Expand All @@ -39,11 +41,14 @@
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
process_pretraining_datasets_for_packing,
)

LOG = logging.getLogger("axolotl")
Expand All @@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]

train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
path,
tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
Expand Down Expand Up @@ -806,9 +819,27 @@ def encode_pretraining(
return ret


def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
Expand All @@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
remove_columns=dataset.features.keys(),
)
return dataset


def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)

input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]

tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
)

chunked_data = defaultdict(list)

for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)

for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))

return chunked_data
10 changes: 10 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset


def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)

train_dataset = train_dataset.filter(drop_long)
train_dataset = train_dataset.map(
add_position_ids,
)
return train_dataset


def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
Expand Down
Loading
Loading