diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0382477768cbf..2cf9aeb40ed09 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -266,76 +266,67 @@ def evaluate(self, model, dataloaders, max_batches, test=False): def run_evaluation(self, test=False): # when testing make sure user defined a test step - can_run_test_step = False + if test and not (self.is_overriden('test_step') and self.is_overriden('test_end')): + m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`. + Please define and try again''' + raise MisconfigurationException(m) + + # hook + model = self.get_model() + model.on_pre_performance_check() + + # select dataloaders if test: - can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end') - if not can_run_test_step: - m = '''You called .test() without defining a test step or test_end. - Please define and try again''' - raise MisconfigurationException(m) - - # validate only if model has validation_step defined - # test only if test_step or validation_step are defined - run_val_step = self.is_overriden('validation_step') - - if run_val_step or can_run_test_step: - - # hook - model = self.get_model() - model.on_pre_performance_check() - - # select dataloaders - if test: - dataloaders = self.get_test_dataloaders() - max_batches = self.num_test_batches - else: - # val - dataloaders = self.get_val_dataloaders() - max_batches = self.num_val_batches - - # cap max batches to 1 when using fast_dev_run - if self.fast_dev_run: - max_batches = 1 - - # init validation or test progress bar - # main progress bar will already be closed when testing so initial position is free - position = 2 * self.process_position + (not test) - desc = 'Testing' if test else 'Validating' - pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, - disable=not self.show_progress_bar, dynamic_ncols=True, - unit='batch', file=sys.stdout) - setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) - - # run evaluation - eval_results = self.evaluate(self.model, - dataloaders, - max_batches, - test) - _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( - eval_results) - - # add metrics to prog bar - self.add_tqdm_metrics(prog_bar_metrics) - - # log metrics - self.log_metrics(log_metrics, {}) - - # track metrics for callbacks - self.callback_metrics.update(callback_metrics) - - # hook - model.on_post_performance_check() - - # add model specific metrics - tqdm_metrics = self.training_tqdm_dict - if not test: - self.main_progress_bar.set_postfix(**tqdm_metrics) - - # close progress bar - if test: - self.test_progress_bar.close() - else: - self.val_progress_bar.close() + dataloaders = self.get_test_dataloaders() + max_batches = self.num_test_batches + else: + # val + dataloaders = self.get_val_dataloaders() + max_batches = self.num_val_batches + + # cap max batches to 1 when using fast_dev_run + if self.fast_dev_run: + max_batches = 1 + + # init validation or test progress bar + # main progress bar will already be closed when testing so initial position is free + position = 2 * self.process_position + (not test) + desc = 'Testing' if test else 'Validating' + pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch', file=sys.stdout) + setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) + + # run evaluation + eval_results = self.evaluate(self.model, + dataloaders, + max_batches, + test) + _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( + eval_results) + + # add metrics to prog bar + self.add_tqdm_metrics(prog_bar_metrics) + + # log metrics + self.log_metrics(log_metrics, {}) + + # track metrics for callbacks + self.callback_metrics.update(callback_metrics) + + # hook + model.on_post_performance_check() + + # add model specific metrics + tqdm_metrics = self.training_tqdm_dict + if not test: + self.main_progress_bar.set_postfix(**tqdm_metrics) + + # close progress bar + if test: + self.test_progress_bar.close() + else: + self.val_progress_bar.close() # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 50282866e2b93..ca90da2660335 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -213,6 +213,7 @@ def __init__( # training state self.model = None self.testing = False + self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.global_step = 0 @@ -486,11 +487,16 @@ def run_pretrain_routine(self, model): self.run_evaluation(test=True) return + # check if we should run validation during training + self.disable_validation = ((self.num_val_batches == 0 or + not self.is_overriden('validation_step')) and + not self.fast_dev_run) + # run tiny validation (if validation defined) # to make sure program won't crash during val ref_model.on_sanity_check_start() ref_model.on_train_start() - if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0: + if not self.disable_validation and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm.tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2a98780aa766..f9860dc044b50 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -184,6 +184,7 @@ def __init__(self): self.num_training_batches = None self.val_check_batch = None self.num_val_batches = None + self.disable_validation = None self.fast_dev_run = None self.is_iterable_train_dataloader = None self.main_progress_bar = None @@ -294,14 +295,16 @@ def train(self): model.current_epoch = epoch self.current_epoch = epoch - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = 0 + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch # total batches includes multiple val checks - self.total_batches = (self.num_training_batches + - self.num_val_batches * val_checks_per_epoch) + self.total_batches = self.num_training_batches + total_val_batches self.batch_loss_value = 0 # accumulated grads if self.fast_dev_run: @@ -386,7 +389,8 @@ def run_training_epoch(self): # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch) + should_check_val = (not self.disable_validation and can_check_epoch and + (is_val_check_batch or early_stop_epoch)) # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: