diff --git a/README.md b/README.md index 16dfe6991..596641219 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,6 @@ Features: - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [Config](#config) - [Train](#train) - - [Training w/ Deepspeed](#training-with-deepspeed) - [Inference](#inference) - [Merge LORA to Base](#merge-lora-to-base) - [Common Errors](#common-errors-) @@ -824,14 +823,41 @@ Run accelerate launch -m axolotl.cli.train your_config.yml ``` -#### Multi-GPU +#### Preprocess dataset + +You can optionally pre-tokenize dataset with the following before finetuning. +This is recommended for large datasets. + +- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. +- Use `--debug` to see preprocessed examples. -You can optionally pre-tokenize dataset with the following before finetuning: ```bash -CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only +python -m axolotl.cli.preprocess your_config.yml ``` -##### Config +#### Multi-GPU + +Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed +is the recommended multi-GPU option currently because FSDP may experience +[loss instability](https://github.com/huggingface/transformers/issues/26498). + +##### DeepSpeed + +Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you +might typically be able to fit into your GPU's VRAM. More information about the various optimization types +for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated + +We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3. + +```yaml +deepspeed: deepspeed/zero1.json +``` + +```shell +accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json +``` + +##### FSDP - llama FSDP ```yaml @@ -856,24 +882,6 @@ wandb_run_id: wandb_log_model: ``` -### Training with Deepspeed - -Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you -might typically be able to fit into your GPU's VRAM. More information about the various optimization types -for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated - -We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3. - -```shell -accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json -``` - -or - -```yaml -deepspeed: deepspeed/zero1.json -``` - ### Inference Pass the appropriate flag to the train command: diff --git a/scripts/finetune.py b/scripts/finetune.py index 118a97b84..d5bbcaf8f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs): shard(cfg=parsed_cfg, cli_args=parsed_cli_args) else: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - if parsed_cli_args.prepare_ds_only: - return train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 07a6209e4..27d5df386 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -222,7 +222,9 @@ def load_datasets( ) -> TrainDatasetMeta: tokenizer = load_tokenizer(cfg) - train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) + train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + cfg, tokenizer + ) if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...") @@ -238,6 +240,10 @@ def load_datasets( text_only=cli_args.debug_text_only, ) + LOG.info("printing prompters...") + for prompter in prompters: + LOG.info(prompter) + return TrainDatasetMeta( train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py new file mode 100644 index 000000000..e0eeea6b3 --- /dev/null +++ b/src/axolotl/cli/preprocess.py @@ -0,0 +1,53 @@ +""" +CLI to run training on a model +""" +import logging +from pathlib import Path + +import fire +import transformers +from colorama import Fore + +from axolotl.cli import ( + check_accelerate_default_config, + check_user_token, + load_cfg, + load_datasets, + print_axolotl_text_art, +) +from axolotl.common.cli import PreprocessCliArgs +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH + +LOG = logging.getLogger("axolotl.cli.preprocess") + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + check_accelerate_default_config() + check_user_token() + parser = transformers.HfArgumentParser((PreprocessCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + if not parsed_cfg.dataset_prepared_path: + msg = ( + Fore.RED + + "preprocess CLI called without dataset_prepared_path set, " + + f"using default path: {DEFAULT_DATASET_PREPARED_PATH}" + + Fore.RESET + ) + LOG.warning(msg) + parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH + + _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + LOG.info( + Fore.GREEN + + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + + Fore.RESET + ) + + +if __name__ == "__main__": + fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index b49cbc6b6..1e6fbc320 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -6,7 +6,6 @@ import fire import transformers -from colorama import Fore from axolotl.cli import ( check_accelerate_default_config, @@ -16,7 +15,6 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path: - msg = ( - Fore.RED - + "--prepare_ds_only called without dataset_prepared_path set." - + Fore.RESET - ) - LOG.warning(msg) - parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - if parsed_cli_args.prepare_ds_only: - return train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 62f2b1061..c8aea4a71 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -25,11 +25,22 @@ class TrainerCliArgs: debug_num_examples: int = field(default=5) inference: bool = field(default=False) merge_lora: bool = field(default=False) - prepare_ds_only: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) +@dataclass +class PreprocessCliArgs: + """ + dataclass representing arguments for preprocessing only + """ + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: Optional[str] = field(default=None) + + def load_model_and_tokenizer( *, cfg: DictDefault, diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 23ea38da0..fe4f3b62f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -245,6 +245,7 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: raise NotImplementedError def tokenize_prompt(self, prompt): + # pylint: disable=duplicate-code ( instruction, input, # pylint: disable=redefined-builtin diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 7cd89886a..2839c946e 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -4,10 +4,12 @@ from enum import Enum from typing import Generator, Optional, Union +from colorama import Fore from fastchat.conversation import Conversation, get_conv_template LOG = logging.getLogger("axolotl") IGNORE_TOKEN_ID = -100 +REPR_TEMPLATE = "\n\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n\n" class PromptStyle(Enum): @@ -55,20 +57,15 @@ def match_prompt_style(self): ) self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" - def build_prompt( - self, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - ) -> Generator[str, None, None]: + def _build_result(self, instruction, input_text, output): # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. - if input: + if input_text: res = ( self.system_format.format(system=self.system_prompt) if self.system_prompt else "" - ) + self.turn_format.format(instruction=instruction, input=input) + ) + self.turn_format.format(instruction=instruction, input=input_text) else: res = ( self.system_format.format(system=self.system_no_input_prompt) @@ -77,7 +74,21 @@ def build_prompt( ) + self.turn_no_input_format.format(instruction=instruction) if output: res = f"{res}{output}" - yield res + + return res + + def build_prompt( + self, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + yield self._build_result(instruction, input, output) + + def __repr__(self) -> str: + return REPR_TEMPLATE.format( + full_prompt=self._build_result("{instruction}", "{input}", "{output}") + ) class UnpromptedPrompter(AlpacaPrompter): @@ -191,14 +202,14 @@ def match_prompt_style(self): ) self.response_split = "ASSISTANT:" - def build_prompt( + def _build_result( self, instruction: str, input: Union[None, str] = None, # pylint: disable=redefined-builtin output: Union[None, str] = None, reflection: Union[None, str] = None, corrected: Union[None, str] = None, - ) -> Generator[str, None, None]: + ): # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: @@ -212,7 +223,30 @@ def build_prompt( corrected=corrected, ) res = f"{res}{label}" - yield res + + return res + + def build_prompt( + self, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + reflection: Union[None, str] = None, + corrected: Union[None, str] = None, + ) -> Generator[str, None, None]: + # pylint: disable=duplicate-code + yield self._build_result( + instruction, + input, + output, + reflection, + corrected, + ) + + def __repr__(self) -> str: + return REPR_TEMPLATE.format( + full_prompt=self._build_result("{instruction}", "{input}", "{output}") + ) SHAREGPT_ASSERTION_FAILED_ROLE = ( @@ -247,7 +281,7 @@ def __init__( if role_key_model: self.role_key_model = role_key_model - def build_prompt(self, source) -> Generator[str, None, None]: + def _build_result(self, source): if len(source) < 2: # If there isn't a back and forth conversation, ignore it # also happens on the data splitting leaving empty conversations @@ -282,11 +316,20 @@ def build_prompt(self, source) -> Generator[str, None, None]: LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") conv.append_message(role, sentence["value"]) - for part in conv.get_turns(): + return conv.get_turns() + + def build_prompt(self, source) -> Generator[str, None, None]: + turns = self._build_result(source) + + for part in turns: if part[0] and not part[1]: LOG.warning(f"role with empty message: {part[0]}") yield part + def __repr__(self) -> str: + turns = self._build_result([{"from": "{from}", "value": "{value}"}]) + return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns]) + class ShareGPTPrompterV2(ShareGPTPrompter): """ @@ -304,3 +347,15 @@ def __init__( role_key_human=role_key_human, role_key_model=role_key_model, ) + + +class UnsupportedPrompter: + """ + A dummy class for custom prompters + """ + + def __init__(self) -> None: + pass + + def __repr__(self): + return "Pre-tokenized or custom dataset types are unsupported for logging" diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 99697de32..124b607b3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import hashlib import logging from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch from datasets import ( @@ -36,6 +36,7 @@ MultipleChoiceExplainPrompter, ReflectAlpacaPrompter, SummarizeTLDRPrompter, + UnsupportedPrompter, ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first @@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: def prepare_dataset(cfg, tokenizer): + prompters = [] if not cfg.pretraining_dataset: with zero_first(is_main_process()): - train_dataset, eval_dataset = load_prepare_datasets( + train_dataset, eval_dataset, prompters = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: @@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") eval_dataset = None - return train_dataset, eval_dataset, cfg.max_steps + return train_dataset, eval_dataset, cfg.max_steps, prompters with zero_first(is_main_process()): train_dataset, eval_dataset = process_datasets_for_packing( @@ -83,7 +85,7 @@ def prepare_dataset(cfg, tokenizer): LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) - return train_dataset, eval_dataset, total_num_steps + return train_dataset, eval_dataset, total_num_steps, prompters def load_tokenized_prepared_datasets( @@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets( else Path(default_dataset_prepared_path) / ds_hash ) dataset = None + prompters = [] use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: @@ -147,13 +150,13 @@ def for_d_in_datasets(dataset_configs): yield dataset # pylint: disable=invalid-name - for d in for_d_in_datasets(cfg.datasets): + for config_dataset in for_d_in_datasets(cfg.datasets): ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: load_dataset( - d.path, - name=d.name, + config_dataset.path, + name=config_dataset.name, streaming=True, token=use_auth_token, ) @@ -162,33 +165,33 @@ def for_d_in_datasets(dataset_configs): pass # prefer local dataset, even if hub exists - local_path = Path(d.path) + local_path = Path(config_dataset.path) if local_path.exists(): if local_path.is_dir(): # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` ds = load_dataset( - d.path, - name=d.name, - data_files=d.data_files, + config_dataset.path, + name=config_dataset.name, + data_files=config_dataset.data_files, streaming=False, split=None, ) elif local_path.is_file(): ds_type = "json" - if d.ds_type: - ds_type = d.ds_type - elif ".parquet" in d.path: + if config_dataset.ds_type: + ds_type = config_dataset.ds_type + elif ".parquet" in config_dataset.path: ds_type = "parquet" - elif ".arrow" in d.path: + elif ".arrow" in config_dataset.path: ds_type = "arrow" - elif ".csv" in d.path: + elif ".csv" in config_dataset.path: ds_type = "csv" - elif ".txt" in d.path: + elif ".txt" in config_dataset.path: ds_type = "text" ds = load_dataset( ds_type, - name=d.name, - data_files=d.path, + name=config_dataset.name, + data_files=config_dataset.path, streaming=False, split=None, ) @@ -198,25 +201,25 @@ def for_d_in_datasets(dataset_configs): ) elif ds_from_hub: ds = load_dataset( - d.path, - name=d.name, + config_dataset.path, + name=config_dataset.name, streaming=False, - data_files=d.data_files, + data_files=config_dataset.data_files, token=use_auth_token, ) else: - if isinstance(d.data_files, str): + if isinstance(config_dataset.data_files, str): fp = hf_hub_download( - repo_id=d.path, + repo_id=config_dataset.path, repo_type="dataset", - filename=d.data_files, + filename=config_dataset.data_files, ) - elif isinstance(d.data_files, list): + elif isinstance(config_dataset.data_files, list): fp = [] - for file in d.data_files: + for file in config_dataset.data_files: fp.append( hf_hub_download( - repo_id=d.path, + repo_id=config_dataset.path, repo_type="dataset", filename=file, ) @@ -226,21 +229,27 @@ def for_d_in_datasets(dataset_configs): "data_files must be either a string or list of strings" ) ds = load_dataset( - "json", name=d.name, data_files=fp, streaming=False, split=None + "json", + name=config_dataset.name, + data_files=fp, + streaming=False, + split=None, ) if not ds: raise ValueError("unhandled dataset load") # support for using a subset of the data - if d.shards: + if config_dataset.shards: if "train" in ds: ds = ds.shuffle(seed=seed)["train"].shard( - num_shards=d.shards, index=0 + num_shards=config_dataset.shards, index=0 ) else: - ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0) + ds = ds.shuffle(seed=seed).shard( + num_shards=config_dataset.shards, index=0 + ) d_base_type = d_prompt_style = None - d_type = d.type + d_type = config_dataset.type if isinstance(d_type, str): d_type_split = d_type.split(":") d_base_type = d_type_split[0] @@ -249,108 +258,26 @@ def for_d_in_datasets(dataset_configs): ds = ds["train"] elif ( isinstance(ds, DatasetDict) - and d.train_on_split - and d.train_on_split in ds + and config_dataset.train_on_split + and config_dataset.train_on_split in ds ): - ds = ds[d.train_on_split] + ds = ds[config_dataset.train_on_split] elif isinstance(ds, DatasetDict): raise ValueError( - f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `" - ) - if ( - "input_ids" in ds.features - and "attention_mask" in ds.features - and "labels" in ds.features - ): - # dataset is already tokenized, just drop it straight in - datasets.append(ds) - elif isinstance(d.type, DictDefault): - ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict()) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif ds_strategy := load(d.type, tokenizer, cfg, d): - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "alpaca": - ds_strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "explainchoice": - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - MultipleChoiceExplainPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "concisechoice": - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - MultipleChoiceConcisePrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "summarizetldr": - ds_strategy = SummarizeTLDRPromptTokenizingStrategy( - SummarizeTLDRPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "jeopardy": - ds_strategy = JeopardyPromptTokenizingStrategy( - JeopardyPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "oasst": - ds_strategy = OpenAssistantPromptTokenizingStrategy( - AlpacaPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "gpteacher": - ds_strategy = GPTeacherPromptTokenizingStrategy( - GPTeacherPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - elif d_base_type == "reflection": - ds_strategy = AlpacaReflectionPTStrategy( - ReflectAlpacaPrompter(d_prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) - datasets.append(ds_wrapper) - else: - suffix = "" - if ":load_" in d.type: - suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?" - LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}") - raise ValueError( - f"unhandled prompt tokenization strategy: {d.type} {suffix}" + f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" ) + + dataset_wrapper, dataset_prompter = get_dataset_wrapper( + config_dataset=config_dataset, + dataset=ds, + tokenizer=tokenizer, + cfg=cfg, + d_base_type=d_base_type, + d_prompt_style=d_prompt_style, + ) + datasets.append(dataset_wrapper) + prompters.append(dataset_prompter) + LOG.info("merging datasets") dataset = concatenate_datasets(datasets) @@ -368,14 +295,14 @@ def for_d_in_datasets(dataset_configs): f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True ) - return dataset + return dataset, prompters def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, -) -> Tuple[Dataset, Dataset]: +) -> Tuple[Dataset, Dataset, List[Any]]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -384,6 +311,7 @@ def load_prepare_datasets( ) # make sure we don't accidentally set it larger than sequence_len tokenizer_name = tokenizer.__class__.__name__ + prompters = [] if cfg.max_packed_sequence_len is not None: # see if we can go ahead and load the stacked dataset seed = f"@{str(cfg.seed)}" if cfg.seed else "" @@ -439,7 +367,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True ) else: - dataset = load_tokenized_prepared_datasets( + dataset, prompters = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path ) @@ -481,7 +409,7 @@ def load_prepare_datasets( private=True, ) else: - dataset = load_tokenized_prepared_datasets( + dataset, prompters = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path ) @@ -532,7 +460,124 @@ def load_prepare_datasets( train_dataset = dataset eval_dataset = None - return train_dataset, eval_dataset + return train_dataset, eval_dataset, prompters + + +def get_dataset_wrapper( + config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style +): + dataset_wrapper = None + dataset_prompter = None + + if ( + "input_ids" in dataset.features + and "attention_mask" in dataset.features + and "labels" in dataset.features + ): + # dataset is already tokenized, just drop it straight in + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = dataset + elif isinstance(config_dataset.type, DictDefault): + ds_strategy = load( + "user_defined", tokenizer, cfg, config_dataset.type.to_dict() + ) + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + elif d_base_type == "alpaca": + dataset_prompter = AlpacaPrompter(d_prompt_style) + ds_strategy = AlpacaPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "explainchoice": + dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "concisechoice": + dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "summarizetldr": + dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) + ds_strategy = SummarizeTLDRPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "jeopardy": + dataset_prompter = JeopardyPrompter(d_prompt_style) + ds_strategy = JeopardyPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "oasst": + dataset_prompter = AlpacaPrompter(d_prompt_style) + ds_strategy = OpenAssistantPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "gpteacher": + dataset_prompter = GPTeacherPrompter(d_prompt_style) + ds_strategy = GPTeacherPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + elif d_base_type == "reflection": + dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) + ds_strategy = AlpacaReflectionPTStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = ds_wrapper + else: + suffix = "" + if ":load_" in config_dataset.type: + suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" + LOG.error( + f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" + ) + raise ValueError( + f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" + ) + + return dataset_wrapper, dataset_prompter def encode_pretraining(