Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PLFM-3938 Create a code-loading utility in DeciClient #1393

Merged
merged 5 commits into from
Aug 21, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions src/super_gradients/common/plugins/deci_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,43 +149,53 @@ def get_model_weights(self, model_name: str) -> Optional[str]:
:return: model_weights path. None if weights were not found for this specific model on this SG version."""
return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)

@staticmethod
def load_code_from_zipfile(*, file: str, target_path: str, success_message: str, package_name: str = "deci_model_code") -> None:
"""
load additional code files.
the zip file is extracted, and code files will be placed in the target_path/package_name and imported dynamically
"""
roikoren755 marked this conversation as resolved.
Show resolved Hide resolved
package_path = os.path.join(target_path, package_name)
# create the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, "__init__.py"), "w") as init_file:
all_str = "\n\n__all__ = ["
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

logger.info(success_message)

def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
"""
try to download code files for this model.
if found, code files will be placed in the target_path/package_name and imported dynamically
"""

file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)

package_path = os.path.join(target_path, package_name)
if file is not None:
# crete the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, "__init__.py"), "w") as init_file:
all_str = "\n\n__all__ = ["
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

logger.info(
f"*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n"
f"These files will be downloaded to the same location each time the model is fetched from the deci-client.\n"
f"you can override this by passing models.get(... download_required_code=False) and importing the files yourself"
)
self.load_code_from_zipfile(
file=file,
target_path=target_path,
package_name=package_name,
success_message=f"*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n"
f"These files will be downloaded to the same location each time the model is fetched from the deci-client.\n"
f"you can override this by passing models.get(... download_required_code=False) and importing the files yourself",
)
roikoren755 marked this conversation as resolved.
Show resolved Hide resolved

def upload_model(
self,
Expand Down