Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure in Training due to Empty Validation Metrics Dictionary when run_validation_freq is greater than 1 #1324

Closed
AlfredQin opened this issue Jul 26, 2023 · 2 comments
Labels
🐛 Bug Something isn't working

Comments

@AlfredQin
Copy link

🐛 Describe the bug

Bug: Failure in Training due to Empty Validation Metrics Dictionary

Description

When the run_validation_freq parameter in training_params of Trainer is set to a value greater than 1 (e.g., run_validation_freq: 5), the training process fails. This appears to be due to an issue with the validation metrics dictionary (valid_metrics_dict).

Details

The valid_metrics_dict is initialized as an empty dictionary. However, if the condition (epoch + 1) < run_validation_freq is met, the dictionary remains empty and is passed to the _write_to_disk_operations function. This results in an error when attempting to access validation_results_dict[self.metric_to_watch] in the _save_checkpoint function because the validation_results_dict is empty.

Here is the problematic code:

# RUN TEST ON VALIDATION SET EVERY self.run_validation_freq EPOCHS
valid_metrics_dict = {}
if (epoch + 1) % self.run_validation_freq == 0:
    ...
    valid_metrics_dict = self._validate_epoch(context=context, silent_mode=silent_mode)
    ...
...
self._write_to_disk_operations(
    train_metrics_dict=train_metrics_dict,
    validation_results_dict=valid_metrics_dict,
    ...
)
...
...
def _write_to_disk_operations(
        self,
        train_metrics_dict: dict,
        validation_results_dict: dict,
        test_metrics_dict: dict,
        lr_dict: dict,
        inf_time: float,
        epoch: int,
        context: PhaseContext,
    ):
     ...
        # SAVE THE CHECKPOINT
        if self.training_params.save_model:
            self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
...
def _save_checkpoint(
    self,
    optimizer: torch.optim.Optimizer = None,
    epoch: int = None,
    validation_results_dict: Optional[Dict[str, float]] = None,
    context: PhaseContext = None,
) -> None:
    ...
    metric = validation_results_dict[self.metric_to_watch]


### Versions

Set run_validation_freq > 1 in the training parameters.
Run the training process.
Observe the error when the training tries to save a checkpoint.
@BloodAxe BloodAxe added the 🐛 Bug Something isn't working label Aug 10, 2023
@BloodAxe
Copy link
Collaborator

Thanks for the bug report. We will investigate this behavior

@hakuryuu96
Copy link
Contributor

@BloodAxe please, see this PR. I've attempted to fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants