Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[blocked by #1043] enabled early stopping/checkpoint even without val step #1041

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fa7410e
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
d4d9fe7
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
afd79b9
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
3aaaece
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
55e6444
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
11be425
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
726de59
Merge branch 'callback_on_train' of https://github.com/PyTorchLightni…
williamFalcon Mar 4, 2020
7758b16
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
11a1679
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
756472b
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
d2b9bb5
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
c60a392
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
f36993d
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
da0a21d
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
3279fa5
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
10b683b
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
64e528b
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
481d457
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
a43959e
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
a4e724e
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
cb182a2
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
49c1f22
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
b91fc01
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
76c7872
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
1703950
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
e701594
Merge branch 'callback_on_train' of https://github.com/PyTorchLightni…
williamFalcon Mar 4, 2020
d68cf70
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
1f9c824
enabled early stopping/checkpooiunt even without val step
williamFalcon Mar 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))

### Changed

Expand Down
18 changes: 13 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class ModelCheckpoint(Callback):

Example::

# no path
ModelCheckpoint()
# saves like /my/path/epoch_0.ckpt

# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')

# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt


# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt


Expand Down Expand Up @@ -146,10 +151,13 @@ def check_monitor_top_k(self, current: float) -> bool:
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
try:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
except Exception as e:
import pdb; pdb.set_trace()
return filepath

def on_validation_end(self, trainer, pl_module) -> None:
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,14 @@ def configure_checkpoint_callback(self):
else:
ckpt_path = os.path.join(self.default_save_path, "checkpoints")

# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overriden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'

self.ckpt_path = ckpt_path
self.checkpoint_callback = ModelCheckpoint(
dirpath=ckpt_path
dirpath=ckpt_path,
monitor=monitor_key
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC):
process_output: ...
training_tqdm_dict: ...
proc_rank: int
checkpoint_callback: ...
current_epoch: int
callback_metrics: ...
test_dataloaders: DataLoader
Expand Down Expand Up @@ -377,11 +376,6 @@ def run_evaluation(self, test_mode: bool = False):
# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
# model checkpointing
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,6 @@ def run_pretrain_routine(self, model: LightningModule):
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")

# set up checkpoint callback
self.configure_checkpoint_callback()

# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

Expand All @@ -1151,6 +1148,9 @@ def run_pretrain_routine(self, model: LightningModule):
# if cluster resets state, the model will update with the saved weights
self.model = model

# set up checkpoint callback
self.configure_checkpoint_callback()

# restore training and model before hpc call
self.restore_weights(model)

Expand Down
32 changes: 29 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class TrainerTrainLoopMixin(ABC):
max_steps: int
max_steps: int
total_batch_idx: int
checkpoint_callback: ...
Copy link
Member

@Borda Borda Mar 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Suggested change
checkpoint_callback: ...


# Callback system
callbacks: List[Callback]
Expand All @@ -212,6 +213,7 @@ class TrainerTrainLoopMixin(ABC):
on_batch_end: Callable
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable

@property
def max_nb_epochs(self):
Expand Down Expand Up @@ -454,9 +456,6 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
Expand All @@ -469,11 +468,33 @@ def run_training_epoch(self):
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# save checkpoint even when no test or val step are defined
train_step_only = not self.is_overriden('validation_step')
if self.fast_dev_run or should_check_val or train_step_only:
self.call_checkpoint_callback()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# save checkpoint even when no test or val step are defined
train_step_only = not self.is_overriden('validation_step')
if self.fast_dev_run or should_check_val or train_step_only:
self.call_checkpoint_callback()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
break
Expand Down Expand Up @@ -705,3 +726,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
output = self.process_output(output, train=True)

return output

def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()
46 changes: 23 additions & 23 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException

def test_hparams_save_load(tmpdir):
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1

# try to load the model now
pretrained_model = tutils.load_model_from_checkpoint(
trainer.checkpoint_callback.dirpath,
module_class=DictHparamsModel
)


def test_no_val_module(tmpdir):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
Expand Down Expand Up @@ -126,7 +147,8 @@ def test_gradient_accumulation_scheduling(tmpdir):
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

# test optimizer call freq matches scheduler
def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
def _optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, second_order_closure=None):
# only test the first 12 batches in epoch
if batch_idx < 12:
if epoch == 0:
Expand Down Expand Up @@ -620,25 +642,3 @@ def test_default_args(tmpdir):

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5


def test_hparams_save_load(tmpdir):
model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10})

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1

# try to load the model now
pretrained_model = tutils.load_model_from_checkpoint(
trainer.checkpoint_callback.dirpath,
module_class=DictHparamsModel
)