Skip to content

Commit

Permalink
Improve env sanity check to be more robust (#612)
Browse files Browse the repository at this point in the history
* first version - a few things to change

* improve

* imrpove doc

* undo requirements change

* mini update

* support wildcards and improve code

* move all to same file, and only raise error when version too small

* move all to same file, and only raise error when version too small

* remove arg
  • Loading branch information
Louis-Dupont committed Jan 17, 2023
1 parent a2e58bb commit 94ca017
Showing 1 changed file with 69 additions and 126 deletions.
195 changes: 69 additions & 126 deletions src/super_gradients/sanity_check/env_sanity_check.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import logging
import os
import sys
from pip._internal.operations.freeze import freeze
from typing import List, Dict, Union
import pkg_resources
from pkg_resources import parse_version
from packaging.specifiers import SpecifierSet
from typing import List, Optional
from pathlib import Path
from packaging.version import Version


from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_main_process

LIB_CHECK_IMPOSSIBLE_MSG = 'Library check is not supported when super_gradients installed through "git+https://github.com/..." command'
logger = get_logger(__name__, "DEBUG")


def format_error_msg(test_name: str, error_msg: str) -> str:
"""Format an error message in the appropriate format.
:param test_name: Name of the test being tested.
:param error_msg: Message to format in appropriate format.
:return: Formatted message
"""
return f"\33[31mFailed to verify {test_name}: {error_msg}\33[0m"


logger = get_logger(__name__, log_level=logging.DEBUG)
def check_os():
"""Check the operating system name and platform."""

if "linux" not in sys.platform.lower():
error = "Deci officially supports only Linux kernels. Some features may not work as expected."
logger.error(msg=format_error_msg(test_name="operating system", error_msg=error))


def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
def get_requirements_path(requirements_file_name: str) -> Optional[Path]:
"""Get the path of requirement.txt from the root if exist.
There is a difference when installed from artifact or locally.
- In the first case, requirements.txt is copied to the package during the CI.
Expand All @@ -26,9 +42,9 @@ def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
is copied/cloned from github, the requirements.txt was not copied to the super_gradients package root, so we
need to go to the project root (.) to find it.
"""
file_path = Path(__file__) # super-gradients/src/super_gradients/sanity_check/env_sanity_check.py
package_root = file_path.parent.parent # moving to super-gradients/src/super_gradients
project_root = package_root.parent.parent # moving to super-gradients
file_path = Path(__file__) # Refers to: .../super-gradients/src/super_gradients/sanity_check/env_sanity_check.py
package_root = file_path.parent.parent # Refers to: .../super-gradients/src/super_gradients
project_root = package_root.parent.parent # Refers to .../super-gradients

# If installed from artifact, requirements.txt is in package_root, if installed locally it is in project_root
if (package_root / requirements_file_name).exists():
Expand All @@ -39,138 +55,65 @@ def get_requirements_path(requirements_file_name: str) -> Union[None, Path]:
return None # Could happen when installed through github directly ("pip install git+https://github.com/...")


def get_installed_libs_with_version() -> Dict[str, str]:
"""Get all the installed libraries, and outputs it as a dict: lib -> version"""
installed_libs_with_version = {}
for lib_with_version in freeze():
if "==" in lib_with_version:
lib, version = lib_with_version.split("==")
installed_libs_with_version[lib.lower()] = version
return installed_libs_with_version


def verify_installed_libraries() -> List[str]:
"""Check that all installed libs respect the requirement.txt"""

def get_requirements(use_pro_requirements: bool) -> Optional[List[str]]:
requirements_path = get_requirements_path("requirements.txt")
pro_requirements_path = get_requirements_path("requirements.pro.txt")

if requirements_path is None:
return [LIB_CHECK_IMPOSSIBLE_MSG]
if (requirements_path is None) or (pro_requirements_path is None):
return None

with open(requirements_path, "r") as f:
requirements = f.readlines()

installed_libs_with_version = get_installed_libs_with_version()
requirements = f.read().splitlines()

# if pro_requirements_path is not None:
with open(pro_requirements_path, "r") as f:
pro_requirements = f.readlines()
if "deci-lab-client" in installed_libs_with_version:
requirements += pro_requirements

errors = []
for requirement in requirements:
if ">=" in requirement:
constraint = ">="
elif "~=" in requirement:
constraint = "~="
elif "==" in requirement:
constraint = "=="
else:
continue
pro_requirements = f.read().splitlines()

return requirements + pro_requirements if use_pro_requirements else requirements


def check_packages():
"""Check that all installed libs respect the requirement.txt, and requirements.pro.txt if relevant.
Note: We only log an error
"""
test_name = "installed packages"

installed_packages = {package.key.lower(): package.version for package in pkg_resources.working_set}
requirements = get_requirements(use_pro_requirements="deci-lab-client" in installed_packages)

lib, required_version_str = requirement.split(constraint)
if requirements is None:
logger.info(msg='Library check is not supported when super_gradients installed through "git+https://github.com/..." command')
return

if ",<=" in required_version_str:
upper_limit_version = Version(required_version_str.split(",<=")[1])
required_version_str = required_version_str.split(",<=")[0]
constraint += ",<="
for requirement in pkg_resources.parse_requirements(requirements):
package_name = requirement.name.lower()

if lib.lower() not in installed_libs_with_version.keys():
errors.append(f"{lib} required but not found")
if package_name not in installed_packages.keys():
error = f"{package_name} required but not found"
logger.error(msg=format_error_msg(test_name=test_name, error_msg=error))
continue

installed_version_str = installed_libs_with_version[lib.lower()]
installed_version, required_version = Version(installed_version_str), Version(required_version_str)

is_constraint_respected = {
">=,<=": required_version <= installed_version <= upper_limit_version,
">=": installed_version >= required_version,
"~=": (
installed_version.major == required_version.major
and installed_version.minor == required_version.minor
and installed_version.micro >= required_version.micro
),
"==": installed_version == required_version,
}
if not is_constraint_respected[constraint]:
errors.append(f"{lib} is installed with version {installed_version} which does not satisfy {requirement} (based on {requirements_path})")
return errors


def verify_os() -> List[str]:
"""Verifying operating system name and platform"""
if "linux" not in sys.platform.lower():
return ["Deci officially supports only Linux kernels. Some features may not work as expected."]
return []


def run_env_sanity_check():
"""Run the sanity check tests and log everything that does not meet requirements"""

display_sanity_check = os.getenv("DISPLAY_SANITY_CHECK", "False") == "True"
stdout_log_level = logging.INFO if display_sanity_check else logging.DEBUG

logger.setLevel(logging.DEBUG) # We want to log everything regardless of DISPLAY_SANITY_CHECK

requirement_checkers = {
"operating_system": verify_os,
"libraries": verify_installed_libraries,
}

logger.log(stdout_log_level, "SuperGradients Sanity Check Started")
logger.log(stdout_log_level, f"Checking the following components: {list(requirement_checkers.keys())}")
logger.log(stdout_log_level, "_" * 20)

lib_check_is_impossible = False
sanity_check_errors = {}
for test_name, test_function in requirement_checkers.items():
logger.log(stdout_log_level, f"Verifying {test_name}...")

errors = test_function()
if errors == [LIB_CHECK_IMPOSSIBLE_MSG]:
lib_check_is_impossible = True
logger.log(stdout_log_level, LIB_CHECK_IMPOSSIBLE_MSG)
elif len(errors) > 0:
sanity_check_errors[test_name] = errors
for error in errors:
logger.log(logging.ERROR, f"\33[31mFailed to verify {test_name}: {error}\33[0m")
else:
logger.log(stdout_log_level, f"{test_name} OK")
logger.log(stdout_log_level, "_" * 20)

if sanity_check_errors:
logger.log(stdout_log_level, f'The current environment does not meet Deci\'s needs, errors found in: {", ".join(list(sanity_check_errors.keys()))}')
elif lib_check_is_impossible:
logger.log(stdout_log_level, LIB_CHECK_IMPOSSIBLE_MSG)
else:
logger.log(stdout_log_level, "Great, Looks like the current environment meet's Deci's requirements!")
installed_version_str = installed_packages[package_name]
for operator_str, req_version_str in requirement.specs:

# The last message needs to be displayed independently of DISPLAY_SANITY_CHECK
if display_sanity_check:
logger.info("** This check can be hidden by setting the env variable DISPLAY_SANITY_CHECK=False prior to import. **")
else:
logger.info(
"** A sanity check is done when importing super_gradients for the first time. ** "
"-> You can see the details by setting the env variable DISPLAY_SANITY_CHECK=True prior to import."
)
installed_version = parse_version(installed_version_str)
req_version = parse_version(req_version_str)
req_spec = SpecifierSet(operator_str + req_version_str)

if installed_version_str not in req_spec:
error = f"{package_name}=={installed_version} does not satisfy requirement {requirement}"

requires_at_least = operator_str in ("==", "~=", ">=", ">")
if requires_at_least and installed_version < req_version:
logger.error(msg=format_error_msg(test_name=test_name, error_msg=error))
else:
logger.debug(msg=error)


def env_sanity_check():
"""Run the sanity check tests and log everything that does not meet requirements"""
"""Run the sanity check tests and log everything that does not meet requirements."""
if is_main_process():
run_env_sanity_check()
check_os()
check_packages()


if __name__ == "__main__":
Expand Down

0 comments on commit 94ca017

Please sign in to comment.