Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error with pip _get_installed_distributions #1494

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/super_gradients/common/environment/package_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pkg_resources
from typing import Dict


def get_installed_packages() -> Dict[str, str]:
"""Map all the installed packages to their version."""
return {package.key.lower(): package.version for package in pkg_resources.working_set}
3 changes: 2 additions & 1 deletion src/super_gradients/sanity_check/env_sanity_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_main_process
from super_gradients.common.environment.package_utils import get_installed_packages

logger = get_logger(__name__, "DEBUG")

Expand Down Expand Up @@ -79,7 +80,7 @@ def check_packages():
"""
test_name = "installed packages"

installed_packages = {package.key.lower(): package.version for package in pkg_resources.working_set}
installed_packages = get_installed_packages()
requirements = get_requirements(use_pro_requirements="deci-platform-client" in installed_packages)

if requirements is None:
Expand Down
6 changes: 3 additions & 3 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn
import torchmetrics
from omegaconf import DictConfig, OmegaConf
from piptools.scripts.sync import _get_installed_distributions

from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, SequentialSampler
Expand Down Expand Up @@ -40,6 +40,7 @@
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.environment.package_utils import get_installed_packages

from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.datasets.samplers import RepeatAugSampler
Expand Down Expand Up @@ -1875,8 +1876,7 @@ def _get_hyper_param_config(self):
}
# ADD INSTALLED PACKAGE LIST + THEIR VERSIONS
if self.training_params.log_installed_packages:
pkg_list = list(map(lambda pkg: str(pkg), _get_installed_distributions()))
additional_log_items["installed_packages"] = pkg_list
additional_log_items["installed_packages"] = get_installed_packages()

dataset_params = {
"train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset, "dataset_params") else None,
Expand Down