Skip to content

Commit

Permalink
Merged changes in feature/saving_metrics_to_yaml and making commit si…
Browse files Browse the repository at this point in the history
…gned
  • Loading branch information
hakuryuu96 committed Oct 23, 2023
1 parent f153c61 commit fb8ae7d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
19 changes: 18 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import os
import signal
import time
from typing import Union, Any
from typing import Union, Any, Optional

import matplotlib.pyplot as plt
import numpy as np
import psutil
import torch
from PIL import Image
import shutil
import yaml

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
Expand Down Expand Up @@ -329,6 +330,22 @@ def _save_checkpoint(self, path: str, state_dict: dict) -> None:
if self.save_checkpoints_remote:
self.model_checkpoints_data_interface.save_remote_checkpoints_file(self.experiment_name, self._local_dir, name)

@multi_process_safe
def add_yaml_summary(self, tag: str, summary_dict: dict, global_step: Optional[int] = None) -> None:
"""Saves any dict to <experiment_folder>/<tag>.yaml
Initially added for saving metrics to yaml to store it in something easily parsable (easier than .pth checkpoints),
but who knows what it will be suited for later.
:param tag: Identifier of the summary.
:param summary_dict: Checkpoint summary_dict.
:param global_step: Epoch number.
"""

name = tag + (f"_{global_step}" if global_step is not None else "") + ".yml"
with open(os.path.join(self._local_dir, name), "w") as outfile:
yaml.dump(summary_dict, outfile, default_flow_style=False)
outfile.close()

def add(self, tag: str, obj: Any, global_step: int = None):
pass

Expand Down
37 changes: 29 additions & 8 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def _save_checkpoint(
self,
optimizer: torch.optim.Optimizer = None,
epoch: int = None,
train_metrics_dict: Optional[Dict[str, float]] = None,
validation_results_dict: Optional[Dict[str, float]] = None,
context: PhaseContext = None,
) -> None:
Expand All @@ -654,10 +655,18 @@ def _save_checkpoint(

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
metric = validation_results_dict[self.metric_to_watch]
curr_tracked_metric = validation_results_dict[self.metric_to_watch]

# create metrics dict to save
all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"metrics": {"valid": {metric_name: validation_results_dict[metric_name] for metric_name in self.valid_metrics}},
}
if train_metrics_dict is not None:
all_metrics["metrics"]["train"] = {metric_name: train_metrics_dict[metric_name].item() for metric_name in self.train_metrics}

# BUILD THE state_dict
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}
state = {"net": unwrap_model(self.net).state_dict(), "acc": curr_tracked_metric, "epoch": epoch, "all_metrics": all_metrics}

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()
Expand All @@ -677,23 +686,28 @@ def _save_checkpoint(

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_latest", summary_dict=all_metrics)

# SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
if epoch in self.training_params.save_ckpt_epoch_list:
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_epoch", summary_dict=all_metrics, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better):
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = metric
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_best", summary_dict=all_metrics)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)

if isinstance(metric, torch.Tensor):
metric = metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
if isinstance(curr_tracked_metric, torch.Tensor):
curr_tracked_metric = curr_tracked_metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
Expand Down Expand Up @@ -1184,6 +1198,7 @@ def forward(self, inputs, targets):
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)

silent_mode = self.training_params.silent_mode or self.ddp_silent_mode

# METRICS
self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
Expand Down Expand Up @@ -1925,7 +1940,13 @@ def _write_to_disk_operations(

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down

0 comments on commit fb8ae7d

Please sign in to comment.