Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metrics logging to checkpoint and separate yaml file #1562

Merged
merged 24 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4b11551
Added metrics logging to checkpoint and separate yaml file
hakuryuu96 Oct 23, 2023
58c96ca
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 23, 2023
6b3999c
Fixed variable name in docstr to add_yaml_summary
hakuryuu96 Oct 23, 2023
fb8ae7d
Merged changes in feature/saving_metrics_to_yaml and making commit si…
hakuryuu96 Oct 23, 2023
9446489
Added method to abstract sglogger class and fixed some places comment…
hakuryuu96 Oct 24, 2023
80dc4b9
Added method to abstract sglogger class and fixed some places comment…
hakuryuu96 Oct 24, 2023
0c2e066
Fixed metric saving and added/fixed some tests
hakuryuu96 Oct 24, 2023
53b5294
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 24, 2023
cff188e
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 24, 2023
ca054c7
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 25, 2023
aea5860
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 25, 2023
c60612d
Changed casting function from __maybe_get_item_from_tensor to simple …
hakuryuu96 Oct 25, 2023
133b9f7
Merge branch 'feature/saving_metrics_to_yaml' of github.com:hakuryuu9…
hakuryuu96 Oct 25, 2023
2bf40e2
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 26, 2023
534fbfa
Left only metrics saved to checkpoint
hakuryuu96 Oct 26, 2023
34ab3fd
Removed test for yaml files
hakuryuu96 Oct 26, 2023
256dbae
Fixed some place in code according to comments in PR
hakuryuu96 Oct 26, 2023
cb84783
Changed float metric value back to int in test of schedulers :(
hakuryuu96 Oct 26, 2023
a96c2ab
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 27, 2023
e7b8d83
Merge branch 'master' into feature/saving_metrics_to_yaml
BloodAxe Oct 29, 2023
2ce24a7
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 30, 2023
412b62d
Fixed linters (trailing spaces)
hakuryuu96 Oct 30, 2023
6168a75
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 30, 2023
0a02eac
Merge branch 'master' into feature/saving_metrics_to_yaml
Louis-Dupont Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def _save_checkpoint(
self,
optimizer: torch.optim.Optimizer = None,
epoch: int = None,
train_metrics_dict: Optional[Dict[str, float]] = None,
validation_results_dict: Optional[Dict[str, float]] = None,
context: PhaseContext = None,
) -> None:
Expand All @@ -657,10 +658,28 @@ def _save_checkpoint(

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
metric = validation_results_dict[self.metric_to_watch]
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])

# create metrics dict to save
valid_metrics_titles = get_metrics_titles(self.valid_metrics)

all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"valid": {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles},
}

if train_metrics_dict is not None:
train_metrics_titles = get_metrics_titles(self.train_metrics)
all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles}

# BUILD THE state_dict
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch, "packages": get_installed_packages()}
state = {
"net": unwrap_model(self.net).state_dict(),
"acc": curr_tracked_metric,
"epoch": epoch,
"metrics": all_metrics,
"packages": get_installed_packages(),
}

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()
Expand All @@ -686,17 +705,16 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better):
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = metric
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)

if isinstance(metric, torch.Tensor):
metric = metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
Expand Down Expand Up @@ -1187,6 +1205,7 @@ def forward(self, inputs, targets):
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)

silent_mode = self.training_params.silent_mode or self.ddp_silent_mode

# METRICS
self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
Expand Down Expand Up @@ -1938,7 +1957,13 @@ def _write_to_disk_operations(

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down
13 changes: 11 additions & 2 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ class TestTrainer(unittest.TestCase):
def setUp(cls):
super_gradients.init_trainer()
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
cls.experiment_names = [
"test_train",
"test_save_load",
"test_load_w",
"test_load_w2",
"test_load_w3",
"test_checkpoint_content",
"analyze",
"test_yaml_metrics_present",
]
cls.training_params = {
"max_epochs": 1,
"silent_mode": True,
Expand Down Expand Up @@ -79,7 +88,7 @@ def test_checkpoint_content(self):
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "packages"]), sorted(list(ckpt.keys())))
self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "metrics", "packages"]), sorted(list(ckpt.keys())))
trainer._save_checkpoint()
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))
Expand Down