Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Way to fix bug with validation frequency #1601

Merged
merged 14 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,12 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": self.net.state_dict()}, global_step=epoch)
return

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])
# Check whether we have to attempt the validation results in case if (1+epoch) % val_freq != 0
is_validation_calculated = not bool(epoch % self.run_validation_freq)
curr_tracked_metric = None

# create metrics dict to save
valid_metrics_titles = get_metrics_titles(self.valid_metrics)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved

all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"valid": {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles},
}
# Create metrics dict to save
all_metrics = {"tracked_metric_name": self.metric_to_watch}

if train_metrics_dict is not None:
train_metrics_titles = get_metrics_titles(self.train_metrics)
Expand All @@ -675,7 +670,6 @@ def _save_checkpoint(
# BUILD THE state_dict
state = {
"net": unwrap_model(self.net).state_dict(),
"acc": curr_tracked_metric,
"epoch": epoch,
"metrics": all_metrics,
"packages": get_installed_packages(),
Expand All @@ -697,30 +691,41 @@ def _save_checkpoint(
if self._torch_lr_scheduler is not None:
state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler)

if is_validation_calculated:
valid_metrics_titles = get_metrics_titles(self.valid_metrics)

state["metrics"]["valid"] = {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles}

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])
state["acc"] = curr_tracked_metric

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)

# SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
if epoch in self.training_params.save_ckpt_epoch_list:
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
if is_validation_calculated:
# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))
# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)

state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)
state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)

def _prep_net_for_train(self) -> None:
if self.arch_params is None:
Expand Down Expand Up @@ -1262,6 +1267,9 @@ def forward(self, inputs, targets):
logger.warning("[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")

self.run_validation_freq = self.training_params.run_validation_freq

if self.max_epochs % self.run_validation_freq != 0:
logger.warning("max_epochs is not divisible by run_validation_freq. " "The model on the last epoch wouldn't be checked whether it is the best.")
self.run_test_freq = self.training_params.run_test_freq

inf_time = 0
Expand Down
16 changes: 16 additions & 0 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def test_checkpoint_content(self):
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))

def test_validation_frequency_divisible(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 4
training_params["run_validation_freq"] = 2
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
metrics = {}
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
metrics[ckpt_path] = ckpt["metrics"]
assert metrics[ckpt_paths[0]] == metrics[ckpt_paths[1]]
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main()