From df6d8c882b0514c12698d936834005dbd334319f Mon Sep 17 00:00:00 2001 From: Nikolai Petukhov Date: Mon, 4 Mar 2024 14:27:48 -0300 Subject: [PATCH] test --- supervisely/train/requirements.txt | 2 +- supervisely/train/src/sly_project_cached.py | 74 +++++++++++++++++++++ supervisely/train/src/sly_train.py | 9 ++- supervisely/train/src/ui/input_project.html | 10 ++- supervisely/train/src/ui/input_project.py | 5 +- supervisely/train/src/ui/ui.py | 2 +- 6 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 supervisely/train/src/sly_project_cached.py diff --git a/supervisely/train/requirements.txt b/supervisely/train/requirements.txt index a39f6deb9b05..073b87a59366 100644 --- a/supervisely/train/requirements.txt +++ b/supervisely/train/requirements.txt @@ -1 +1 @@ -supervisely==6.73.4 +supervisely==6.73.41 diff --git a/supervisely/train/src/sly_project_cached.py b/supervisely/train/src/sly_project_cached.py new file mode 100644 index 000000000000..5fcce7d5b886 --- /dev/null +++ b/supervisely/train/src/sly_project_cached.py @@ -0,0 +1,74 @@ +import os + +import supervisely as sly +from supervisely.project.download import ( + download_to_cache, + copy_from_cache, + is_cached, + get_cache_size, +) +from sly_utils import get_progress_cb +import sly_train_globals as g + + +def download_project( + api: sly.Api, + project_info: sly.ProjectInfo, + project_dir: str, + use_cache: bool, +): + if os.path.exists(project_dir): + sly.fs.clean_dir(project_dir) + if not use_cache: + total = project_info.items_count + download_progress = get_progress_cb("Downloading input data...", total * 2) + sly.download( + api=api, + project_id=project_info.id, + dest_dir=project_dir, + dataset_ids=None, + log_progress=True, + progress_cb=download_progress, + cache=g.my_app.cache + ) + return + + # get datasets to download and cached + dataset_infos = api.dataset.get_list(project_info.id) + to_download = [info for info in dataset_infos if not is_cached(project_info.id, info.name)] + cached = [info for info in dataset_infos if is_cached(project_info.id, info.name)] + if len(cached) == 0: + log_msg = "No cached datasets found" + else: + log_msg = "Using cached datasets: " + ", ".join( + f"{ds_info.name} ({ds_info.id})" for ds_info in cached + ) + sly.logger.info(log_msg) + if len(to_download) == 0: + log_msg = "All datasets are cached. No datasets to download" + else: + log_msg = "Downloading datasets: " + ", ".join( + f"{ds_info.name} ({ds_info.id})" for ds_info in to_download + ) + sly.logger.info(log_msg) + # get images count + total = sum([ds_info.images_count for ds_info in to_download]) + # download + download_progress = get_progress_cb("Downloading input data...", total * 2) + download_to_cache( + api=api, + project_id=project_info.id, + dataset_infos=to_download, + log_progress=True, + progress_cb=download_progress, + ) + # copy datasets from cache + total = sum([get_cache_size(project_info.id, ds.name) for ds in dataset_infos]) + dataset_names = [ds_info.name for ds_info in dataset_infos] + download_progress = get_progress_cb("Retreiving data from cache...", total, is_size=True) + copy_from_cache( + project_id=project_info.id, + dest_dir=project_dir, + dataset_names=dataset_names, + progress_cb=download_progress, + ) diff --git a/supervisely/train/src/sly_train.py b/supervisely/train/src/sly_train.py index 1aadd1d3137e..ea59a9c400dc 100644 --- a/supervisely/train/src/sly_train.py +++ b/supervisely/train/src/sly_train.py @@ -10,6 +10,7 @@ root_source_dir, scratch_str, finetune_str import ui as ui +from sly_project_cached import download_project from sly_train_utils import init_script_arguments from sly_utils import get_progress_cb, upload_artifacts from splits import get_train_val_sets, verify_train_val_sets @@ -39,9 +40,13 @@ def train(api: sly.Api, task_id, context, state, app_logger): sly.fs.mkdir(project_dir, remove_content_if_exists=True) # clean content for debug, has no effect in prod # download and preprocess Sypervisely project (using cache) - download_progress = get_progress_cb("Download data (using cache)", g.project_info.items_count * 2) try: - sly.download_project(api, project_id, project_dir, cache=my_app.cache, progress_cb=download_progress) + download_project( + api=api, + project_info=g.project_info, + project_dir=project_dir, + use_cache=state.get("useCache", True), + ) except Exception as e: sly.logger.warn("Can not download project") raise Exception( diff --git a/supervisely/train/src/ui/input_project.html b/supervisely/train/src/ui/input_project.html index ff0ef021bc63..76831a592e9c 100644 --- a/supervisely/train/src/ui/input_project.html +++ b/supervisely/train/src/ui/input_project.html @@ -1,8 +1,12 @@ - {{data.projectName}} ({{data.projectImagesCount}} + {{data.projectName}} + ({{data.projectImagesCount}} images) - + + + Use cached data stored on the agent to optimize project downlaod + Cache data on the agent to optimize project download for future trainings + \ No newline at end of file diff --git a/supervisely/train/src/ui/input_project.py b/supervisely/train/src/ui/input_project.py index aeda250bd7f5..bee970dff4ad 100644 --- a/supervisely/train/src/ui/input_project.py +++ b/supervisely/train/src/ui/input_project.py @@ -1,8 +1,11 @@ +from supervisely.project.download import is_cached import sly_train_globals as g -def init(data): +def init(data, state): data["projectId"] = g.project_info.id data["projectName"] = g.project_info.name data["projectImagesCount"] = g.project_info.items_count data["projectPreviewUrl"] = g.api.image.preview_url(g.project_info.reference_image_url, 100, 100) + data["isCached"] = is_cached(g.project_info.id) + state["useCache"] = True \ No newline at end of file diff --git a/supervisely/train/src/ui/ui.py b/supervisely/train/src/ui/ui.py index 1c64e75cd7a1..3d22f51b8053 100644 --- a/supervisely/train/src/ui/ui.py +++ b/supervisely/train/src/ui/ui.py @@ -9,7 +9,7 @@ def init(data, state): - input_project.init(data) + input_project.init(data, state) training_classes.init(g.api, data, state, g.project_meta, g.project_stats) train_val_split.init(g.project_info, g.project_meta, data, state) model_architectures.init(data, state)