Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate max packed sequence len #1141

Merged
merged 6 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,6 @@ sequence_len: 2048
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# Max sequence length to concatenate training samples together up to
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on.
Expand Down
12 changes: 0 additions & 12 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,6 @@ def validate_config(cfg):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if cfg.max_packed_sequence_len and cfg.sample_packing:
winglian marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
)
if cfg.max_packed_sequence_len:
LOG.warning(
str(
PendingDeprecationWarning(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
)
)
)

if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
Expand Down
130 changes: 12 additions & 118 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
Expand Down Expand Up @@ -74,6 +74,11 @@ def prepare_dataset(cfg, tokenizer):
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]
elif isinstance(cfg.pretraining_dataset, list) and isinstance(
winglian marked this conversation as resolved.
Show resolved Hide resolved
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]

train_dataset = load_pretraining_dataset(
path,
Expand All @@ -88,11 +93,6 @@ def prepare_dataset(cfg, tokenizer):
eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters

with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
)

if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
Expand Down Expand Up @@ -382,6 +382,9 @@ def for_d_in_datasets(dataset_configs):
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)

dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)

if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
Expand Down Expand Up @@ -419,119 +422,9 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len

tokenizer_name = tokenizer.__class__.__name__
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|"
+ tokenizer_name
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(default_dataset_prepared_path) / ds_hash
)

dataset = None
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
pass

if dataset:
...
elif (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared packed dataset loaded from disk...")
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)

if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed)

constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
dataset = Dataset.from_list(list(constant_len_dataset))

# filter out bad data
# TODO convert to dataset.filter(...)
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) <= cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)

if cfg.local_rank == 0:
LOG.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)

if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
LOG.info(
Expand Down Expand Up @@ -872,6 +765,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
Expand Down
6 changes: 1 addition & 5 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,7 @@ def load_model(
LOG.info("patching with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not inference
):
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

LOG.info("patching _expand_mask")
Expand Down
21 changes: 10 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights)


@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)


def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
Expand All @@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0


@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
Expand Down Expand Up @@ -226,8 +226,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sampler=RandomSampler(train_dataset),
batch_size=cfg.micro_batch_size,
drop_last=True,
batch_max_len=cfg.micro_batch_size
* (cfg.max_packed_sequence_len or cfg.sequence_len),
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
lengths=get_dataset_lengths(train_dataset),
)

Expand Down
23 changes: 0 additions & 23 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,6 @@ def test_adamw_hyperparams(self):
validate_config(cfg)

def test_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
winglian marked this conversation as resolved.
Show resolved Hide resolved
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)

cfg = DictDefault(
{
"sample_packing": True,
Expand All @@ -352,16 +339,6 @@ def test_packing(self):
for record in self._caplog.records
)

cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)

@pytest.mark.skipif(
is_torch_bf16_gpu_available(),
reason="test should only run on gpus w/o bf16 support",
Expand Down