Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added recipe unifying script #1560

Merged
merged 20 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7f26ddc
working script + docs
shaydeci Oct 23, 2023
d08558e
renamed script
shaydeci Oct 23, 2023
0463b0c
added tests
shaydeci Oct 23, 2023
f663300
added tests to suite
shaydeci Oct 23, 2023
3ea4d29
Merge branch 'master' into feature/SG-1179_unify_recipe_script
shaydeci Oct 23, 2023
eaadf73
Merge branch 'master' into feature/SG-1179_unify_recipe_script
shaydeci Oct 23, 2023
2095ed5
Merge branch 'master' into feature/SG-1179_unify_recipe_script
shaydeci Oct 24, 2023
742ede0
Merge branch 'master' into feature/SG-1179_unify_recipe_script
BloodAxe Oct 24, 2023
5fa8827
save path resolved in main
shaydeci Oct 29, 2023
e4c9daa
Merge remote-tracking branch 'origin/feature/SG-1179_unify_recipe_scr…
shaydeci Oct 29, 2023
f791e29
used logger for script print
shaydeci Oct 29, 2023
6ad2348
updated positional args
shaydeci Oct 29, 2023
9211aac
Merge remote-tracking branch 'origin/master' into feature/SG-1179_uni…
shaydeci Oct 29, 2023
6ea5ff2
updated test
shaydeci Oct 29, 2023
6276b08
Merge remote-tracking branch 'origin/master' into feature/SG-1179_uni…
shaydeci Oct 30, 2023
df1cbc6
removed redundant cleanup in test
shaydeci Oct 30, 2023
ade92bf
Merge branch 'master' into feature/SG-1179_unify_recipe_script
BloodAxe Oct 30, 2023
a070d3a
Merge branch 'master' into feature/SG-1179_unify_recipe_script
Louis-Dupont Oct 30, 2023
20fcd42
Merge branch 'master' into feature/SG-1179_unify_recipe_script
Louis-Dupont Oct 31, 2023
d8be0cc
Merge branch 'master' into feature/SG-1179_unify_recipe_script
shaydeci Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/super_gradients/common/environment/cfg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, open_dict, DictConfig

from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path

Expand Down Expand Up @@ -167,3 +168,29 @@ def override_cfg(cfg: DictConfig, overrides: Union[DictConfig, Dict[str, Any]])
"""
with open_dict(cfg): # This is required to add new fields to existing config
cfg.merge_with(overrides)


def export_recipe(config_name: str, save_path: str = None, config_dir: str = pkg_resources.resource_filename("super_gradients.recipes", "")):
"""
saves a complete (i.e no inheritance from other yaml configuration files),
.yaml file that can be ran on its own without the need to keep other configurations which the original
file inherits from.

:param config_name: The .yaml config filename (can leave the .yaml postfix out, but not mandatory).

:param save_path: The config directory path, as absolute file system path.
When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes)

:param config_dir: The config directory path, as absolute file system path.
When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes)

"""
if save_path is None:
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
save_path = os.path.join(os.getcwd(), config_name).replace(".yaml", "") + "_complete.yaml"
# NEED TO REGISTER RESOLVERS FIRST
register_hydra_resolvers()
GlobalHydra.instance().clear()
with initialize_config_dir(config_dir=normalize_path(config_dir), version_base="1.2"):
cfg = compose(config_name=config_name)
OmegaConf.save(config=cfg, f=save_path)
print(f"Successfully saved recipe at {save_path}. \n" f"Recipe content:\n {cfg}")
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
34 changes: 34 additions & 0 deletions src/super_gradients/scripts/export_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Script that saves a complete (i.e no inheritance from other yaml configuration files),
.yaml file that can be ran on its own without the need to keep other configurations which the original
file inherits from.

Usage:
python export_recipe --config_name=cifar10_resnet save_path=/other/recipes/dir/my_complete_recipe.yaml -> saves cifar10_resnet_complete.yaml
in current working directory

python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml -> saves config_name_complete.yaml in current working directory

python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml save_path=/other/recipes/dir/my_complete_recipe.yaml
-> saves the complete recipe in /other/recipes/dir/my_complete_recipe.yaml

:arg config_name: The .yaml config filename (can leave the .yaml postfix out, but not mandatory).

:arg config_dir: The config directory path, as absolute file system path.
When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes)

:arg: The output path for the complete .yaml file.
When None, will use the current working directory.

"""

import argparse
from super_gradients.common.environment.cfg_utils import export_recipe

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_dir", type=str, default=None, help="The config directory path")
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--config_name", type=str, help=".yaml filename")
parser.add_argument("--save_path", type=str, default=None, help="Destination path to the output .yaml file")
args = parser.parse_args()
export_recipe(config_dir=args.config_dir, config_name=args.config_name, save_path=args.save_path)
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TestPostPredictionCallback,
TestModelPredict,
TestDeprecationDecorator,
TestExportRecipe,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
Expand Down Expand Up @@ -162,6 +163,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationModelExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASPoseTests))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestExportRecipe))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride
from tests.unit_tests.resume_training_test import ResumeTrainingTest
from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
from tests.unit_tests.test_export_recipe import TestExportRecipe
from tests.unit_tests.train_after_test_test import CallTrainAfterTestTest
from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
Expand Down Expand Up @@ -55,4 +56,5 @@
"TestPostPredictionCallback",
"TestModelPredict",
"TestDeprecationDecorator",
"TestExportRecipe",
]
51 changes: 51 additions & 0 deletions tests/unit_tests/test_export_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tempfile
import unittest
import os

import hydra
from hydra import initialize_config_dir, compose

from super_gradients.common.environment.cfg_utils import export_recipe


class TestExportRecipe(unittest.TestCase):
def test_export_recipe(self):
# Define the command to run your script
export_recipe(config_name="cifar10_resnet")

# Check if the output file was created
expected_output_path = "cifar10_resnet_complete.yaml"
self.assertTrue(os.path.exists(expected_output_path))

with initialize_config_dir(config_dir=os.getcwd(), version_base="1.2"):
cfg = compose(config_name="cifar10_resnet_complete.yaml")

cfg = hydra.utils.instantiate(cfg)

self.assertEqual(cfg.training_hyperparams.max_epochs, 250)

# Clean up the created file after the test
os.remove(expected_output_path)

def test_export_recipe_with_save_path(self):
with tempfile.TemporaryDirectory() as td:
save_path = os.path.join(td, "cifar10_resnet_complete.yaml")
# Define the command to run your script
export_recipe(config_name="cifar10_resnet", save_path=save_path)

# Check if the output file was created
self.assertTrue(os.path.exists(save_path))

with initialize_config_dir(config_dir=td, version_base="1.2"):
cfg = compose(config_name="cifar10_resnet_complete.yaml")

cfg = hydra.utils.instantiate(cfg)

self.assertEqual(cfg.training_hyperparams.max_epochs, 250)

# Clean up the created file after the test
os.remove(save_path)
shaydeci marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main()