Skip to content

Commit

Permalink
Added metrics logging to checkpoint and separate yaml file (#1562)
Browse files Browse the repository at this point in the history
* Added metrics logging to checkpoint and separate yaml file

* Fixed variable name in docstr to add_yaml_summary

* Merged changes in feature/saving_metrics_to_yaml and making commit signed

* Added method to abstract sglogger class and fixed some places commented in DPR

* Fixed metric saving and added/fixed some tests

* Changed casting function from __maybe_get_item_from_tensor to simple float())

* Left only metrics saved to checkpoint

* Removed test for yaml files

* Fixed some place in code according to comments in PR

* Changed float metric value back to int in test of schedulers :(

* Fixed linters (trailing spaces)

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 31, 2023
1 parent 40b1f2c commit e91be8f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
43 changes: 34 additions & 9 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def _save_checkpoint(
self,
optimizer: torch.optim.Optimizer = None,
epoch: int = None,
train_metrics_dict: Optional[Dict[str, float]] = None,
validation_results_dict: Optional[Dict[str, float]] = None,
context: PhaseContext = None,
) -> None:
Expand All @@ -657,10 +658,28 @@ def _save_checkpoint(

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
metric = validation_results_dict[self.metric_to_watch]
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])

# create metrics dict to save
valid_metrics_titles = get_metrics_titles(self.valid_metrics)

all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"valid": {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles},
}

if train_metrics_dict is not None:
train_metrics_titles = get_metrics_titles(self.train_metrics)
all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles}

# BUILD THE state_dict
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch, "packages": get_installed_packages()}
state = {
"net": unwrap_model(self.net).state_dict(),
"acc": curr_tracked_metric,
"epoch": epoch,
"metrics": all_metrics,
"packages": get_installed_packages(),
}

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()
Expand All @@ -686,17 +705,16 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better):
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = metric
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)

if isinstance(metric, torch.Tensor):
metric = metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
Expand Down Expand Up @@ -1187,6 +1205,7 @@ def forward(self, inputs, targets):
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)

silent_mode = self.training_params.silent_mode or self.ddp_silent_mode

# METRICS
self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
Expand Down Expand Up @@ -1938,7 +1957,13 @@ def _write_to_disk_operations(

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down
13 changes: 11 additions & 2 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ class TestTrainer(unittest.TestCase):
def setUp(cls):
super_gradients.init_trainer()
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
cls.experiment_names = [
"test_train",
"test_save_load",
"test_load_w",
"test_load_w2",
"test_load_w3",
"test_checkpoint_content",
"analyze",
"test_yaml_metrics_present",
]
cls.training_params = {
"max_epochs": 1,
"silent_mode": True,
Expand Down Expand Up @@ -79,7 +88,7 @@ def test_checkpoint_content(self):
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "packages"]), sorted(list(ckpt.keys())))
self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "metrics", "packages"]), sorted(list(ckpt.keys())))
trainer._save_checkpoint()
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))
Expand Down

0 comments on commit e91be8f

Please sign in to comment.