From 868202f3f8e82f274b1e6027945707e80b54d635 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 28 Sep 2023 19:38:00 -0400 Subject: [PATCH] add support for defined train split --- README.md | 10 +++++ src/axolotl/prompt_strategies/context_qa.py | 41 +++++++++++++++++++++ src/axolotl/utils/data.py | 10 +++++ 3 files changed, 61 insertions(+) diff --git a/README.md b/README.md index 3a1eb0cd7..3f1767ea4 100644 --- a/README.md +++ b/README.md @@ -250,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended): ```json {"article": "...", "question": "...", "answer": "..."} ``` +- `context_qa.load_v2`: in context question answering (alternate) + ```json + {"context": "...", "question": "...", "answer": "..."} + ``` - `context_qa.load_404`: in context question answering from an article, with default response for no answer from context ```json {"article": "...", "unanswerable_question": "..."} @@ -356,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - path: data.jsonl # or json ds_type: json # see other options below type: alpaca + + # dataset with splits, but no train split + dataset: + - path: knowrohit07/know_sql + type: context_qa.load_v2 + train_on_split: validation ``` - loading diff --git a/src/axolotl/prompt_strategies/context_qa.py b/src/axolotl/prompt_strategies/context_qa.py index f7027c7e2..f87dd8b5c 100644 --- a/src/axolotl/prompt_strategies/context_qa.py +++ b/src/axolotl/prompt_strategies/context_qa.py @@ -24,6 +24,15 @@ def load(tokenizer, cfg): ) +def load_v2(tokenizer, cfg): + return ContextQaV2PromptTokenizingStrategy( + ContextV2Prompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + class AlpacaContextPrompter(AlpacaPrompter): """ Customized system prompted for concise QA @@ -50,6 +59,38 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: ) +class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenization Strategy to combine in-context article with a question and answer + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + "Context: " + + prompt["context"] + + "\nQuestion: " + + prompt["question"] + + "\n", + "", + "Answer: " + prompt["answer"], + ) + + +class ContextV2Prompter(AlpacaPrompter): + """ + Customized system prompted for concise QA + """ + + system_prompt = "" + system_no_input_prompt = "" + + def match_prompt_style(self): + # pylint: disable=duplicate-code + self.turn_format = "{instruction}\n{input}" + self.turn_no_input_format = "{instruction}" + self.system_format = "{system}" + + class AlpacaMissingInfoContextPromptTokenizingStrategy( InstructionPromptTokenizingStrategy ): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 271379677..34a5baaff 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -247,6 +247,16 @@ def for_d_in_datasets(dataset_configs): d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None if "train" in ds: ds = ds["train"] + elif ( + isinstance(ds, DatasetDict) + and d.train_on_split + and d.train_on_split in ds + ): + ds = ds[d.train_on_split] + elif isinstance(ds, DatasetDict): + raise ValueError( + f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `" + ) if ( "input_ids" in ds.features and "attention_mask" in ds.features