Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Way to fix bug with validation frequency #1601

Merged
merged 14 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 60 additions & 40 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,22 @@ def _init_arch_params(self) -> None:
if arch_params is not None:
self.arch_params.override(**arch_params.to_dict())

def _should_run_validation_for_epoch(self, epoch: int) -> bool:
"""
Method returns true if the validation should to be calculated on this epoch (starting from 0).

We need to calculate validation if
1) the epoch is divisible by #run_validation_freq
2) if epoch is last
3) if epoch is in self.save_ckpt_epoch_list
"""

is_run_val_freq_divisible = ((epoch + 1) % self.run_validation_freq) == 0
is_last_epoch = (epoch + 1) == self.max_epochs
is_in_checkpoint_list = (epoch + 1) in self.training_params.save_ckpt_epoch_list

return is_run_val_freq_divisible or is_last_epoch or is_in_checkpoint_list

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def train(
self,
Expand Down Expand Up @@ -1262,9 +1278,14 @@ def forward(self, inputs, targets):
logger.warning("[Warning] Checkpoint does not include EMA weights, continuing training without EMA.")

self.run_validation_freq = self.training_params.run_validation_freq

if self.max_epochs % self.run_validation_freq != 0:
logger.warning(
"max_epochs is not divisible by run_validation_freq. "
"Please check the training parameters and ensure that run_validation_freq has been set correctly."
)
self.run_test_freq = self.training_params.run_test_freq

inf_time = 0
timer = core_utils.Timer(device_config.device)

# IF THE LR MODE IS NOT DEFAULT TAKE IT FROM THE TRAINING PARAMS
Expand Down Expand Up @@ -1438,6 +1459,7 @@ def forward(self, inputs, targets):
for epoch in range(self.start_epoch, self.max_epochs):
# broadcast_from_master is necessary here, since in DDP mode, only the master node will
# receive the Ctrl-C signal, and we want all nodes to stop training.
timer.start()
if broadcast_from_master(context.stop_training):
logger.info("Request to stop training has been received, stopping training")
break
Expand Down Expand Up @@ -1485,13 +1507,18 @@ def forward(self, inputs, targets):
keep_model = self.net
self.net = self.ema_model.ema

train_inf_time = timer.stop()
self._write_scalars_to_logger(metrics=train_metrics_dict, epoch=1 + epoch, inference_time=train_inf_time, tag="Train")

# RUN TEST ON VALIDATION SET EVERY self.run_validation_freq EPOCHS
valid_metrics_dict = {}
if (epoch + 1) % self.run_validation_freq == 0:
should_run_validation = self._should_run_validation_for_epoch(epoch)

if should_run_validation:
self.phase_callback_handler.on_validation_loader_start(context)
timer.start()
valid_metrics_dict = self._validate_epoch(context=context, silent_mode=silent_mode)
inf_time = timer.stop()
val_inf_time = timer.stop()

self.valid_monitored_values = sg_trainer_utils.update_monitored_values_dict(
monitored_values_dict=self.valid_monitored_values,
Expand All @@ -1502,11 +1529,16 @@ def forward(self, inputs, targets):
context.update_context(metrics_dict=valid_metrics_dict)
self.phase_callback_handler.on_validation_loader_end(context)

self._write_scalars_to_logger(metrics=valid_metrics_dict, epoch=1 + epoch, inference_time=val_inf_time, tag="Valid")

test_metrics_dict = {}
if (epoch + 1) % self.run_test_freq == 0:
self.phase_callback_handler.on_test_loader_start(context)
test_inf_time = 0.0
for dataset_name, dataloader in self.test_loaders.items():
timer.start()
dataset_metrics_dict = self._test_epoch(data_loader=dataloader, context=context, silent_mode=silent_mode, dataset_name=dataset_name)
test_inf_time += timer.stop()
dataset_metrics_dict_with_name = {
f"{dataset_name}:{metric_name}": metric_value for metric_name, metric_value in dataset_metrics_dict.items()
}
Expand All @@ -1519,20 +1551,23 @@ def forward(self, inputs, targets):
context.update_context(metrics_dict=test_metrics_dict)
self.phase_callback_handler.on_test_loader_end(context)

self._write_scalars_to_logger(metrics=test_metrics_dict, epoch=1 + epoch, inference_time=test_inf_time, tag="Test")

if self.ema:
self.net = keep_model

if not self.ddp_silent_mode:
self.sg_logger.add_scalars(tag_scalar_dict=self._epoch_start_logging_values, global_step=1 + epoch)

# SAVING AND LOGGING OCCURS ONLY IN THE MAIN PROCESS (IN CASES THERE ARE SEVERAL PROCESSES - DDP)
self._write_to_disk_operations(
train_metrics_dict=train_metrics_dict,
validation_results_dict=valid_metrics_dict,
test_metrics_dict=test_metrics_dict,
lr_dict=self._epoch_start_logging_values,
inf_time=inf_time,
epoch=epoch,
context=context,
)
if should_run_validation and self.training_params.save_model:
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=valid_metrics_dict,
context=context,
)
self.sg_logger.upload()

if not silent_mode:
Expand Down Expand Up @@ -1938,35 +1973,20 @@ def _get_hyper_param_config(self):
}
return hyper_param_config

def _write_to_disk_operations(
self,
train_metrics_dict: dict,
validation_results_dict: dict,
test_metrics_dict: dict,
lr_dict: dict,
inf_time: float,
epoch: int,
context: PhaseContext,
):
"""Run the various logging operations, e.g.: log file, Tensorboard, save checkpoint etc."""
result_dict = {
"Inference Time": inf_time,
**{f"Train_{k}": v for k, v in train_metrics_dict.items()},
**{f"Valid_{k}": v for k, v in validation_results_dict.items()},
**{f"Test_{k}": v for k, v in test_metrics_dict.items()},
}
self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=epoch)
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
def _write_scalars_to_logger(self, metrics: dict, epoch: int, inference_time: float, tag: str) -> None:
"""
Method for writing metrics and LR info to logger.

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)
:param metrics: (dict) dict of metrics..
:param epoch: (inf) 1-based number of epoch.
:param inference_time: (float) time of inference.
:param tag: (str) tag for writing to logger (rule of thumb: Train/Test/Valid)
"""

if not self.ddp_silent_mode:
info_dict = {f"{tag} Inference Time": inference_time, **{f"{tag}_{k}": v for k, v in metrics.items()}}

self.sg_logger.add_scalars(tag_scalar_dict=info_dict, global_step=epoch)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down
31 changes: 31 additions & 0 deletions tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,37 @@ def test_checkpoint_content(self):
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
self.assertListEqual(["net"], list(weights_only.keys()))

def test_validation_frequency_divisible(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 4
training_params["run_validation_freq"] = 2
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
metrics = {}
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
metrics[ckpt_path] = ckpt["metrics"]
self.assertTrue(metrics[ckpt_paths[0]] == metrics[ckpt_paths[1]])

def test_validation_frequency_and_save_ckpt_list(self):
trainer, model = self.get_classification_trainer(self.experiment_names[0])
training_params = self.training_params.copy()
training_params["max_epochs"] = 5
training_params["run_validation_freq"] = 3
training_params["save_ckpt_epoch_list"] = [1]
trainer.train(
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
)
ckpt_filename = ["ckpt_epoch_1.pth", "ckpt_latest.pth"]
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
for ckpt_path in ckpt_paths:
ckpt = torch.load(ckpt_path)
self.assertTrue("valid" in ckpt["metrics"].keys())


if __name__ == "__main__":
unittest.main()