Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Way to fix bug with validation frequency #1601

Merged
merged 14 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,13 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": self.net.state_dict()}, global_step=epoch)
return

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])
# Check whether we have to attempt the validation results
# and subtract 1 from epoch because epoch+1 is passed to this function
is_validation_calculated = self._if_need_to_calc_validation(epoch - 1)
curr_tracked_metric = None

# create metrics dict to save
valid_metrics_titles = get_metrics_titles(self.valid_metrics)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved

all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"valid": {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles},
}
# Create metrics dict to save
all_metrics = {"tracked_metric_name": self.metric_to_watch}

if train_metrics_dict is not None:
train_metrics_titles = get_metrics_titles(self.train_metrics)
Expand All @@ -675,7 +671,6 @@ def _save_checkpoint(
# BUILD THE state_dict
state = {
"net": unwrap_model(self.net).state_dict(),
"acc": curr_tracked_metric,
"epoch": epoch,
"metrics": all_metrics,
"packages": get_installed_packages(),
Expand All @@ -697,30 +692,41 @@ def _save_checkpoint(
if self._torch_lr_scheduler is not None:
state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler)

if is_validation_calculated:
valid_metrics_titles = get_metrics_titles(self.valid_metrics)

state["metrics"]["valid"] = {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles}

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
curr_tracked_metric = float(validation_results_dict[self.metric_to_watch])
state["acc"] = curr_tracked_metric

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)

# SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
if epoch in self.training_params.save_ckpt_epoch_list:
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
if is_validation_calculated:
# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))
# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)

state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)
state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_dict=validation_results_dict)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)

def _prep_net_for_train(self) -> None:
if self.arch_params is None:
Expand Down Expand Up @@ -756,6 +762,13 @@ def _init_arch_params(self) -> None:
if arch_params is not None:
self.arch_params.override(**arch_params.to_dict())

def _if_need_to_calc_validation(self, epoch: int) -> bool:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
is_run_val_freq_divisible = not bool((epoch + 1) % self.run_validation_freq)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
is_last_epoch = (epoch + 1) == self.max_epochs
is_in_checkpoint_list = (epoch + 1) in self.training_params.save_ckpt_epoch_list

return is_run_val_freq_divisible or is_last_epoch or is_in_checkpoint_list

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def train(
self,
Expand Down Expand Up @@ -1262,6 +1275,12 @@ def forward(self, inputs, targets):
logger.warning("[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")

self.run_validation_freq = self.training_params.run_validation_freq

if self.max_epochs % self.run_validation_freq != 0:
logger.warning(
"max_epochs is not divisible by run_validation_freq. "
"Please check the training parameters and ensure that run_validation_freq has been set correctly."
)
self.run_test_freq = self.training_params.run_test_freq

inf_time = 0
Expand Down Expand Up @@ -1489,7 +1508,14 @@ def forward(self, inputs, targets):

# RUN TEST ON VALIDATION SET EVERY self.run_validation_freq EPOCHS
valid_metrics_dict = {}
if (epoch + 1) % self.run_validation_freq == 0:
# We need to calculate validation if
# 1) the epoch is divisible by #run_validation_freq
# 2) if epoch is last
# 3) if epoch is in self.save_ckpt_epoch_list

is_validation_calculated = self._if_need_to_calc_validation(epoch)

if is_validation_calculated:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
self.phase_callback_handler.on_validation_loader_start(context)
timer.start()
valid_metrics_dict = self._validate_epoch(context=context, silent_mode=silent_mode)
Expand Down
31 changes: 31 additions & 0 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,37 @@ def test_checkpoint_content(self):
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))

def test_validation_frequency_divisible(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 4
training_params["run_validation_freq"] = 2
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
metrics = {}
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
metrics[ckpt_path] = ckpt["metrics"]
self.assertTrue(metrics[ckpt_paths[0]] == metrics[ckpt_paths[1]])

def test_validation_frequency_and_save_ckpt_list(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 5
training_params["run_validation_freq"] = 3
training_params["save_ckpt_epoch_list"] = [1]
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_epoch_1.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertTrue("valid" in ckpt["metrics"].keys())


if __name__ == "__main__":
unittest.main()