Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for auto packing ratio #683

Merged
merged 26 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6d53fca
Add support for auto packing ratio
irenedea Sep 22, 2023
1c0f157
Add test
irenedea Oct 17, 2023
b32bac0
Refactor and change to generator
irenedea Oct 19, 2023
d9dcdbc
Add simple tests
irenedea Oct 19, 2023
93d7926
Add auto packing tests
irenedea Oct 19, 2023
ec71fec
Add auto packing to test_dataloader
irenedea Oct 19, 2023
239fce4
Merge branch 'main' into packing-collator
irenedea Oct 19, 2023
0db972a
use correct max leftovers to keep
irenedea Oct 19, 2023
6c321d3
Handle dataspec change
irenedea Oct 20, 2023
d6793bb
Merge branch 'main' into packing-collator
irenedea Oct 20, 2023
a852c23
Add dataloader test
irenedea Oct 20, 2023
d48fb97
Add distributed autopacking
irenedea Oct 21, 2023
8c08405
Update comments for profile_packing script refactor
irenedea Oct 21, 2023
6aab1ad
add torch cuda check
irenedea Oct 21, 2023
aeffb4b
Use 0 workers for profiling because one batch is loaded per worker an…
irenedea Oct 21, 2023
044bb00
Merge branch 'main' into packing-collator
irenedea Oct 21, 2023
96b4829
Fix code quality
irenedea Oct 24, 2023
2ad0c31
Merge branch 'main' into packing-collator
irenedea Oct 24, 2023
83e8d3a
Merge branch 'main' into packing-collator
irenedea Oct 25, 2023
f8ba32f
Merge branch 'main' into packing-collator
dakinggg Oct 28, 2023
1ee68b8
Merge branch 'main' into packing-collator
irenedea Nov 2, 2023
d88cdcc
Address PR comments
irenedea Nov 2, 2023
913c47f
Merge branch 'main' into packing-collator
irenedea Nov 2, 2023
57cb170
Set random seed for auto packing to make it deterministic
irenedea Nov 3, 2023
de6b45d
Fix typo
irenedea Nov 3, 2023
2ff88c2
Update max_leftover_bins_to_keep to keep all and remove unused variables
irenedea Nov 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from llmfoundry.data.text_data import (StreamingTextDataset,
build_text_dataloader)

from llmfoundry.data.dataloader import build_dataloader

__all__ = [
'MixtureOfDenoisersCollator',
'build_text_denoising_dataloader',
Expand All @@ -18,4 +20,5 @@
'build_text_dataloader',
'NoConcatDataset',
'ConcatTokensDataset',
'build_dataloader',
]
46 changes: 46 additions & 0 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Dataloader builder utilities."""

from composer import DataSpec
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.text_data import build_text_dataloader

from llmfoundry.data.denoising import build_text_denoising_dataloader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
device_batch_size: int) -> DataSpec:
"""Builds a dataloader from a config.

Args:
cfg (DictConfig): An omegaconf dictionary used to configure the loader.
tokenizer (PreTrainedTokenizerBase): The tokenizer that the model will use.
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
8 changes: 4 additions & 4 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils
Expand Down Expand Up @@ -387,7 +387,7 @@ def build_text_denoising_dataloader(
packing.
Select packing_ratio **carefully** based on the dataset
statistics, max_seq_len, and tolerance for discarding samples!
The packing code in `./packing.py` provides a script that can help
The script `scripts/misc/profile_packing.py` can help
you choose the best packing_ratio.
See :class:`StreamingTextDataset` for info on other standard config
options within `cfg.dataset`.
Expand Down Expand Up @@ -419,7 +419,7 @@ def build_text_denoising_dataloader(
that the dataloader will produce.

Note:
You can run the script inside `./packing.py` to quickly test the
You can use the script `scripts/misc/profile_packing.py` to quickly test the
irenedea marked this conversation as resolved.
Show resolved Hide resolved
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
Expand Down Expand Up @@ -492,7 +492,7 @@ def build_text_denoising_dataloader(
raise NotImplementedError(
'On-the-fly packing is currently only supported for decoder-only formats.'
)
collate_fn = BinPackWrapper(
collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=cfg.dataset.max_seq_len,
Expand Down
21 changes: 13 additions & 8 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +86,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
packing.
Select `packing_ratio` **carefully** based on the dataset
statistics, `max_seq_len`, and tolerance for discarding samples!
The packing code in `../packing.py` provides a script that can help
`scripts/misc/profile_packing.py` is a script that can help
you choose the best `packing_ratio`.
cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
___
Expand All @@ -106,7 +106,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
A pytorch dataloader

Note:
You can run the script inside `../packing.py` to quickly test the
You can run the script inside `scripts/misc/profile_packing.py` to quickly test the
irenedea marked this conversation as resolved.
Show resolved Hide resolved
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
Expand Down Expand Up @@ -143,7 +143,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
Expand Down Expand Up @@ -174,7 +174,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
Expand Down Expand Up @@ -367,9 +367,10 @@ def _build_hf_dataset_from_remote(


def _build_collate_fn(
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
dataset_cfg = dataloader_cfg.dataset
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
Expand All @@ -386,6 +387,10 @@ def _build_collate_fn(
'the latter to turn on packing or remove the former from the config.')
return collate_fn, device_batch_size

if packing_ratio == 'auto':
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer,
device_batch_size)

if packing_ratio == 1.0:
return collate_fn, device_batch_size
elif packing_ratio < 1.0:
Expand All @@ -396,7 +401,7 @@ def _build_collate_fn(
'On-the-fly packing is currently only supported for decoder-only formats.'
)

collate_fn = BinPackWrapper(
collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=dataset_cfg.max_seq_len,
Expand Down
Loading
Loading