diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 81307b6..86f504a 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,6 +5,7 @@ import logging from pathlib import Path import fire +import torch import transformers from axolotl.cli import ( @@ -18,6 +19,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") +torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 165e2c8..282c933 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import hashlib import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -72,6 +72,7 @@ def prepare_dataset(cfg, tokenizer): train_dataset = load_pretraining_dataset( cfg.pretraining_dataset, tokenizer, + cfg, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -811,9 +812,14 @@ def encode_pretraining( return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + encode = functools.partial(encode_packed_pretraining, tokenizer, max_seq_length=max_tokens, sample_packing_efficiency=1.0) + cfg.sample_packing = False + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name="en") dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, @@ -829,8 +835,8 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): def encode_packed_pretraining( tokenizer: PreTrainedTokenizerBase, examples: List[str], - max_seq_length: int = 8192, - sample_packing_efficiency: int = 1, + max_seq_length: int = 2048, + sample_packing_efficiency: int = 1.0, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -859,7 +865,7 @@ def encode_packed_pretraining( sampler = MultipackBatchSampler( RandomSampler(train_dataset), - batch_size=1, + batch_size=4, drop_last=True, batch_max_len=max_seq_length, lengths=(