Skip to content

Commit

Permalink
Disable caching on --disable_caching in CLI (#1110)
Browse files Browse the repository at this point in the history
* Disable caching on `--disable_caching` in CLI

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
casper-hansen and winglian committed Jan 13, 2024
1 parent 304ea1b commit d66b101
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import fire
import transformers
from colorama import Fore
from datasets import disable_caching

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

if (
remaining_args.get("disable_caching") is not None
and remaining_args["disable_caching"]
):
disable_caching()
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
Expand Down
9 changes: 8 additions & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import fire
import transformers
from datasets import disable_caching

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

if (
remaining_args.get("disable_caching") is not None
and remaining_args["disable_caching"]
):
disable_caching()
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
Expand Down

0 comments on commit d66b101

Please sign in to comment.