Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dataset): add config to keep processed dataset in memory #1152

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# Keep dataset in memory while preprocessing
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
Expand Down
13 changes: 7 additions & 6 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,30 @@ class TokenizedPromptDataset(Dataset):
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""

def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
self.keep_in_memory = keep_in_memory
super().__init__(
self.process(dataset).data,
**kwargs,
)

def process(self, dataset):
features = dataset.features.keys()
num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())

map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
Expand All @@ -55,7 +56,7 @@ def process(self, dataset):
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
keep_in_memory=True,
keep_in_memory=self.keep_in_memory,
**map_kwargs,
)

Expand Down
25 changes: 15 additions & 10 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,11 @@ def get_dataset_wrapper(
dataset_wrapper = None
dataset_prompter = None

ds_kwargs = {
"process_count": cfg.dataset_processes,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
Expand All @@ -604,14 +609,14 @@ def get_dataset_wrapper(
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
Expand All @@ -624,7 +629,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
Expand All @@ -638,7 +643,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
Expand All @@ -652,7 +657,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
Expand All @@ -666,7 +671,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
Expand All @@ -680,7 +685,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
Expand All @@ -694,7 +699,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
Expand All @@ -708,7 +713,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
Expand All @@ -722,7 +727,7 @@ def get_dataset_wrapper(
ds_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
**ds_kwargs,
)
dataset_wrapper = ds_wrapper
else:
Expand Down