diff --git a/src/super_gradients/common/environment/cfg_utils.py b/src/super_gradients/common/environment/cfg_utils.py index a47936513e..d094d56552 100644 --- a/src/super_gradients/common/environment/cfg_utils.py +++ b/src/super_gradients/common/environment/cfg_utils.py @@ -9,8 +9,12 @@ from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf, open_dict, DictConfig +from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers from super_gradients.common.environment.path_utils import normalize_path from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path +from super_gradients.common.abstractions.abstract_logger import get_logger + +logger = get_logger(__name__) class RecipeNotFoundError(Exception): @@ -167,3 +171,27 @@ def override_cfg(cfg: DictConfig, overrides: Union[DictConfig, Dict[str, Any]]) """ with open_dict(cfg): # This is required to add new fields to existing config cfg.merge_with(overrides) + + +def export_recipe(config_name: str, save_path: str, config_dir: str = pkg_resources.resource_filename("super_gradients.recipes", "")): + """ + saves a complete (i.e no inheritance from other yaml configuration files), + .yaml file that can be ran on its own without the need to keep other configurations which the original + file inherits from. + + :param config_name: The .yaml config filename (can leave the .yaml postfix out, but not mandatory). + + :param save_path: The config directory path, as absolute file system path. + When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes) + + :param config_dir: The config directory path, as absolute file system path. + When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes) + + """ + # NEED TO REGISTER RESOLVERS FIRST + register_hydra_resolvers() + GlobalHydra.instance().clear() + with initialize_config_dir(config_dir=normalize_path(config_dir), version_base="1.2"): + cfg = compose(config_name=config_name) + OmegaConf.save(config=cfg, f=save_path) + logger.info(f"Successfully saved recipe at {save_path}. \n" f"Recipe content:\n {cfg}") diff --git a/src/super_gradients/scripts/export_recipe.py b/src/super_gradients/scripts/export_recipe.py new file mode 100644 index 0000000000..d4ae3aaaf9 --- /dev/null +++ b/src/super_gradients/scripts/export_recipe.py @@ -0,0 +1,38 @@ +""" +Script that saves a complete (i.e no inheritance from other yaml configuration files), + .yaml file that can be ran on its own without the need to keep other configurations which the original + file inherits from. + +Usage: + python export_recipe config_name=cifar10_resnet save_path=/other/recipes/dir/my_complete_recipe.yaml -> saves cifar10_resnet_complete.yaml + in current working directory + + python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml -> saves config_name_complete.yaml in current working directory + + python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml save_path=/other/recipes/dir/my_complete_recipe.yaml + -> saves the complete recipe in /other/recipes/dir/my_complete_recipe.yaml + +:arg config_name: The .yaml config filename (can leave the .yaml postfix out, but not mandatory). + +:arg config_dir: The config directory path, as absolute file system path. + When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes) + +:arg: The output path for the complete .yaml file. + When None, will use the current working directory. + +""" + +import argparse +import os + +from super_gradients.common.environment.cfg_utils import export_recipe + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_dir", type=str, default=None, help="The config directory path") + parser.add_argument("config_name", type=str, help=".yaml filename") + parser.add_argument("save_path", type=str, default=None, help="Destination path to the output .yaml file") + args = parser.parse_args() + if args.save_path is None: + args.save_path = os.path.join(os.getcwd(), args.config_name).replace(".yaml", "") + "_complete.yaml" + export_recipe(config_dir=args.config_dir, config_name=args.config_name, save_path=args.save_path) diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 821f820270..66809dc1a3 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -25,6 +25,7 @@ TestModelPredict, TestDeprecationDecorator, DynamicModelTests, + TestExportRecipe, TestMixedPrecisionDisabled, ) from tests.end_to_end_tests import TestTrainer @@ -166,6 +167,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationModelExport)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASPoseTests)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestExportRecipe)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMixedPrecisionDisabled)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DynamicModelTests)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertRecipeToCode)) diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index 405ede310a..c36d7a979d 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -5,6 +5,7 @@ from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride from tests.unit_tests.resume_training_test import ResumeTrainingTest from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest +from tests.unit_tests.test_export_recipe import TestExportRecipe from tests.unit_tests.train_after_test_test import CallTrainAfterTestTest from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest @@ -59,4 +60,5 @@ "TestDeprecationDecorator", "TestMixedPrecisionDisabled", "DynamicModelTests", + "TestExportRecipe", ] diff --git a/tests/unit_tests/test_export_recipe.py b/tests/unit_tests/test_export_recipe.py new file mode 100644 index 0000000000..5710161724 --- /dev/null +++ b/tests/unit_tests/test_export_recipe.py @@ -0,0 +1,30 @@ +import tempfile +import unittest +import os + +import hydra +from hydra import initialize_config_dir, compose + +from super_gradients.common.environment.cfg_utils import export_recipe + + +class TestExportRecipe(unittest.TestCase): + def test_export_recipe(self): + with tempfile.TemporaryDirectory() as td: + save_path = os.path.join(td, "cifar10_resnet_complete.yaml") + # Define the command to run your script + export_recipe(config_name="cifar10_resnet", save_path=save_path) + + # Check if the output file was created + self.assertTrue(os.path.exists(save_path)) + + with initialize_config_dir(config_dir=td, version_base="1.2"): + cfg = compose(config_name="cifar10_resnet_complete.yaml") + + cfg = hydra.utils.instantiate(cfg) + + self.assertEqual(cfg.training_hyperparams.max_epochs, 250) + + +if __name__ == "__main__": + unittest.main()