Skip to content

Commit

Permalink
DPO cleanup (#1126)
Browse files Browse the repository at this point in the history
* cleanup dpo to be a little more extensible, add zephyr/nectar strategy

* fix eos slash

* support for eval split

* fix kwargs

* handle empty evals

* don't load peft model for dpo

* ensure dpo traning args gets bf16 for peft if applicable

* fix duplicate kwargs for bf16

* make sure to respect the configured lr scheduler

* supprt trainer callback to push config to wandb

* set dataloader preload args

* ensure that we are loading the lora when merging

* Update src/axolotl/utils/data.py

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* support local datasets for dpo

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* chore: lint

* dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names

* add split to dpo tests

* fix rebase/merging error

* handle edge case w logging

* use accelerator for dpo datasets so it doesn't break the logger

* missing args

* validate checkpoint is an adapter for now

* log warning when dataset strategy is not loadable

---------

Co-authored-by: Agus <agustin.piqueres@gmail.com>
  • Loading branch information
winglian and plaguss authored Jan 23, 2024
1 parent 5439707 commit 7523d1f
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 106 deletions.
79 changes: 2 additions & 77 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
Expand All @@ -30,7 +29,7 @@
normalize_config,
validate_config,
)
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
Expand Down Expand Up @@ -343,81 +342,7 @@ def load_rl_datasets(
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
# eval_dataset = load_dataset(
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
# )
eval_dataset = None

def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample

def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(
ds_type_fn,
desc="Mapping RL Dataset",
)
train_dataset = concatenate_datasets(train_datasets)

# eval_dataset = eval_dataset.map(intel_apply_chatml)

train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
Expand Down
109 changes: 86 additions & 23 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Optional, Type, Union
from typing import List, Optional, Type, Union

import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers import (
EarlyStoppingCallback,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer

Expand Down Expand Up @@ -460,6 +465,7 @@ class TrainerBuilderBase(abc.ABC):
_train_dataset = None
_eval_dataset = None
_model_ref = None
_peft_config = None

def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
Expand Down Expand Up @@ -490,26 +496,33 @@ def eval_dataset(self):
def eval_dataset(self, dataset):
self._eval_dataset = dataset

@property
def peft_config(self):
return self._peft_config

@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config

@abstractmethod
def build(self, total_num_steps):
pass

@abstractmethod
def get_callbacks(self):
pass
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)

return callbacks

@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""


class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""

def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
Expand All @@ -526,10 +539,16 @@ def hook_post_create_trainer(self, trainer):
# TODO
return trainer


class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""

def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
callbacks.append(EvalFirstStepCallback())

if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
Expand All @@ -538,7 +557,7 @@ def get_callbacks(self):
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
callbacks.append(SaveBetterTransformerModelCallback())

if self.cfg.use_wandb:
callbacks.append(
Expand Down Expand Up @@ -931,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
"""

def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand All @@ -949,21 +968,60 @@ def build_training_arguments(self, total_num_steps):
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)

if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True

if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy

if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.eval_dataset:
training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True

training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)

if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor

training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
max_steps=self.cfg.max_steps or total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
evaluation_strategy="no",
# eval_steps=self.cfg.eval_steps,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
bf16=True,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
or {"use_reentrant": False},
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
Expand All @@ -982,22 +1040,27 @@ def build(self, total_num_steps):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"

if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
# eval_dataset=self.eval_dataset,
eval_dataset=None,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

return dpo_trainer

Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/prompt_strategies/dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
module for DPO style dataset transform strategies
"""

import importlib
import logging

LOG = logging.getLogger("axolotl")


def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
Loading

0 comments on commit 7523d1f

Please sign in to comment.