Skip to content

Commit

Permalink
Expand Finetuning Dataset Registry to Match One-Shot (#1940)
Browse files Browse the repository at this point in the history
* add finetuning README

* add finetuning datasets
  • Loading branch information
Satrat committed Jan 8, 2024
1 parent 9b8dfca commit 8683a06
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@

from .base import TextGenerationDataset
from .c4 import C4Dataset
from .evolcodealpaca import EvolCodeAlpacaDataset
from .gsm8k import GSM8KDataset
from .open_platypus import OpenPlatypusDataset
from .ptb import PtbDataset
from .ultrachat_200k import UltraChatDataset
from .wikitext import WikiTextDataset
74 changes: 74 additions & 0 deletions src/sparseml/transformers/finetune/data/evolcodealpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Optional

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_helpers import get_raw_dataset


@TextGenerationDataset.register(name="evolcodealpaca")
class EvolCodeAlpacaDataset(TextGenerationDataset):
"""
Child text generation class for the Evol Code Alpaca dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

EVOL_ALPACA_TEMPLATE = (
"Below is an instruction that describes a "
"programming task. Write a program that appropriately "
"completes the request.\n\n### Instruction:\n{instruction}"
"\n\n### Response:\n"
)

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "theblackcat102/evol-codealpaca-v1"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
"""
Load the raw dataset from Hugging Face, using cached copy if available.
Additionally reformats the entries to fit the alpaca template.
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
raw_dataset = get_raw_dataset(
self.data_args, cache_dir, split=self.split, **self.raw_kwargs
)

# helper fn for restructuring each dataset entry using the alpaca template
def restructure_fn(sample):
sample["text"] = self.EVOL_ALPACA_TEMPLATE.format(
instruction=sample["instruction"]
)
if "output" in sample:
sample["text"] += sample["output"]
return sample

raw_dataset = raw_dataset.map(
restructure_fn,
batched=False,
remove_columns=["output", "instruction"],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring Evol Code Alpaca Dataset",
)
return raw_dataset
68 changes: 68 additions & 0 deletions src/sparseml/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Optional

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_helpers import get_raw_dataset


@TextGenerationDataset.register(name="gsm8k")
class GSM8KDataset(TextGenerationDataset):
"""
Child text generation class for the Grade School Math 8k dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

GSM_TEMPLATE = "Question: {question}.\nAnswer:"

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "gsm8k"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
"""
Load the raw dataset from Hugging Face, using cached copy if available.
Additionally reformats the entries to fit the alpaca template.
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
raw_dataset = get_raw_dataset(
self.data_args, cache_dir, split=self.split, **self.raw_kwargs
)

# helper fn for restructuring each dataset entry using the gsm template
def restructure_fn(sample):
sample["text"] = self.GSM_TEMPLATE.format(question=sample["question"])
if "answer" in sample:
sample["text"] += " " + sample["answer"]
return sample

raw_dataset = raw_dataset.map(
restructure_fn,
batched=False,
remove_columns=["question", "answer"],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring GSM Dataset",
)
return raw_dataset
38 changes: 38 additions & 0 deletions src/sparseml/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from sparseml.transformers.finetune.data import TextGenerationDataset


@TextGenerationDataset.register(name="ptb")
class PtbDataset(TextGenerationDataset):
"""
Child text generation class for the PTB dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "ptb_text_only"
super().__init__(
text_column="sentence",
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
90 changes: 90 additions & 0 deletions src/sparseml/transformers/finetune/data/ultrachat_200k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Optional

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_helpers import get_raw_dataset


@TextGenerationDataset.register(name="ultrachat_200k")
class UltraChatDataset(TextGenerationDataset):
"""
Child text generation class for the Ultra Chat 200k dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

DEFAULT_CHAT_TEMPLATE = (
"{% for message in messages %}\n"
"{% if message['role'] == 'user' %}\n"
"{{ '<|user|>\n' + message['content'] + eos_token }}\n"
"{% elif message['role'] == 'system' %}\n"
"{{ '<|system|>\n' + message['content'] + eos_token }}\n"
"{% elif message['role'] == 'assistant' %}\n"
"{{ '<|assistant|>\n' + message['content'] + eos_token }}\n"
"{% endif %}\n"
"{% if loop.last and add_generation_prompt %}\n"
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "HuggingFaceH4/ultrachat_200k"
super().__init__(
text_column="messages",
data_args=data_args,
split=split,
tokenizer=tokenizer,
)

if (
not hasattr(self.tokenizer, "chat_template")
or self.tokenizer.chat_template is None
):
self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE

def get_raw_dataset(self, cache_dir: Optional[str] = None):
"""
Load the raw dataset from Hugging Face, using cached copy if available.
Additionally reformats the entries to fit the alpaca template.
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
raw_dataset = get_raw_dataset(
self.data_args, cache_dir, split=self.split, **self.raw_kwargs
)

# helper fn for restructuring each dataset entry using the chat template
def restructure_fn(sample):
if sample["messages"][0]["role"] != "system":
sample["messages"].insert(0, {"role": "system", "content": ""})

sample["messages"] = self.tokenizer.apply_chat_template(
sample["messages"], tokenize=False, add_generation_prompt=False
)
return sample

raw_dataset = raw_dataset.map(
restructure_fn,
batched=False,
remove_columns=[],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring Ultra Chat Dataset",
)
return raw_dataset
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_concatenation_tokenization(tiny_llama_tokenizer):
wiki_manager = TextGenerationDataset.load_from_registry(
data_args.dataset_name,
data_args=data_args,
split="train[:5%]",
split="train[:2%]",
tokenizer=tiny_llama_tokenizer,
)
raw_dataset = wiki_manager.get_raw_dataset()
Expand Down Expand Up @@ -107,3 +107,64 @@ def test_dataset_kwargs_and_percentages(tiny_llama_tokenizer):
raw_dataset_b = c4_manager_b.get_raw_dataset()

assert len(raw_dataset_b) == 2 * len(raw_dataset_a)


@pytest.mark.usefixtures("tiny_llama_tokenizer")
@pytest.mark.parametrize(
"dataset_key,dataset_config,split,do_concat",
[
("ptb", "penn_treebank", "train[:5%]", False),
("gsm8k", "main", "train[:5%]", True),
("ultrachat_200k", "default", "train_sft[:2%]", False),
],
)
def test_datasets(tiny_llama_tokenizer, dataset_key, dataset_config, split, do_concat):
data_args = DataTrainingArguments(
dataset_name=dataset_key,
dataset_config_name=dataset_config,
concatenate_data=do_concat,
)
manager = TextGenerationDataset.load_from_registry(
data_args.dataset_name,
data_args=data_args,
split=split,
tokenizer=tiny_llama_tokenizer,
)
raw_dataset = manager.get_raw_dataset()
assert len(raw_dataset) > 0
assert raw_dataset.split == split
assert raw_dataset.info.config_name == dataset_config

tokenized_dataset = manager.tokenize_and_process(raw_dataset)
assert "input_ids" in tokenized_dataset.features
assert "labels" in tokenized_dataset.features
for i in range(len(tokenized_dataset)):
if do_concat:
assert len(tokenized_dataset[i]["input_ids"]) == manager.max_seq_length
else:
assert len(tokenized_dataset[i]["input_ids"]) <= manager.max_seq_length


@pytest.mark.skip("Dataset load broken on Hugging Face")
@pytest.mark.usefixtures("tiny_llama_tokenizer")
def test_evol(tiny_llama_tokenizer):
data_args = DataTrainingArguments(
dataset_name="evolcodealpaca",
dataset_config_name=None,
concatenate_data=False,
)
evol_manager = TextGenerationDataset.load_from_registry(
data_args.dataset_name,
data_args=data_args,
split="train[:2%]",
tokenizer=tiny_llama_tokenizer,
)
raw_dataset = evol_manager.get_raw_dataset()
assert len(raw_dataset) > 0
assert raw_dataset.split == "train[:2%]"

tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset)
assert "input_ids" in tokenized_dataset.features
assert "labels" in tokenized_dataset.features
for i in range(len(tokenized_dataset)):
assert len(tokenized_dataset[i]["input_ids"]) <= evol_manager.max_seq_length

0 comments on commit 8683a06

Please sign in to comment.