diff --git a/src/lightning_utilities/core/imports.py b/src/lightning_utilities/core/imports.py index 45dba7f4..0e13b1b9 100644 --- a/src/lightning_utilities/core/imports.py +++ b/src/lightning_utilities/core/imports.py @@ -121,6 +121,42 @@ def __repr__(self) -> str: return self.__str__() +class ModuleAvailableCache: + """Boolean-like class for check of module availability. + + >>> ModuleAvailableCache("torch") + Module 'torch' available + >>> bool(ModuleAvailableCache("torch")) + True + >>> bool(ModuleAvailableCache("unknown_package")) + False + """ + + def __init__(self, module: str) -> None: + self.module = module + + def _check_requirement(self) -> None: + if hasattr(self, "available"): + return + + self.available = module_available(self.module) + if self.available: + self.message = f"Module {self.module!r} available" + else: + self.message = f"Module not found: {self.module!r}. HINT: Try running `pip install -U {self.module}`" + + def __bool__(self) -> bool: + self._check_requirement() + return self.available + + def __str__(self) -> str: + self._check_requirement() + return self.message + + def __repr__(self) -> str: + return self.__str__() + + def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str: """Returns the minimum version specifier of a dependency of a package. diff --git a/tests/unittests/core/test_imports.py b/tests/unittests/core/test_imports.py index ed0fb58b..161077f7 100644 --- a/tests/unittests/core/test_imports.py +++ b/tests/unittests/core/test_imports.py @@ -8,6 +8,7 @@ get_dependency_min_version_spec, lazy_import, module_available, + ModuleAvailableCache, RequirementCache, requires, ) @@ -52,6 +53,12 @@ def test_requirement_cache(): assert "pip install -U '-'" in str(RequirementCache("-")) +def test_module_available_cache(): + assert ModuleAvailableCache("pytest") + assert not ModuleAvailableCache("this_module_is_not_installed") + assert "pip install -U this_module_is_not_installed" in str(ModuleAvailableCache("this_module_is_not_installed")) + + def test_get_dependency_min_version_spec(): attrs_min_version_spec = get_dependency_min_version_spec("pytest", "attrs") assert re.match(r"^>=[\d.]+$", attrs_min_version_spec)