Skip to content

Commit

Permalink
continue to support scripts/finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 13, 2023
1 parent f82d21c commit b049704
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
52 changes: 52 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
from pathlib import Path

import fire
import transformers

from axolotl.cli import (
check_accelerate_default_config,
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train

LOG = logging.getLogger("axolotl.scripts.finetune")


def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will me replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


if __name__ == "__main__":
fire.Fire(do_cli)
9 changes: 7 additions & 2 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
import fire
import transformers

from axolotl.cli import load_cfg, load_datasets, print_axolotl_text_art, check_accelerate_default_config
from axolotl.cli import (
check_accelerate_default_config,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
Expand Down

0 comments on commit b049704

Please sign in to comment.