Skip to content

Commit

Permalink
wip (#1406)
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrimasad committed Aug 24, 2023
1 parent a6d9003 commit 7240726
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
experiment_name: # The experiment name used to train the model (optional- ignored when checkpoint_path is given)
run_id: # The directory name of the required checkpoint i.e. RUN_20230823_154026_757034 - if left empty, the last run will be used
ckpt_root_dir: # The checkpoint root directory, s.t ckpt_root_dir/experiment_name/ckpt_name resides.
# Can be ignored if the checkpoints directory is the default (i.e path to checkpoints module from contents root), or when checkpoint_path is given
ckpt_name: ckpt_best.pth # Name of the checkpoint to export ("ckpt_latest.pth", "average_model.pth" or "ckpt_best.pth" for instance).
Expand Down
11 changes: 8 additions & 3 deletions src/super_gradients/training/models/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.environment.cfg_utils import load_experiment_cfg
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_latest_run_id
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training import models
from super_gradients.training.utils.sg_trainer_utils import parse_args
Expand Down Expand Up @@ -237,8 +237,13 @@ def prepare_conversion_cfgs(cfg: DictConfig):
"checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from "
"checkpoints_dir/training_hyperparams.ckpt_name "
)
checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir))
cfg.checkpoint_path = str(checkpoints_dir / cfg.ckpt_name)
if cfg.run_id is None:
checkpoints_dir = Path(get_latest_run_id(experiment_name=cfg.experiment_name, checkpoints_root_dir=cfg.ckpt_root_dir))
else:
checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir))
checkpoints_dir = os.path.join(checkpoints_dir, cfg.run_id)

cfg.checkpoint_path = os.path.join(checkpoints_dir, cfg.ckpt_name)
cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".pth", ".onnx")
logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.")
return cfg, experiment_cfg
Expand Down

0 comments on commit 7240726

Please sign in to comment.