Skip to content

Commit

Permalink
Fix wandb logger on resume (#766)
Browse files Browse the repository at this point in the history
* fix

* deepcopy

* fix according to comments

* make private

* remove deepcopy

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 18, 2023
1 parent 505f646 commit 4d1f1f3
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[Mu
self.max_train_batches = None
self.max_valid_batches = None

self._epoch_start_logging_values = {}

@property
def device(self) -> str:
return device_config.device
Expand Down Expand Up @@ -443,9 +445,8 @@ def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
context.update_context(preds=outputs, loss_log_items=loss_log_items)
self.phase_callback_handler.on_train_batch_loss_end(context)

# LOG LR THAT WILL BE USED IN CURRENT EPOCH AND AFTER FIRST WARMUP/LR_SCHEDULER UPDATE BEFORE WEIGHT UPDATE
if not self.ddp_silent_mode and batch_idx == 0:
self._write_lrs(epoch)
self._epoch_start_logging_values = self._get_epoch_start_logging_values()

self._backward_step(loss, epoch, batch_idx, context)

Expand Down Expand Up @@ -1294,7 +1295,14 @@ def forward(self, inputs, targets):

if not self.ddp_silent_mode:
# SAVING AND LOGGING OCCURS ONLY IN THE MAIN PROCESS (IN CASES THERE ARE SEVERAL PROCESSES - DDP)
self._write_to_disk_operations(train_metrics_tuple, validation_results_tuple, inf_time, epoch, context)
self._write_to_disk_operations(
train_metrics=train_metrics_tuple,
validation_results=validation_results_tuple,
lr_dict=self._epoch_start_logging_values,
inf_time=inf_time,
epoch=epoch,
context=context,
)
self.sg_logger.upload()

# Evaluating the average model and removing snapshot averaging file if training is completed
Expand Down Expand Up @@ -1649,24 +1657,27 @@ def _get_hyper_param_config(self):
}
return hyper_param_config

def _write_to_disk_operations(self, train_metrics: tuple, validation_results: tuple, inf_time: float, epoch: int, context: PhaseContext):
def _write_to_disk_operations(self, train_metrics: tuple, validation_results: tuple, lr_dict: dict, inf_time: float, epoch: int, context: PhaseContext):
"""Run the various logging operations, e.g.: log file, Tensorboard, save checkpoint etc."""
# STORE VALUES IN A TENSORBOARD FILE
train_results = list(train_metrics) + list(validation_results) + [inf_time]
all_titles = self.results_titles + ["Inference Time"]

result_dict = {all_titles[i]: train_results[i] for i in range(len(train_results))}
self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=epoch)
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results, context)

def _write_lrs(self, epoch):
def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
This is useful for values like Learning Rate that can change over an epoch."""
lrs = [self.optimizer.param_groups[i]["lr"] for i in range(len(self.optimizer.param_groups))]
lr_titles = ["LR/Param_group_" + str(i) for i in range(len(self.optimizer.param_groups))] if len(self.optimizer.param_groups) > 1 else ["LR"]
lr_dict = {lr_titles[i]: lrs[i] for i in range(len(lrs))}
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
return lr_dict

def test(
self,
Expand Down

0 comments on commit 4d1f1f3

Please sign in to comment.