Skip to content

Commit

Permalink
add e2e tests for checking functionality of resume from checkpoint (#865
Browse files Browse the repository at this point in the history
)

* use tensorboard to see if resume from checkpoint works

* make sure e2e test is either fp16 or bf16

* set max_steps and save limit so we have the checkpoint when testing resuming

* fix test parameters
  • Loading branch information
winglian committed Nov 16, 2023
1 parent 8a8d1c4 commit b3a61e8
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ pynvml
art
fschat==0.2.29
gradio
tensorboard
1 change: 1 addition & 0 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_lora_packing(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
Expand Down
95 changes: 95 additions & 0 deletions tests/e2e/test_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
E2E tests for resuming training
"""

import logging
import os
import re
import subprocess
import unittest
from pathlib import Path

from transformers.utils import is_torch_bf16_gpu_available

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import most_recent_subdir, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestResumeLlama(unittest.TestCase):
"""
Test case for resuming training of llama models
"""

@with_temp_dir
def test_resume_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_steps": 10,
"save_total_limit": 5,
"max_steps": 40,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

resume_cfg = cfg | DictDefault(
{
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
}
)
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()

train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
res = subprocess.run(
cmd, shell=True, text=True, capture_output=True, check=True
)
pattern = r"first_step\s+(\d+)"
first_steps = int(re.findall(pattern, res.stdout)[0])
assert first_steps == 31
13 changes: 12 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
helper utils for tests
"""

import os
import shutil
import tempfile
from functools import wraps
from pathlib import Path


def with_temp_dir(test_func):
Expand All @@ -20,3 +21,13 @@ def wrapper(*args, **kwargs):
shutil.rmtree(temp_dir)

return wrapper


def most_recent_subdir(path):
base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
if not subdirectories:
return None
subdir = max(subdirectories, key=os.path.getctime)

return subdir

0 comments on commit b3a61e8

Please sign in to comment.