diff --git a/README.md b/README.md index a58ba5965..fcb912d5b 100644 --- a/README.md +++ b/README.md @@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - path: EleutherAI/pile name: enron_emails type: completion # format from earlier + field: text # Optional[str] default: text, field to use for completion data # huggingface repo with multiple named configurations/subsets datasets: @@ -444,6 +445,9 @@ datasets: # 'no_input_format' cannot include {input} no_input_format: "{instruction} " + # for completions datsets, uses the provided field if not `text` + field: + # axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path dataset_prepared_path: data/last_run_prepared diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index e9e567953..e62a5c20c 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -1,6 +1,7 @@ """Module to load prompt strategies.""" import importlib +import inspect from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig @@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg): load_kwargs = {} if strategy == "user_defined": load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) + else: + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg return func(tokenizer, cfg, **load_kwargs) except Exception: # pylint: disable=broad-exception-caught return None diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py new file mode 100644 index 000000000..ee5b4cb3e --- /dev/null +++ b/src/axolotl/prompt_strategies/completion.py @@ -0,0 +1,20 @@ +""" +Basic completion text +""" +from typing import Any, Dict, Optional + +from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy +from axolotl.prompters import CompletionPrompter + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + strat = CompletionPromptTokenizingStrategy( + CompletionPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "field" in ds_cfg: + strat.field = ds_cfg["field"] + + return strat diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index ed32ab24a..b1aaeb350 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): Tokenizing strategy for Completion prompts. """ + _field: str = "text" + + @property + def field(self) -> str: + return self._field + + @field.setter + def field(self, new_field: str): + self._field = new_field + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt[self.field], + "", + "", + ) + def tokenize_prompt(self, prompt): - full_prompt = self._build_full_prompt(prompt["text"], None, None) + ( + instruction, + _, + _, + ) = self.parse_instruction_fields(prompt) + + full_prompt = self._build_full_prompt(instruction, None, None) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 7ad8b34ee..66d207374 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -22,7 +22,6 @@ AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, AlpacaReflectionPTStrategy, - CompletionPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, @@ -31,7 +30,6 @@ ) from axolotl.prompters import ( AlpacaPrompter, - CompletionPrompter, GPTeacherPrompter, JeopardyPrompter, MultipleChoiceConcisePrompter, @@ -327,15 +325,6 @@ def for_d_in_datasets(dataset_configs): ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) - elif d_base_type == "completion": - ds_strategy = CompletionPromptTokenizingStrategy( - CompletionPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) else: suffix = "" if ":load_" in d.type: