Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated hooks #2850

Merged
merged 9 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class Callback(abc.ABC):
Abstract base class used to build new callbacks.
"""

def setup(self, trainer, stage: str):
def setup(self, trainer, pl_module, stage: str):
"""Called when fit or test begins"""
pass

def teardown(self, trainer, stage: str):
def teardown(self, trainer, pl_module, stage: str):
"""Called when fit or test ends"""
pass

Expand All @@ -30,11 +30,11 @@ def on_init_end(self, trainer):
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_fit_start(self, trainer):
def on_fit_start(self, trainer, pl_module):
"""Called when fit begins"""
pass

def on_fit_end(self, trainer):
def on_fit_end(self, trainer, pl_module):
"""Called when fit ends"""
pass

Expand All @@ -46,11 +46,11 @@ def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_train_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
pass

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
pass

Expand Down Expand Up @@ -90,19 +90,19 @@ def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_validation_batch_start(self, trainer, pl_module):
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
pass

def on_validation_batch_end(self, trainer, pl_module):
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
pass

def on_test_batch_start(self, trainer, pl_module):
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
pass

def on_test_batch_end(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
pass

Expand All @@ -118,6 +118,14 @@ def on_train_end(self, trainer, pl_module):
"""Called when the train ends."""
pass

def on_pretrain_routine_start(self, trainer, pl_module):
"""Called when the pretrain routine begins."""
pass

def on_pretrain_routine_end(self, trainer, pl_module):
"""Called when the pretrain routine ends."""
pass

def on_validation_start(self, trainer, pl_module):
"""Called when the validation loop begins."""
pass
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module):
# Initialize for storing values
self.lrs = {name: [] for name in names}

def on_train_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,19 @@ def on_train_start(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0

def on_validation_batch_end(self, trainer, pl_module):
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1

def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0

def on_test_batch_end(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1


Expand Down Expand Up @@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module):
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')

def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
Expand All @@ -329,8 +329,8 @@ def on_validation_start(self, trainer, pl_module):
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module):
super().on_validation_batch_end(trainer, pl_module)
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)
Expand All @@ -349,8 +349,8 @@ def on_test_start(self, trainer, pl_module):
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module):
super().on_test_batch_end(trainer, pl_module)
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
self.test_progress_bar.update(self.refresh_rate)

Expand Down
35 changes: 33 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,51 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_train_batch_start(self, batch: Any) -> None:
def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the start of the pretrain routine

def on_pretrain_routine_end(self) -> None:
"""
Called at the end of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the end of the pretrain routine

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch starts

def on_train_batch_end(self) -> None:
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop after the batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch end

Expand Down
44 changes: 27 additions & 17 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class TrainerCallbackHookMixin(ABC):
def setup(self, stage: str):
"""Called in the beginning of fit and test"""
for callback in self.callbacks:
callback.setup(self, stage)
callback.setup(self, self.get_model(), stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
for callback in self.callbacks:
callback.teardown(self, stage)
callback.teardown(self, self.get_model(), stage)

def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
Expand All @@ -31,15 +31,15 @@ def on_init_end(self):
for callback in self.callbacks:
callback.on_init_end(self)

def on_fit_start(self):
def on_fit_start(self, model):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_start(self)
callback.on_fit_start(self, model)

def on_fit_end(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_end(self)
callback.on_fit_end(self, self.get_model())

def on_sanity_check_start(self):
"""Called when the validation sanity check starts."""
Expand Down Expand Up @@ -101,6 +101,16 @@ def on_train_end(self):
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())

def on_pretrain_routine_start(self, model):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_pretrain_routine_start(self, model)

def on_pretrain_routine_end(self, model):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_pretrain_routine_end(self, model)

def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
Expand All @@ -111,35 +121,35 @@ def on_batch_end(self):
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_train_batch_start(self):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.get_model())
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_train_batch_end(self):
def on_train_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.get_model())
callback.on_train_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_batch_start(self):
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
for callback in self.callbacks:
callback.on_validation_batch_start(self, self.get_model())
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self):
def on_validation_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
for callback in self.callbacks:
callback.on_validation_batch_end(self, self.get_model())
callback.on_validation_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_test_batch_start(self):
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
for callback in self.callbacks:
callback.on_test_batch_start(self, self.get_model())
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_test_batch_end(self):
def on_test_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
for callback in self.callbacks:
callback.on_test_batch_end(self, self.get_model())
callback.on_test_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_start(self):
"""Called when the validation loop begins."""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def _evaluate(

# callbacks
if test_mode:
self.on_test_batch_start()
self.on_test_batch_start(batch, batch_idx, dataloader_idx)
else:
self.on_validation_batch_start()
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)

# -----------------
# RUN EVALUATION STEP
Expand All @@ -336,13 +336,13 @@ def _evaluate(
model_ref = self.get_model()
with self.profiler.profile('test_step_end'):
output = model_ref.test_step_end(output)
self.on_test_batch_end()
self.on_test_batch_end(batch, batch_idx, dataloader_idx)
else:
if self.is_overridden('validation_step_end'):
model_ref = self.get_model()
with self.profiler.profile('validation_step_end'):
output = model_ref.validation_step_end(output)
self.on_validation_batch_end()
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)

# track outputs for collation
if output is not None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def on_batch_start(self, trainer, pl_module):

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def fit(
self.config_validator.verify_loop_configurations(model)

# callbacks
self.on_fit_start()
self.on_fit_start(model)
if self.is_function_implemented('on_fit_start', model):
model.on_fit_start()

Expand Down Expand Up @@ -1055,13 +1055,12 @@ def fit(
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)

# callbacks
# on fit end callback
self.on_fit_end()

# model hooks
if self.is_function_implemented('on_fit_end'):
model.on_fit_end()

# teardown callback
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')
Expand Down Expand Up @@ -1156,6 +1155,11 @@ def run_pretrain_routine(self, model: LightningModule):
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

# on pretrain routine start
self.on_pretrain_routine_start(ref_model)
if self.is_function_implemented('on_pretrain_routine_start'):
ref_model.on_pretrain_routine_start()

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ModelSummary.MODES:
Expand Down Expand Up @@ -1198,6 +1202,11 @@ def run_pretrain_routine(self, model: LightningModule):
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# CORE TRAINING LOOP
self.train()

Expand Down
Loading