Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metrics logging to checkpoint and separate yaml file #1562

Merged
merged 24 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4b11551
Added metrics logging to checkpoint and separate yaml file
hakuryuu96 Oct 23, 2023
58c96ca
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 23, 2023
6b3999c
Fixed variable name in docstr to add_yaml_summary
hakuryuu96 Oct 23, 2023
fb8ae7d
Merged changes in feature/saving_metrics_to_yaml and making commit si…
hakuryuu96 Oct 23, 2023
9446489
Added method to abstract sglogger class and fixed some places comment…
hakuryuu96 Oct 24, 2023
80dc4b9
Added method to abstract sglogger class and fixed some places comment…
hakuryuu96 Oct 24, 2023
0c2e066
Fixed metric saving and added/fixed some tests
hakuryuu96 Oct 24, 2023
53b5294
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 24, 2023
cff188e
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 24, 2023
ca054c7
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 25, 2023
aea5860
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 25, 2023
c60612d
Changed casting function from __maybe_get_item_from_tensor to simple …
hakuryuu96 Oct 25, 2023
133b9f7
Merge branch 'feature/saving_metrics_to_yaml' of github.com:hakuryuu9…
hakuryuu96 Oct 25, 2023
2bf40e2
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 26, 2023
534fbfa
Left only metrics saved to checkpoint
hakuryuu96 Oct 26, 2023
34ab3fd
Removed test for yaml files
hakuryuu96 Oct 26, 2023
256dbae
Fixed some place in code according to comments in PR
hakuryuu96 Oct 26, 2023
cb84783
Changed float metric value back to int in test of schedulers :(
hakuryuu96 Oct 26, 2023
a96c2ab
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 27, 2023
e7b8d83
Merge branch 'master' into feature/saving_metrics_to_yaml
BloodAxe Oct 29, 2023
2ce24a7
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 30, 2023
412b62d
Fixed linters (trailing spaces)
hakuryuu96 Oct 30, 2023
6168a75
Merge branch 'master' into feature/saving_metrics_to_yaml
hakuryuu96 Oct 30, 2023
0a02eac
Merge branch 'master' into feature/saving_metrics_to_yaml
Louis-Dupont Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/super_gradients/common/sg_loggers/abstract_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod, ABC
from typing import Union, Any
from typing import Union, Any, Optional

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -137,6 +137,17 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None):
"""
raise NotImplementedError

@abstractmethod
def add_yaml_summary(self, tag: str, summary_dict: dict, global_step: Optional[int] = None):
"""
Add any dict as yaml to SGLogger.

:param tag: Identifier of the summary.
:param summary_dict: Checkpoint summary_dict.
:param global_step: Epoch number.
"""
raise NotImplementedError

@abstractmethod
def add_file(self, file_name: str = None):
"""
Expand Down
18 changes: 17 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import os
import signal
import time
from typing import Union, Any
from typing import Union, Any, Optional

import matplotlib.pyplot as plt
import numpy as np
import psutil
import torch
from PIL import Image
import shutil
import yaml

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
Expand Down Expand Up @@ -329,6 +330,21 @@ def _save_checkpoint(self, path: str, state_dict: dict) -> None:
if self.save_checkpoints_remote:
self.model_checkpoints_data_interface.save_remote_checkpoints_file(self.experiment_name, self._local_dir, name)

@multi_process_safe
def add_yaml_summary(self, tag: str, summary_dict: dict, global_step: Optional[int] = None) -> None:
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved
"""Saves any dict to <experiment_folder>/<tag>.yaml
Initially added for saving metrics to yaml to store it in something easily parsable (easier than .pth checkpoints),
but who knows what it will be suited for later.

:param tag: Identifier of the summary.
:param summary_dict: Checkpoint summary_dict.
:param global_step: Epoch number.
"""

name = tag + (f"_{global_step}" if global_step is not None else "") + ".yml"
with open(os.path.join(self._local_dir, name), "w", encoding="utf-8") as outfile:
yaml.safe_dump(summary_dict, outfile, default_flow_style=False)

def add(self, tag: str, obj: Any, global_step: int = None):
pass

Expand Down
50 changes: 42 additions & 8 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,16 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context
# RUN PHASE CALLBACKS
self.phase_callback_handler.on_train_batch_gradient_step_end(context)

def __maybe_get_item_from_tensor(self, value: Union[float, torch.Tensor]) -> float:
if isinstance(value, torch.Tensor):
return value.item()
return value
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved

def _save_checkpoint(
self,
optimizer: torch.optim.Optimizer = None,
epoch: int = None,
train_metrics_dict: Optional[Dict[str, float]] = None,
validation_results_dict: Optional[Dict[str, float]] = None,
context: PhaseContext = None,
) -> None:
Expand All @@ -655,10 +661,27 @@ def _save_checkpoint(

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
metric = validation_results_dict[self.metric_to_watch]
curr_tracked_metric = validation_results_dict[self.metric_to_watch]
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved

# create metrics dict to save
valid_metrics_titles = get_metrics_titles(self.valid_metrics)

all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"metrics": {
"valid": {metric_name: self.__maybe_get_item_from_tensor(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles}
},
}

train_metrics_titles = get_metrics_titles(self.train_metrics)
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved

if train_metrics_dict is not None:
all_metrics["metrics"]["train"] = {
metric_name: self.__maybe_get_item_from_tensor(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles
}

# BUILD THE state_dict
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}
state = {"net": unwrap_model(self.net).state_dict(), "acc": curr_tracked_metric, "epoch": epoch, "all_metrics": all_metrics}
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()
Expand All @@ -678,23 +701,27 @@ def _save_checkpoint(

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_latest", summary_dict=all_metrics)

# SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
if epoch in self.training_params.save_ckpt_epoch_list:
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_epoch", summary_dict=all_metrics, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better):
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = metric
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_best", summary_dict=all_metrics)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)

if isinstance(metric, torch.Tensor):
metric = metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
curr_tracked_metric = self.__maybe_get_item_from_tensor(curr_tracked_metric)
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
Expand Down Expand Up @@ -1185,6 +1212,7 @@ def forward(self, inputs, targets):
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)

silent_mode = self.training_params.silent_mode or self.ddp_silent_mode

# METRICS
self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
Expand Down Expand Up @@ -1936,7 +1964,13 @@ def _write_to_disk_operations(

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down
23 changes: 21 additions & 2 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ class TestTrainer(unittest.TestCase):
def setUp(cls):
super_gradients.init_trainer()
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
cls.experiment_names = [
"test_train",
"test_save_load",
"test_load_w",
"test_load_w2",
"test_load_w3",
"test_checkpoint_content",
"analyze",
"test_yaml_metrics_present",
]
cls.training_params = {
"max_epochs": 1,
"silent_mode": True,
Expand Down Expand Up @@ -79,11 +88,21 @@ def test_checkpoint_content(self):
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertListEqual(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict"], list(ckpt.keys()))
self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "all_metrics"]), sorted(list(ckpt.keys())))
trainer._save_checkpoint()
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))

def test_yaml_metrics_present(self):
trainer, model = self.get_classification_trainer(self.experiment_names[7])
params = self.training_params.copy()
params["save_ckpt_epoch_list"] = [1]
trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
yml_filename = ["metrics_best.yml", "metrics_epoch_1.yml", "metrics_latest.yml"]
yml_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in yml_filename]
for yml in yml_paths:
assert os.path.exists(yml)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/unit_tests/test_train_with_torch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def update(self, *args, **kwargs) -> None:
pass

def compute(self):
return 1
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved
return 1.0


class TrainWithTorchSchedulerTest(unittest.TestCase):
Expand Down