Skip to content

Commit

Permalink
Missing update
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
  • Loading branch information
Maxusmusti committed Oct 10, 2024
1 parent bdbd494 commit d628634
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy
from pathlib import Path
import argparse
import json
import math
import os
import re
Expand Down Expand Up @@ -650,6 +651,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if os.path.exists(train_args.model_path):
if not os.path.isdir(train_args.model_path):
raise RuntimeError(
"Model path does not appear to be a dir, please validate or update the path"
)
else:
raise RuntimeError(
"Model Path cannot be found, please verify existense and permissions"
)

# process the training data
if not os.path.exists(train_args.data_output_dir):
os.makedirs(train_args.data_output_dir, exist_ok=True)
Expand Down Expand Up @@ -705,6 +716,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
command.append("--use_dolomite")

if train_args.disable_flash_attn:
Expand Down

0 comments on commit d628634

Please sign in to comment.