Skip to content

Commit

Permalink
first version (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Oct 3, 2023
1 parent 8c7dc64 commit 5b6d74b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/super_gradients/common/environment/package_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pkg_resources
from typing import Dict


def get_installed_packages() -> Dict[str, str]:
"""Map all the installed packages to their version."""
return {package.key.lower(): package.version for package in pkg_resources.working_set}
3 changes: 2 additions & 1 deletion src/super_gradients/sanity_check/env_sanity_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_main_process
from super_gradients.common.environment.package_utils import get_installed_packages

logger = get_logger(__name__, "DEBUG")

Expand Down Expand Up @@ -79,7 +80,7 @@ def check_packages():
"""
test_name = "installed packages"

installed_packages = {package.key.lower(): package.version for package in pkg_resources.working_set}
installed_packages = get_installed_packages()
requirements = get_requirements(use_pro_requirements="deci-platform-client" in installed_packages)

if requirements is None:
Expand Down
6 changes: 3 additions & 3 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn
import torchmetrics
from omegaconf import DictConfig, OmegaConf
from piptools.scripts.sync import _get_installed_distributions

from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, SequentialSampler
Expand All @@ -36,6 +36,7 @@
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.environment.package_utils import get_installed_packages

from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.datasets.samplers import RepeatAugSampler
Expand Down Expand Up @@ -1800,8 +1801,7 @@ def _get_hyper_param_config(self):
}
# ADD INSTALLED PACKAGE LIST + THEIR VERSIONS
if self.training_params.log_installed_packages:
pkg_list = list(map(lambda pkg: str(pkg), _get_installed_distributions()))
additional_log_items["installed_packages"] = pkg_list
additional_log_items["installed_packages"] = get_installed_packages()

dataset_params = {
"train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset, "dataset_params") else None,
Expand Down

0 comments on commit 5b6d74b

Please sign in to comment.