Skip to content

Commit

Permalink
PLFM-3938 Create a code-loading utility in DeciClient (#1393)
Browse files Browse the repository at this point in the history
* PLFM-3938 Create a code-loading utility in DeciClient

* PLFM-3938 CR

* PLFM-3938 Better
  • Loading branch information
roikoren755 committed Aug 21, 2023
1 parent 51d61bd commit 1680783
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions src/super_gradients/common/plugins/deci_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,38 +149,46 @@ def get_model_weights(self, model_name: str) -> Optional[str]:
:return: model_weights path. None if weights were not found for this specific model on this SG version."""
return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)

@staticmethod
def load_code_from_zipfile(*, file: str, target_path: str, package_name: str = "deci_model_code") -> None:
"""Load additional code files.
The zip file is extracted, and code files will be placed in the target_path/package_name and imported dynamically,
:param file: path to zip file to extract code files from.
:param target_path: path to place code files.
:param package_name: name of the package to place code files in."""
package_path = os.path.join(target_path, package_name)
# create the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, "__init__.py"), "w") as init_file:
all_str = "\n\n__all__ = ["
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
"""
try to download code files for this model.
if found, code files will be placed in the target_path/package_name and imported dynamically
"""

file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)

package_path = os.path.join(target_path, package_name)
if file is not None:
# crete the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, "__init__.py"), "w") as init_file:
all_str = "\n\n__all__ = ["
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

self.load_code_from_zipfile(file=file, target_path=target_path, package_name=package_name)
logger.info(
f"*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n"
f"These files will be downloaded to the same location each time the model is fetched from the deci-client.\n"
Expand Down

0 comments on commit 1680783

Please sign in to comment.