Skip to content

Commit

Permalink
cleanup the old multipack dataloader (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 12, 2023
1 parent 105d0b3 commit 1a6309c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 364 deletions.
9 changes: 3 additions & 6 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import torch
import transformers
Expand All @@ -31,7 +31,6 @@
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

Expand Down Expand Up @@ -215,9 +214,7 @@ def get_train_dataloader(self) -> DataLoader:
)
return super().get_train_dataloader()

def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
Expand Down Expand Up @@ -260,7 +257,7 @@ def _get_bench_sampler(
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
Expand Down
14 changes: 10 additions & 4 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ class PromptStyle(Enum):
CHATML = "chatml"


class AlpacaPrompter:
class Prompter:
"""
Base prompter class for all prompters
"""


class AlpacaPrompter(Prompter):
"""
Base class for alpaca prompters
"""
Expand Down Expand Up @@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
"""


class ReflectAlpacaPrompter:
class ReflectAlpacaPrompter(Prompter):
"""
Prompter for ReflectAlpaca
"""
Expand Down Expand Up @@ -254,7 +260,7 @@ def __repr__(self) -> str:
)


class ShareGPTPrompter: # pylint: disable=too-few-public-methods
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
Expand Down Expand Up @@ -349,7 +355,7 @@ def __init__(
)


class UnsupportedPrompter:
class UnsupportedPrompter(Prompter):
"""
A dummy class for custom prompters
"""
Expand Down
24 changes: 12 additions & 12 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
from datasets import (
Expand Down Expand Up @@ -34,6 +34,7 @@
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter,
SummarizeTLDRPrompter,
UnsupportedPrompter,
Expand Down Expand Up @@ -90,7 +91,7 @@ def prepare_dataset(cfg, tokenizer):

def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict:
) -> Tuple[DatasetDict, List[Prompter]]:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
Expand Down Expand Up @@ -302,7 +303,7 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Any]]:
) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
Expand All @@ -311,7 +312,7 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len

tokenizer_name = tokenizer.__class__.__name__
prompters = []
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
Expand Down Expand Up @@ -445,14 +446,13 @@ def load_prepare_datasets(
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

with zero_first(is_main_process()):
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)

train_dataset = dataset["train"]
eval_dataset = dataset["test"]
Expand Down
Loading

0 comments on commit 1a6309c

Please sign in to comment.