Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup the old multipack dataloader #841

Merged
merged 1 commit into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import torch
import transformers
Expand All @@ -31,7 +31,6 @@
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

Expand Down Expand Up @@ -215,9 +214,7 @@ def get_train_dataloader(self) -> DataLoader:
)
return super().get_train_dataloader()

def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
Expand Down Expand Up @@ -260,7 +257,7 @@ def _get_bench_sampler(
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
Expand Down
14 changes: 10 additions & 4 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ class PromptStyle(Enum):
CHATML = "chatml"


class AlpacaPrompter:
class Prompter:
"""
Base prompter class for all prompters
"""


class AlpacaPrompter(Prompter):
"""
Base class for alpaca prompters
"""
Expand Down Expand Up @@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
"""


class ReflectAlpacaPrompter:
class ReflectAlpacaPrompter(Prompter):
"""
Prompter for ReflectAlpaca
"""
Expand Down Expand Up @@ -254,7 +260,7 @@ def __repr__(self) -> str:
)


class ShareGPTPrompter: # pylint: disable=too-few-public-methods
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
Expand Down Expand Up @@ -349,7 +355,7 @@ def __init__(
)


class UnsupportedPrompter:
class UnsupportedPrompter(Prompter):
"""
A dummy class for custom prompters
"""
Expand Down
24 changes: 12 additions & 12 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
from datasets import (
Expand Down Expand Up @@ -34,6 +34,7 @@
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter,
SummarizeTLDRPrompter,
UnsupportedPrompter,
Expand Down Expand Up @@ -90,7 +91,7 @@ def prepare_dataset(cfg, tokenizer):

def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict:
) -> Tuple[DatasetDict, List[Prompter]]:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
Expand Down Expand Up @@ -302,7 +303,7 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Any]]:
) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
Expand All @@ -311,7 +312,7 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len

tokenizer_name = tokenizer.__class__.__name__
prompters = []
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
Expand Down Expand Up @@ -445,14 +446,13 @@ def load_prepare_datasets(
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

with zero_first(is_main_process()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we had this with context previously? Is it for caching?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, because the function it is in is only ever called in a zero_first context, this is a bit redundant.

dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)

train_dataset = dataset["train"]
eval_dataset = dataset["test"]
Expand Down
Loading