From f94b919b96def9ffd46124f00c67afde5ddaa80e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Jun 2020 12:33:41 +0200 Subject: [PATCH] deprecated: epoch indexing from 1 (#2206) * epoch indexing from 1 * chlog * fix tests * fix tests * self.min_epochs --- CHANGELOG.md | 1 + .../callbacks/gradient_accumulation_scheduler.py | 8 ++------ pytorch_lightning/callbacks/progress.py | 4 ++-- pytorch_lightning/trainer/training_io.py | 2 +- pytorch_lightning/trainer/training_loop.py | 8 ++++---- tests/callbacks/test_callbacks.py | 2 +- tests/models/test_hooks.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/test_trainer.py | 6 +++--- 9 files changed, 16 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89ab2d7d3cee7..11978430c6b86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020)) - Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166)) - Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126)) +- Changed epoch/step indexing from 1 instead of 0 ([#2206](https://github.com/PyTorchLightning/pytorch-lightning/pull/2206)) ### Deprecated diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index bc1cd79e96f63..15b67e6c33560 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -42,11 +42,9 @@ def __init__(self, scheduling: dict): for key in scheduling: if not isinstance(key, int) or not isinstance(scheduling[key], int): - raise TypeError("All epoches and accumulation factor must be integers") + raise TypeError("All epochs and accumulation factor must be integers") minimal_epoch = min(scheduling.keys()) - # rank_zero_warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' - # ' but will start from "0" in v0.8.0.', DeprecationWarning) if minimal_epoch < 1: raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") if minimal_epoch != 1: # if user didnt define first epoch accumulation factor @@ -56,9 +54,7 @@ def __init__(self, scheduling: dict): self.epochs = sorted(scheduling.keys()) def on_epoch_start(self, trainer, pl_module): - # indexing epochs from 1 (until v0.6.x) - # In v0.8.0, ` + 1` should be removed. - epoch = trainer.current_epoch + 1 + epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 5d2c6f742ed19..fb3dafb6e6019 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -96,7 +96,7 @@ def total_val_batches(self) -> int: if trainer.fast_dev_run and trainer.val_dataloaders is not None: total_val_batches = len(trainer.val_dataloaders) elif not self.trainer.disable_validation: - is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 + is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0 total_val_batches = trainer.num_val_batches if is_val_epoch else 0 return total_val_batches @@ -317,7 +317,7 @@ def on_epoch_start(self, trainer, pl_module): total_batches = total_train_batches + total_val_batches if not self.main_progress_bar.disable: self.main_progress_bar.reset(convert_inf(total_batches)) - self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 955f6d768d3b6..cdd6f1ce84043 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -323,7 +323,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: structured dictionary """ checkpoint = { - 'epoch': self.current_epoch + 1, + 'epoch': self.current_epoch, 'global_step': self.global_step + 1, 'pytorch-ligthning_version': pytorch_lightning.__version__, } diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a80f3b62837f0..ed98192eaa260 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -336,8 +336,8 @@ def train(self): model.on_train_start() try: - # run all epochs - for epoch in range(self.current_epoch, self.max_epochs): + # run all epochs from actual + 1 till the maximal + for epoch in range(self.current_epoch + 1, self.max_epochs + 1): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) @@ -372,7 +372,7 @@ def train(self): self.update_learning_rates(interval='epoch') # early stopping - met_min_epochs = epoch >= self.min_epochs - 1 + met_min_epochs = epoch >= self.min_epochs met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback @@ -466,7 +466,7 @@ def run_training_epoch(self): # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 - can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 3341ff51229b6..399f6ba3cb06d 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -248,7 +248,7 @@ def training_step(self, *args, **kwargs): result = trainer.fit(model) assert result == 1, 'training failed to complete' - assert trainer.current_epoch < trainer.max_epochs + assert trainer.current_epoch <= trainer.max_epochs def test_pickling(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 47b73eb9e715b..d4cfefab29f46 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -68,7 +68,7 @@ def training_epoch_end(self, outputs): # a metric shared in both methods gets overwritten by epoch_end assert metrics['shared_metric'] == 111 # metrics are kept after each epoch - for i in range(num_epochs): + for i in range(1, num_epochs + 1): assert metrics[f'epoch_metric_{i}'] == i diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 99919ce402600..c2219548dacac 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -172,7 +172,7 @@ def test_dp_resume(tmpdir): result = trainer.fit(model) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun - real_global_epoch = trainer.current_epoch + 1 + real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert result == 1, 'amp + dp model failed to complete' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e3701a1f9418f..ba0f500844dff 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -451,7 +451,7 @@ def test_trainer_max_steps_and_epochs(tmpdir): # check training stopped at max_epochs assert trainer.global_step == num_train_samples * trainer.max_epochs - assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs" + assert trainer.current_epoch == trainer.max_epochs, "Model did not stop at max_epochs" def test_trainer_min_steps_and_epochs(tmpdir): @@ -619,7 +619,7 @@ def validation_epoch_end(self, *args, **kwargs): # check that val_percent_check=0 turns off validation assert result == 1, 'training failed to complete' - assert trainer.current_epoch == 1 + assert trainer.current_epoch == 2 assert not model.validation_step_invoked, \ '`validation_step` should not run when `val_percent_check=0`' assert not model.validation_epoch_end_invoked, \ @@ -632,7 +632,7 @@ def validation_epoch_end(self, *args, **kwargs): result = trainer.fit(model) assert result == 1, 'training failed to complete' - assert trainer.current_epoch == 0 + assert trainer.current_epoch == 1 assert model.validation_step_invoked, \ 'did not run `validation_step` with `fast_dev_run=True`' assert model.validation_epoch_end_invoked, \