diff --git a/src/sparsezoo/api/query_parser.py b/src/sparsezoo/api/query_parser.py index a8f8b64a..2e5994ac 100644 --- a/src/sparsezoo/api/query_parser.py +++ b/src/sparsezoo/api/query_parser.py @@ -20,7 +20,7 @@ DEFAULT_MODELS_FIELDS = ["modelId", "stub"] -DEFAULT_FILES_FIELDS = ["displayName", "fileSize", "modelId", "fileType"] +DEFAULT_FILES_FIELDS = ["displayName", "downloadUrl", "fileSize", "fileType", "modelId"] DEFAULT_TRAINING_RESULTS_FIELDS = [ "datasetName", diff --git a/src/sparsezoo/model/utils.py b/src/sparsezoo/model/utils.py index 8089ca4b..966e605c 100644 --- a/src/sparsezoo/model/utils.py +++ b/src/sparsezoo/model/utils.py @@ -31,7 +31,7 @@ ValidationResult, ) from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz, Recipes -from sparsezoo.utils import BASE_API_URL, convert_to_bool, save_numpy +from sparsezoo.utils import convert_to_bool, save_numpy __all__ = [ @@ -599,9 +599,8 @@ def _copy_and_overwrite(from_path, to_path, func): def include_file_download_url(files: List[Dict]): for file in files: - file["url"] = get_file_download_url( - model_id=file["model_id"], file_name=file["display_name"] - ) + file["url"] = get_file_download_url(file["download_url"]) + del file["download_url"] def get_model_metadata_from_stub(stub: str) -> Dict[str, str]: @@ -637,15 +636,12 @@ def is_stub(candidate: str) -> bool: def get_file_download_url( - model_id: str, - file_name: str, - base_url: str = BASE_API_URL, + download_url: str, ): """Url to download a file""" - download_url = f"{base_url}/v2/models/{model_id}/files/{file_name}" - # important, do not remove if convert_to_bool(os.getenv("SPARSEZOO_TEST_MODE")): - download_url += "?increment_download=False" + delimiter = "&" if "?" in download_url else "?" + download_url += f"{delimiter}increment_download=False" return download_url diff --git a/tests/sparsezoo/api/test_query_parser.py b/tests/sparsezoo/api/test_query_parser.py index dce7472e..82622799 100644 --- a/tests/sparsezoo/api/test_query_parser.py +++ b/tests/sparsezoo/api/test_query_parser.py @@ -148,7 +148,7 @@ "fields": ( "modelId displayName benchmarkResults " "{ batchSize deviceInfo numCores recordedUnits recordedValue } " - "files { displayName fileSize modelId fileType } " + "files { displayName downloadUrl fileSize fileType modelId } " ), }, ),