-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added recipe unifying script (#1560)
* working script + docs * renamed script * added tests * added tests to suite * save path resolved in main * used logger for script print * updated positional args * updated test * removed redundant cleanup in test --------- Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com>
- Loading branch information
1 parent
5316f51
commit 20ddf3d
Showing
5 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Script that saves a complete (i.e no inheritance from other yaml configuration files), | ||
.yaml file that can be ran on its own without the need to keep other configurations which the original | ||
file inherits from. | ||
Usage: | ||
python export_recipe config_name=cifar10_resnet save_path=/other/recipes/dir/my_complete_recipe.yaml -> saves cifar10_resnet_complete.yaml | ||
in current working directory | ||
python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml -> saves config_name_complete.yaml in current working directory | ||
python export_recipe --config_dir=/path/to/recipes/ config_name=my_recipe.yaml save_path=/other/recipes/dir/my_complete_recipe.yaml | ||
-> saves the complete recipe in /other/recipes/dir/my_complete_recipe.yaml | ||
:arg config_name: The .yaml config filename (can leave the .yaml postfix out, but not mandatory). | ||
:arg config_dir: The config directory path, as absolute file system path. | ||
When None, will use SG's recipe directory (i.e path/to/super_gradients/recipes) | ||
:arg: The output path for the complete .yaml file. | ||
When None, will use the current working directory. | ||
""" | ||
|
||
import argparse | ||
import os | ||
|
||
from super_gradients.common.environment.cfg_utils import export_recipe | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config_dir", type=str, default=None, help="The config directory path") | ||
parser.add_argument("config_name", type=str, help=".yaml filename") | ||
parser.add_argument("save_path", type=str, default=None, help="Destination path to the output .yaml file") | ||
args = parser.parse_args() | ||
if args.save_path is None: | ||
args.save_path = os.path.join(os.getcwd(), args.config_name).replace(".yaml", "") + "_complete.yaml" | ||
export_recipe(config_dir=args.config_dir, config_name=args.config_name, save_path=args.save_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import tempfile | ||
import unittest | ||
import os | ||
|
||
import hydra | ||
from hydra import initialize_config_dir, compose | ||
|
||
from super_gradients.common.environment.cfg_utils import export_recipe | ||
|
||
|
||
class TestExportRecipe(unittest.TestCase): | ||
def test_export_recipe(self): | ||
with tempfile.TemporaryDirectory() as td: | ||
save_path = os.path.join(td, "cifar10_resnet_complete.yaml") | ||
# Define the command to run your script | ||
export_recipe(config_name="cifar10_resnet", save_path=save_path) | ||
|
||
# Check if the output file was created | ||
self.assertTrue(os.path.exists(save_path)) | ||
|
||
with initialize_config_dir(config_dir=td, version_base="1.2"): | ||
cfg = compose(config_name="cifar10_resnet_complete.yaml") | ||
|
||
cfg = hydra.utils.instantiate(cfg) | ||
|
||
self.assertEqual(cfg.training_hyperparams.max_epochs, 250) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |