-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support num_sanity_val_steps=-1 #2246
Merged
Merged
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e1e9680
support sanity_val_step=-1
awaelchli d375e2a
fix list size
awaelchli 0be420c
simplification
awaelchli 66f4e91
simplify
awaelchli d7b56e9
add test for num_sanity_val_steps=-1
awaelchli bd47b70
update test
awaelchli dd3862d
update docs
awaelchli a15fd8d
extend tests to multiple dataloaders
awaelchli cef85c3
changelog
awaelchli e644a29
Merge branch 'master' into feature/sanity-val-full-data
awaelchli e7c2a8d
Update tests/trainer/test_trainer.py
awaelchli a26fe29
Merge branch 'master' into feature/sanity-val-full-data
awaelchli 231c0c6
improve test
awaelchli 6dc3d8b
refactor the sanity check decision
awaelchli b7a973f
Merge branch 'master' into feature/sanity-val-full-data
awaelchli 83c03df
Merge branch 'master' into feature/sanity-val-full-data
awaelchli 637ba3a
Merge branch 'master' into feature/sanity-val-full-data
awaelchli 8713ccd
fix merge
awaelchli ed62155
Update trainer.py
williamFalcon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -336,7 +336,8 @@ def __init__( | |
|
||
amp_level: The optimization level to use (O1, O2, etc...). | ||
|
||
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. | ||
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. | ||
Set it to `-1` to run all batches in all validation dataloaders. Default: 2 | ||
|
||
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of | ||
|
||
|
@@ -408,7 +409,6 @@ def __init__( | |
# training state | ||
self.model = None | ||
self.testing = False | ||
self.disable_validation = False | ||
self.prepare_data_per_node = prepare_data_per_node | ||
self.lr_schedulers = [] | ||
self.optimizers = None | ||
|
@@ -488,7 +488,7 @@ def __init__( | |
self.max_steps = max_steps | ||
self.min_steps = min_steps | ||
|
||
self.num_sanity_val_steps = num_sanity_val_steps | ||
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps | ||
# Backward compatibility, TODO: remove in v0.9.0 | ||
if print_nan_grads: | ||
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." | ||
|
@@ -883,6 +883,18 @@ def progress_bar_dict(self) -> dict: | |
ref_model = self.model if not self.data_parallel else self.model.module | ||
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) | ||
|
||
@property | ||
def disable_validation(self) -> bool: | ||
""" Check if validation is disabled during training. """ | ||
disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when all are not, would it be easier to write There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
and not self.fast_dev_run | ||
return disable_validation | ||
|
||
@property | ||
def enable_validation(self) -> bool: | ||
""" Check if we should run validation during training. """ | ||
return not self.disable_validation | ||
|
||
# ----------------------------- | ||
# MODEL TRAINING | ||
# ----------------------------- | ||
|
@@ -1186,10 +1198,6 @@ def run_pretrain_routine(self, model: LightningModule): | |
|
||
return eval_loop_results | ||
|
||
# check if we should run validation during training | ||
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ | ||
and not self.fast_dev_run | ||
|
||
# run a few val batches before training starts | ||
self._run_sanity_check(ref_model, model) | ||
|
||
|
@@ -1204,9 +1212,12 @@ def run_pretrain_routine(self, model: LightningModule): | |
self.train() | ||
|
||
def _run_sanity_check(self, ref_model, model): | ||
should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ | ||
and self.limit_val_batches > 0 | ||
|
||
# run tiny validation (if validation defined) | ||
# to make sure program won't crash during val | ||
if not self.disable_validation and self.num_sanity_val_steps > 0: | ||
if should_sanity_check: | ||
self.reset_val_dataloader(ref_model) | ||
|
||
# hook and callback | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tqdm does not understand float("inf"), so we have to convert it to None in the case num_sanity_val_steps=inf or dataloader has inf length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right but where is this imported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see... just kind of stuck at the bottom lol