From 4fca994d0e4f42878df05fdd7b599b0e121852f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Apr 2020 13:02:53 +0200 Subject: [PATCH] Fix callback default (horror bug!) (#1534) * fix horror bug * update changelog * fix doctest * liine too long --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/trainer.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de9d59741bbe9..c0d8bdfc84c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,7 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)). -- +- Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)). ## [0.7.3] - 2020-04-09 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3c9da89ca5ec7..998a19bc16d32 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -87,7 +87,7 @@ def __init__( logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, - callbacks: List[Callback] = [], + callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, @@ -293,7 +293,7 @@ def __init__( """ # Init callbacks - self.callbacks = callbacks + self.callbacks = callbacks or [] self.on_init_start() # benchmarking @@ -546,7 +546,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: (, typing.Dict[int, int], typing.List[list]), 1), ... - ('callbacks', (,), []), + ('callbacks', + (typing.List[pytorch_lightning.callbacks.base.Callback], + ), + None), ('check_val_every_n_epoch', (,), 1), ... ('max_epochs', (,), 1000),