Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve env sanity check to be more robust #612

Merged
merged 12 commits into from
Jan 17, 2023
195 changes: 69 additions & 126 deletions src/super_gradients/sanity_check/env_sanity_check.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import logging
import os
import sys
from pip._internal.operations.freeze import freeze
from typing import List, Dict, Union
import pkg_resources
from pkg_resources import parse_version
from packaging.specifiers import SpecifierSet
from typing import List, Optional
from pathlib import Path
from packaging.version import Version


from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_main_process

LIB_CHECK_IMPOSSIBLE_MSG = 'Library check is not supported when super_gradients installed through "git+https://github.com/..." command'
logger = get_logger(__name__, "DEBUG")


def format_error_msg(test_name: str, error_msg: str) -> str:
"""Format an error message in the appropriate format.

:param test_name: Name of the test being tested.
:param error_msg: Message to format in appropriate format.
:return: Formatted message
"""
return f"\33[31mFailed to verify {test_name}: {error_msg}\33[0m"


logger = get_logger(__name__, log_level=logging.DEBUG)
def check_os():
"""Check the operating system name and platform."""

if "linux" not in sys.platform.lower():
error = "Deci officially supports only Linux kernels. Some features may not work as expected."
logger.error(msg=format_error_msg(test_name="operating system", error_msg=error))


def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
def get_requirements_path(requirements_file_name: str) -> Optional[Path]:
"""Get the path of requirement.txt from the root if exist.
There is a difference when installed from artifact or locally.
- In the first case, requirements.txt is copied to the package during the CI.
Expand All @@ -26,9 +42,9 @@ def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
is copied/cloned from github, the requirements.txt was not copied to the super_gradients package root, so we
need to go to the project root (.) to find it.
"""
file_path = Path(__file__) # super-gradients/src/super_gradients/sanity_check/env_sanity_check.py
package_root = file_path.parent.parent # moving to super-gradients/src/super_gradients
project_root = package_root.parent.parent # moving to super-gradients
file_path = Path(__file__) # Refers to: .../super-gradients/src/super_gradients/sanity_check/env_sanity_check.py
package_root = file_path.parent.parent # Refers to: .../super-gradients/src/super_gradients
project_root = package_root.parent.parent # Refers to .../super-gradients

# If installed from artifact, requirements.txt is in package_root, if installed locally it is in project_root
if (package_root / requirements_file_name).exists():
Expand All @@ -39,138 +55,65 @@ def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
return None # Could happen when installed through github directly ("pip install git+https://github.com/...")


def get_installed_libs_with_version() -> Dict[str, str]:
"""Get all the installed libraries, and outputs it as a dict: lib -> version"""
installed_libs_with_version = {}
for lib_with_version in freeze():
if "==" in lib_with_version:
lib, version = lib_with_version.split("==")
installed_libs_with_version[lib.lower()] = version
return installed_libs_with_version


def verify_installed_libraries() -> List[str]:
"""Check that all installed libs respect the requirement.txt"""

def get_requirements(use_pro_requirements: bool) -> Optional[List[str]]:
requirements_path = get_requirements_path("requirements.txt")
pro_requirements_path = get_requirements_path("requirements.pro.txt")

if requirements_path is None:
return [LIB_CHECK_IMPOSSIBLE_MSG]
if (requirements_path is None) or (pro_requirements_path is None):
return None

with open(requirements_path, "r") as f:
requirements = f.readlines()

installed_libs_with_version = get_installed_libs_with_version()
requirements = f.read().splitlines()

# if pro_requirements_path is not None:
with open(pro_requirements_path, "r") as f:
pro_requirements = f.readlines()
if "deci-lab-client" in installed_libs_with_version:
requirements += pro_requirements

errors = []
for requirement in requirements:
if ">=" in requirement:
constraint = ">="
elif "~=" in requirement:
constraint = "~="
elif "==" in requirement:
constraint = "=="
else:
continue
pro_requirements = f.read().splitlines()

return requirements + pro_requirements if use_pro_requirements else requirements


def check_packages():
"""Check that all installed libs respect the requirement.txt, and requirements.pro.txt if relevant.
Note: We only log an error
"""
test_name = "installed packages"

installed_packages = {package.key.lower(): package.version for package in pkg_resources.working_set}
requirements = get_requirements(use_pro_requirements="deci-lab-client" in installed_packages)

lib, required_version_str = requirement.split(constraint)
if requirements is None:
logger.info(msg='Library check is not supported when super_gradients installed through "git+https://github.com/..." command')
return

if ",<=" in required_version_str:
upper_limit_version = Version(required_version_str.split(",<=")[1])
required_version_str = required_version_str.split(",<=")[0]
constraint += ",<="
for requirement in pkg_resources.parse_requirements(requirements):
package_name = requirement.name.lower()

if lib.lower() not in installed_libs_with_version.keys():
errors.append(f"{lib} required but not found")
if package_name not in installed_packages.keys():
error = f"{package_name} required but not found"
logger.error(msg=format_error_msg(test_name=test_name, error_msg=error))
continue

installed_version_str = installed_libs_with_version[lib.lower()]
installed_version, required_version = Version(installed_version_str), Version(required_version_str)

is_constraint_respected = {
">=,<=": required_version <= installed_version <= upper_limit_version,
">=": installed_version >= required_version,
"~=": (
installed_version.major == required_version.major
and installed_version.minor == required_version.minor
and installed_version.micro >= required_version.micro
),
"==": installed_version == required_version,
}
if not is_constraint_respected[constraint]:
errors.append(f"{lib} is installed with version {installed_version} which does not satisfy {requirement} (based on {requirements_path})")
return errors


def verify_os() -> List[str]:
"""Verifying operating system name and platform"""
if "linux" not in sys.platform.lower():
return ["Deci officially supports only Linux kernels. Some features may not work as expected."]
return []


def run_env_sanity_check():
"""Run the sanity check tests and log everything that does not meet requirements"""

display_sanity_check = os.getenv("DISPLAY_SANITY_CHECK", "False") == "True"
stdout_log_level = logging.INFO if display_sanity_check else logging.DEBUG

logger.setLevel(logging.DEBUG) # We want to log everything regardless of DISPLAY_SANITY_CHECK

requirement_checkers = {
"operating_system": verify_os,
"libraries": verify_installed_libraries,
}

logger.log(stdout_log_level, "SuperGradients Sanity Check Started")
logger.log(stdout_log_level, f"Checking the following components: {list(requirement_checkers.keys())}")
logger.log(stdout_log_level, "_" * 20)

lib_check_is_impossible = False
sanity_check_errors = {}
for test_name, test_function in requirement_checkers.items():
logger.log(stdout_log_level, f"Verifying {test_name}...")

errors = test_function()
if errors == [LIB_CHECK_IMPOSSIBLE_MSG]:
lib_check_is_impossible = True
logger.log(stdout_log_level, LIB_CHECK_IMPOSSIBLE_MSG)
elif len(errors) > 0:
sanity_check_errors[test_name] = errors
for error in errors:
logger.log(logging.ERROR, f"\33[31mFailed to verify {test_name}: {error}\33[0m")
else:
logger.log(stdout_log_level, f"{test_name} OK")
logger.log(stdout_log_level, "_" * 20)

if sanity_check_errors:
logger.log(stdout_log_level, f'The current environment does not meet Deci\'s needs, errors found in: {", ".join(list(sanity_check_errors.keys()))}')
elif lib_check_is_impossible:
logger.log(stdout_log_level, LIB_CHECK_IMPOSSIBLE_MSG)
else:
logger.log(stdout_log_level, "Great, Looks like the current environment meet's Deci's requirements!")
installed_version_str = installed_packages[package_name]
for operator_str, req_version_str in requirement.specs:

# The last message needs to be displayed independently of DISPLAY_SANITY_CHECK
if display_sanity_check:
logger.info("** This check can be hidden by setting the env variable DISPLAY_SANITY_CHECK=False prior to import. **")
else:
logger.info(
"** A sanity check is done when importing super_gradients for the first time. ** "
"-> You can see the details by setting the env variable DISPLAY_SANITY_CHECK=True prior to import."
)
installed_version = parse_version(installed_version_str)
req_version = parse_version(req_version_str)
req_spec = SpecifierSet(operator_str + req_version_str)

if installed_version_str not in req_spec:
error = f"{package_name}=={installed_version} does not satisfy requirement {requirement}"

requires_at_least = operator_str in ("==", "~=", ">=", ">")
if requires_at_least and installed_version < req_version:
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
logger.error(msg=format_error_msg(test_name=test_name, error_msg=error))
else:
logger.debug(msg=error)


def env_sanity_check():
"""Run the sanity check tests and log everything that does not meet requirements"""
"""Run the sanity check tests and log everything that does not meet requirements."""
if is_main_process():
run_env_sanity_check()
check_os()
check_packages()


if __name__ == "__main__":
Expand Down