Skip to content

Commit

Permalink
Way to fix bug with validation frequency (#1601)
Browse files Browse the repository at this point in the history
* Way to fix bug with validation frequency

* Fixed test, the state of net was rewritten

* Added validating the latest epoch and epochs from save_ckpt_epoch_list

* Added one more testcase to check wether latest notdivisible epoch has valid in metrics

* Following the SRP recommendation...

* Which inference time exactly

* Fixed incorrect keyword in writing function

* Missing brackets around epoch+1 in valid run check function.

* Final fixes hopefully :)

* Fixed trainer to add scalars only in main process

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
2 people authored and Louis-Dupont committed Nov 8, 2023
1 parent 0292e05 commit e68a3a2
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 40 deletions.
100 changes: 60 additions & 40 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,22 @@ def _init_arch_params(self) -> None:
if arch_params is not None:
self.arch_params.override(**arch_params.to_dict())

def _should_run_validation_for_epoch(self, epoch: int) -> bool:
"""
Method returns true if the validation should to be calculated on this epoch (starting from 0).
We need to calculate validation if
1) the epoch is divisible by #run_validation_freq
2) if epoch is last
3) if epoch is in self.save_ckpt_epoch_list
"""

is_run_val_freq_divisible = ((epoch + 1) % self.run_validation_freq) == 0
is_last_epoch = (epoch + 1) == self.max_epochs
is_in_checkpoint_list = (epoch + 1) in self.training_params.save_ckpt_epoch_list

return is_run_val_freq_divisible or is_last_epoch or is_in_checkpoint_list

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def train(
self,
Expand Down Expand Up @@ -1262,9 +1278,14 @@ def forward(self, inputs, targets):
logger.warning("[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")

self.run_validation_freq = self.training_params.run_validation_freq

if self.max_epochs % self.run_validation_freq != 0:
logger.warning(
"max_epochs is not divisible by run_validation_freq. "
"Please check the training parameters and ensure that run_validation_freq has been set correctly."
)
self.run_test_freq = self.training_params.run_test_freq

inf_time = 0
timer = core_utils.Timer(device_config.device)

# IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
Expand Down Expand Up @@ -1438,6 +1459,7 @@ def forward(self, inputs, targets):
for epoch in range(self.start_epoch, self.max_epochs):
# broadcast_from_master is necessary here, since in DDP mode, only the master node will
# receive the Ctrl-C signal, and we want all nodes to stop training.
timer.start()
if broadcast_from_master(context.stop_training):
logger.info("Request to stop training has been received, stopping training")
break
Expand Down Expand Up @@ -1485,13 +1507,18 @@ def forward(self, inputs, targets):
keep_model = self.net
self.net = self.ema_model.ema

train_inf_time = timer.stop()
self._write_scalars_to_logger(metrics=train_metrics_dict, epoch=1 + epoch, inference_time=train_inf_time, tag="Train")

# RUN TEST ON VALIDATION SET EVERY self.run_validation_freq EPOCHS
valid_metrics_dict = {}
if (epoch + 1) % self.run_validation_freq == 0:
should_run_validation = self._should_run_validation_for_epoch(epoch)

if should_run_validation:
self.phase_callback_handler.on_validation_loader_start(context)
timer.start()
valid_metrics_dict = self._validate_epoch(context=context, silent_mode=silent_mode)
inf_time = timer.stop()
val_inf_time = timer.stop()

self.valid_monitored_values = sg_trainer_utils.update_monitored_values_dict(
monitored_values_dict=self.valid_monitored_values,
Expand All @@ -1502,11 +1529,16 @@ def forward(self, inputs, targets):
context.update_context(metrics_dict=valid_metrics_dict)
self.phase_callback_handler.on_validation_loader_end(context)

self._write_scalars_to_logger(metrics=valid_metrics_dict, epoch=1 + epoch, inference_time=val_inf_time, tag="Valid")

test_metrics_dict = {}
if (epoch + 1) % self.run_test_freq == 0:
self.phase_callback_handler.on_test_loader_start(context)
test_inf_time = 0.0
for dataset_name, dataloader in self.test_loaders.items():
timer.start()
dataset_metrics_dict = self._test_epoch(data_loader=dataloader, context=context, silent_mode=silent_mode, dataset_name=dataset_name)
test_inf_time += timer.stop()
dataset_metrics_dict_with_name = {
f"{dataset_name}:{metric_name}": metric_value for metric_name, metric_value in dataset_metrics_dict.items()
}
Expand All @@ -1519,20 +1551,23 @@ def forward(self, inputs, targets):
context.update_context(metrics_dict=test_metrics_dict)
self.phase_callback_handler.on_test_loader_end(context)

self._write_scalars_to_logger(metrics=test_metrics_dict, epoch=1 + epoch, inference_time=test_inf_time, tag="Test")

if self.ema:
self.net = keep_model

if not self.ddp_silent_mode:
self.sg_logger.add_scalars(tag_scalar_dict=self._epoch_start_logging_values, global_step=1 + epoch)

# SAVING AND LOGGING OCCURS ONLY IN THE MAIN PROCESS (IN CASES THERE ARE SEVERAL PROCESSES - DDP)
self._write_to_disk_operations(
train_metrics_dict=train_metrics_dict,
validation_results_dict=valid_metrics_dict,
test_metrics_dict=test_metrics_dict,
lr_dict=self._epoch_start_logging_values,
inf_time=inf_time,
epoch=epoch,
context=context,
)
if should_run_validation and self.training_params.save_model:
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=valid_metrics_dict,
context=context,
)
self.sg_logger.upload()

if not silent_mode:
Expand Down Expand Up @@ -1938,35 +1973,20 @@ def _get_hyper_param_config(self):
}
return hyper_param_config

def _write_to_disk_operations(
self,
train_metrics_dict: dict,
validation_results_dict: dict,
test_metrics_dict: dict,
lr_dict: dict,
inf_time: float,
epoch: int,
context: PhaseContext,
):
"""Run the various logging operations, e.g.: log file, Tensorboard, save checkpoint etc."""
result_dict = {
"Inference Time": inf_time,
**{f"Train_{k}": v for k, v in train_metrics_dict.items()},
**{f"Valid_{k}": v for k, v in validation_results_dict.items()},
**{f"Test_{k}": v for k, v in test_metrics_dict.items()},
}
self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=epoch)
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
def _write_scalars_to_logger(self, metrics: dict, epoch: int, inference_time: float, tag: str) -> None:
"""
Method for writing metrics and LR info to logger.
# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)
:param metrics: (dict) dict of metrics..
:param epoch: (inf) 1-based number of epoch.
:param inference_time: (float) time of inference.
:param tag: (str) tag for writing to logger (rule of thumb: Train/Test/Valid)
"""

if not self.ddp_silent_mode:
info_dict = {f"{tag} Inference Time": inference_time, **{f"{tag}_{k}": v for k, v in metrics.items()}}

self.sg_logger.add_scalars(tag_scalar_dict=info_dict, global_step=epoch)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down
31 changes: 31 additions & 0 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,37 @@ def test_checkpoint_content(self):
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))

def test_validation_frequency_divisible(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 4
training_params["run_validation_freq"] = 2
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
metrics = {}
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
metrics[ckpt_path] = ckpt["metrics"]
self.assertTrue(metrics[ckpt_paths[0]] == metrics[ckpt_paths[1]])

def test_validation_frequency_and_save_ckpt_list(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 5
training_params["run_validation_freq"] = 3
training_params["save_ckpt_epoch_list"] = [1]
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_epoch_1.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertTrue("valid" in ckpt["metrics"].keys())


if __name__ == "__main__":
unittest.main()

0 comments on commit e68a3a2

Please sign in to comment.