Skip to content

Commit

Permalink
deprecated: epoch indexing from 1 (#2206)
Browse files Browse the repository at this point in the history
* epoch indexing from 1

* chlog

* fix tests

* fix tests

* self.min_epochs
  • Loading branch information
Borda authored Jun 16, 2020
1 parent 8870a84 commit f94b919
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Changed epoch/step indexing from 1 instead of 0 ([#2206](https://github.com/PyTorchLightning/pytorch-lightning/pull/2206))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ def __init__(self, scheduling: dict):

for key in scheduling:
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")
raise TypeError("All epochs and accumulation factor must be integers")

minimal_epoch = min(scheduling.keys())
# rank_zero_warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
# ' but will start from "0" in v0.8.0.', DeprecationWarning)
if minimal_epoch < 1:
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
Expand All @@ -56,9 +54,7 @@ def __init__(self, scheduling: dict):
self.epochs = sorted(scheduling.keys())

def on_epoch_start(self, trainer, pl_module):
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def total_val_batches(self) -> int:
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
total_val_batches = len(trainer.val_dataloaders)
elif not self.trainer.disable_validation:
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
total_val_batches = trainer.num_val_batches if is_val_epoch else 0
return total_val_batches

Expand Down Expand Up @@ -317,7 +317,7 @@ def on_epoch_start(self, trainer, pl_module):
total_batches = total_train_batches + total_val_batches
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
structured dictionary
"""
checkpoint = {
'epoch': self.current_epoch + 1,
'epoch': self.current_epoch,
'global_step': self.global_step + 1,
'pytorch-ligthning_version': pytorch_lightning.__version__,
}
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def train(self):
model.on_train_start()

try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# run all epochs from actual + 1 till the maximal
for epoch in range(self.current_epoch + 1, self.max_epochs + 1):
# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
Expand Down Expand Up @@ -372,7 +372,7 @@ def train(self):
self.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_epochs = epoch >= self.min_epochs
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

# TODO wrap this logic into the callback
Expand Down Expand Up @@ -466,7 +466,7 @@ def run_training_epoch(self):
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def training_step(self, *args, **kwargs):
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs
assert trainer.current_epoch <= trainer.max_epochs


def test_pickling(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def training_epoch_end(self, outputs):
# a metric shared in both methods gets overwritten by epoch_end
assert metrics['shared_metric'] == 111
# metrics are kept after each epoch
for i in range(num_epochs):
for i in range(1, num_epochs + 1):
assert metrics[f'epoch_metric_{i}'] == i


Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_dp_resume(tmpdir):
result = trainer.fit(model)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
real_global_epoch = trainer.current_epoch + 1
real_global_epoch = trainer.current_epoch

# correct result and ok accuracy
assert result == 1, 'amp + dp model failed to complete'
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):

# check training stopped at max_epochs
assert trainer.global_step == num_train_samples * trainer.max_epochs
assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs"
assert trainer.current_epoch == trainer.max_epochs, "Model did not stop at max_epochs"


def test_trainer_min_steps_and_epochs(tmpdir):
Expand Down Expand Up @@ -619,7 +619,7 @@ def validation_epoch_end(self, *args, **kwargs):

# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert trainer.current_epoch == 2
assert not model.validation_step_invoked, \
'`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_epoch_end_invoked, \
Expand All @@ -632,7 +632,7 @@ def validation_epoch_end(self, *args, **kwargs):
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert trainer.current_epoch == 1
assert model.validation_step_invoked, \
'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_epoch_end_invoked, \
Expand Down

0 comments on commit f94b919

Please sign in to comment.