Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Mar 4, 2024
1 parent 00d5936 commit 3dae0ac
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 3 deletions.
2 changes: 1 addition & 1 deletion supervisely/train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
supervisely==6.73.4
supervisely==6.73.41
74 changes: 74 additions & 0 deletions supervisely/train/src/sly_project_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

import supervisely as sly
from supervisely.project.download import (
download_to_cache,
copy_from_cache,
is_cached,
get_cache_size,
)
from sly_utils import get_progress_cb
import sly_train_globals as g


def download_project(
api: sly.Api,
project_info: sly.ProjectInfo,
project_dir: str,
use_cache: bool,
):
if os.path.exists(project_dir):
sly.fs.clean_dir(project_dir)
if not use_cache:
total = project_info.items_count
download_progress = get_progress_cb("Downloading input data...", total * 2)
sly.download(
api=api,
project_id=project_info.id,
dest_dir=project_dir,
dataset_ids=None,
log_progress=True,
progress_cb=download_progress,
cache=g.my_app.cache
)
return

# get datasets to download and cached
dataset_infos = api.dataset.get_list(project_info.id)
to_download = [info for info in dataset_infos if not is_cached(project_info.id, info.name)]
cached = [info for info in dataset_infos if is_cached(project_info.id, info.name)]
if len(cached) == 0:
log_msg = "No cached datasets found"
else:
log_msg = "Using cached datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in cached
)
sly.logger.info(log_msg)
if len(to_download) == 0:
log_msg = "All datasets are cached. No datasets to download"
else:
log_msg = "Downloading datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in to_download
)
sly.logger.info(log_msg)
# get images count
total = sum([ds_info.images_count for ds_info in to_download])
# download
download_progress = get_progress_cb("Downloading input data...", total * 2)
download_to_cache(
api=api,
project_id=project_info.id,
dataset_infos=to_download,
log_progress=True,
progress_cb=download_progress,
)
# copy datasets from cache
total = sum([get_cache_size(project_info.id, ds.name) for ds in dataset_infos])
dataset_names = [ds_info.name for ds_info in dataset_infos]
download_progress = get_progress_cb("Retreiving data from cache...", total, is_size=True)
copy_from_cache(
project_id=project_info.id,
dest_dir=project_dir,
dataset_names=dataset_names,
progress_cb=download_progress,
)
9 changes: 7 additions & 2 deletions supervisely/train/src/sly_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
root_source_dir, scratch_str, finetune_str

import ui as ui
from sly_project_cached import download_project
from sly_train_utils import init_script_arguments
from sly_utils import get_progress_cb, upload_artifacts
from splits import get_train_val_sets, verify_train_val_sets
Expand Down Expand Up @@ -39,9 +40,13 @@ def train(api: sly.Api, task_id, context, state, app_logger):
sly.fs.mkdir(project_dir, remove_content_if_exists=True) # clean content for debug, has no effect in prod

# download and preprocess Sypervisely project (using cache)
download_progress = get_progress_cb("Download data (using cache)", g.project_info.items_count * 2)
try:
sly.download_project(api, project_id, project_dir, cache=my_app.cache, progress_cb=download_progress)
download_project(
api=api,
project_info=g.project_info,
project_dir=project_dir,
use_cache=True,
)
except Exception as e:
sly.logger.warn("Can not download project")
raise Exception(
Expand Down

0 comments on commit 3dae0ac

Please sign in to comment.