Skip to content

Commit

Permalink
update train
Browse files Browse the repository at this point in the history
  • Loading branch information
cxnt committed Jun 6, 2024
1 parent 05bb75f commit 1de3ac9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
11 changes: 10 additions & 1 deletion supervisely/train/src/sly_train_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import yaml
import supervisely as sly
from supervisely.nn.checkpoints.yolov5 import YOLOv5Checkpoint
from supervisely.app.v1.app_service import AppService
from dotenv import load_dotenv

Expand Down Expand Up @@ -54,6 +55,14 @@
experiment_name = str(task_id)
local_artifacts_dir = os.path.join(runs_dir, experiment_name)
sly.logger.info(f"All training artifacts will be saved to local directory {local_artifacts_dir}")
remote_artifacts_dir = os.path.join("/yolov5_train", project_info.name, experiment_name)

checkpoint = YOLOv5Checkpoint(team_id)
model_dir = checkpoint.get_model_dir()

remote_artifacts_dir = os.path.join(model_dir, project_info.name, experiment_name)
remote_artifacts_dir = api.file.get_free_dir_name(team_id, remote_artifacts_dir)

remote_weights_dir = os.path.join(remote_artifacts_dir, checkpoint.weights_dir)
remote_weights_dir = api.file.get_free_dir_name(team_id, remote_artifacts_dir)

sly.logger.info(f"After training artifacts will be uploaded to Team Files: {remote_artifacts_dir}")
13 changes: 12 additions & 1 deletion supervisely/train/src/sly_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,15 @@ def _gen_message(current, total):
globals.api.file.upload(globals.team_id, local_path, remote_path,
lambda monitor: progress_cb(progress_last + monitor.bytes_read))
progress.message = _gen_message(idx + 1, len(local_files))
time.sleep(0.5)
time.sleep(0.5)

# generate metadata
globals.checkpoint.generate_sly_metadata(
app_name=globals.checkpoint.app_name,
session_id=globals.experiment_name,
session_path=globals.remote_artifacts_dir,
weights_dir=globals.remote_weights_dir,
training_project_name=globals.project_info.name,
task_type=globals.checkpoint.task_type,
config_path=None,
)

0 comments on commit 1de3ac9

Please sign in to comment.