diff --git a/supervisely/train/requirements.txt b/supervisely/train/requirements.txt index a39f6deb9b05..073b87a59366 100644 --- a/supervisely/train/requirements.txt +++ b/supervisely/train/requirements.txt @@ -1 +1 @@ -supervisely==6.73.4 +supervisely==6.73.41 diff --git a/supervisely/train/src/sly_project_cached.py b/supervisely/train/src/sly_project_cached.py new file mode 100644 index 000000000000..5fcce7d5b886 --- /dev/null +++ b/supervisely/train/src/sly_project_cached.py @@ -0,0 +1,74 @@ +import os + +import supervisely as sly +from supervisely.project.download import ( + download_to_cache, + copy_from_cache, + is_cached, + get_cache_size, +) +from sly_utils import get_progress_cb +import sly_train_globals as g + + +def download_project( + api: sly.Api, + project_info: sly.ProjectInfo, + project_dir: str, + use_cache: bool, +): + if os.path.exists(project_dir): + sly.fs.clean_dir(project_dir) + if not use_cache: + total = project_info.items_count + download_progress = get_progress_cb("Downloading input data...", total * 2) + sly.download( + api=api, + project_id=project_info.id, + dest_dir=project_dir, + dataset_ids=None, + log_progress=True, + progress_cb=download_progress, + cache=g.my_app.cache + ) + return + + # get datasets to download and cached + dataset_infos = api.dataset.get_list(project_info.id) + to_download = [info for info in dataset_infos if not is_cached(project_info.id, info.name)] + cached = [info for info in dataset_infos if is_cached(project_info.id, info.name)] + if len(cached) == 0: + log_msg = "No cached datasets found" + else: + log_msg = "Using cached datasets: " + ", ".join( + f"{ds_info.name} ({ds_info.id})" for ds_info in cached + ) + sly.logger.info(log_msg) + if len(to_download) == 0: + log_msg = "All datasets are cached. No datasets to download" + else: + log_msg = "Downloading datasets: " + ", ".join( + f"{ds_info.name} ({ds_info.id})" for ds_info in to_download + ) + sly.logger.info(log_msg) + # get images count + total = sum([ds_info.images_count for ds_info in to_download]) + # download + download_progress = get_progress_cb("Downloading input data...", total * 2) + download_to_cache( + api=api, + project_id=project_info.id, + dataset_infos=to_download, + log_progress=True, + progress_cb=download_progress, + ) + # copy datasets from cache + total = sum([get_cache_size(project_info.id, ds.name) for ds in dataset_infos]) + dataset_names = [ds_info.name for ds_info in dataset_infos] + download_progress = get_progress_cb("Retreiving data from cache...", total, is_size=True) + copy_from_cache( + project_id=project_info.id, + dest_dir=project_dir, + dataset_names=dataset_names, + progress_cb=download_progress, + ) diff --git a/supervisely/train/src/sly_train.py b/supervisely/train/src/sly_train.py index 1aadd1d3137e..527d8a3db079 100644 --- a/supervisely/train/src/sly_train.py +++ b/supervisely/train/src/sly_train.py @@ -10,6 +10,7 @@ root_source_dir, scratch_str, finetune_str import ui as ui +from sly_project_cached import download_project from sly_train_utils import init_script_arguments from sly_utils import get_progress_cb, upload_artifacts from splits import get_train_val_sets, verify_train_val_sets @@ -39,9 +40,13 @@ def train(api: sly.Api, task_id, context, state, app_logger): sly.fs.mkdir(project_dir, remove_content_if_exists=True) # clean content for debug, has no effect in prod # download and preprocess Sypervisely project (using cache) - download_progress = get_progress_cb("Download data (using cache)", g.project_info.items_count * 2) try: - sly.download_project(api, project_id, project_dir, cache=my_app.cache, progress_cb=download_progress) + download_project( + api=api, + project_info=g.project_info, + project_dir=project_dir, + use_cache=True, + ) except Exception as e: sly.logger.warn("Can not download project") raise Exception(