Skip to content

Commit

Permalink
add support for defined train split (axolotl-ai-cloud#654)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 29, 2023
1 parent ff69330 commit 22a3fe9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_v2`: in context question answering (alternate)
```json
{"context": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
Expand Down Expand Up @@ -356,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca

# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
```

- loading
Expand Down
41 changes: 41 additions & 0 deletions src/axolotl/prompt_strategies/context_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def load(tokenizer, cfg):
)


def load_v2(tokenizer, cfg):
return ContextQaV2PromptTokenizingStrategy(
ContextV2Prompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)


class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
Expand All @@ -50,6 +59,38 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
)


class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
"Context: "
+ prompt["context"]
+ "\nQuestion: "
+ prompt["question"]
+ "\n",
"",
"Answer: " + prompt["answer"],
)


class ContextV2Prompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""

system_prompt = ""
system_no_input_prompt = ""

def match_prompt_style(self):
# pylint: disable=duplicate-code
self.turn_format = "{instruction}\n{input}"
self.turn_no_input_format = "{instruction}"
self.system_format = "{system}"


class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def for_d_in_datasets(dataset_configs):
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and d.train_on_split
and d.train_on_split in ds
):
ds = ds[d.train_on_split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
Expand Down

0 comments on commit 22a3fe9

Please sign in to comment.