Skip to content

Commit

Permalink
updated hooks (#2850)
Browse files Browse the repository at this point in the history
* modified hooks

* modified hooks

* modified hooks

* modified hooks

* modified hooks

* modified hooks

* modified hooks

* modified hooks

* modified hooks
  • Loading branch information
williamFalcon committed Aug 7, 2020
1 parent b39f479 commit f82d7fe
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 97 deletions.
28 changes: 18 additions & 10 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class Callback(abc.ABC):
Abstract base class used to build new callbacks.
"""

def setup(self, trainer, stage: str):
def setup(self, trainer, pl_module, stage: str):
"""Called when fit or test begins"""
pass

def teardown(self, trainer, stage: str):
def teardown(self, trainer, pl_module, stage: str):
"""Called when fit or test ends"""
pass

Expand All @@ -30,11 +30,11 @@ def on_init_end(self, trainer):
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_fit_start(self, trainer):
def on_fit_start(self, trainer, pl_module):
"""Called when fit begins"""
pass

def on_fit_end(self, trainer):
def on_fit_end(self, trainer, pl_module):
"""Called when fit ends"""
pass

Expand All @@ -46,11 +46,11 @@ def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_train_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
pass

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
pass

Expand Down Expand Up @@ -90,19 +90,19 @@ def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_validation_batch_start(self, trainer, pl_module):
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
pass

def on_validation_batch_end(self, trainer, pl_module):
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
pass

def on_test_batch_start(self, trainer, pl_module):
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
pass

def on_test_batch_end(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
pass

Expand All @@ -118,6 +118,14 @@ def on_train_end(self, trainer, pl_module):
"""Called when the train ends."""
pass

def on_pretrain_routine_start(self, trainer, pl_module):
"""Called when the pretrain routine begins."""
pass

def on_pretrain_routine_end(self, trainer, pl_module):
"""Called when the pretrain routine ends."""
pass

def on_validation_start(self, trainer, pl_module):
"""Called when the validation loop begins."""
pass
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module):
# Initialize for storing values
self.lrs = {name: [] for name in names}

def on_train_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,19 @@ def on_train_start(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0

def on_validation_batch_end(self, trainer, pl_module):
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1

def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0

def on_test_batch_end(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1


Expand Down Expand Up @@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module):
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')

def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
Expand All @@ -329,8 +329,8 @@ def on_validation_start(self, trainer, pl_module):
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module):
super().on_validation_batch_end(trainer, pl_module)
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)
Expand All @@ -349,8 +349,8 @@ def on_test_start(self, trainer, pl_module):
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module):
super().on_test_batch_end(trainer, pl_module)
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
self.test_progress_bar.update(self.refresh_rate)

Expand Down
35 changes: 33 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,51 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_train_batch_start(self, batch: Any) -> None:
def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the start of the pretrain routine

def on_pretrain_routine_end(self) -> None:
"""
Called at the end of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
"""
# do something at the end of the pretrain routine

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch starts

def on_train_batch_end(self) -> None:
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop after the batch.
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
"""
# do something when the batch end

Expand Down
44 changes: 27 additions & 17 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class TrainerCallbackHookMixin(ABC):
def setup(self, stage: str):
"""Called in the beginning of fit and test"""
for callback in self.callbacks:
callback.setup(self, stage)
callback.setup(self, self.get_model(), stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
for callback in self.callbacks:
callback.teardown(self, stage)
callback.teardown(self, self.get_model(), stage)

def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
Expand All @@ -31,15 +31,15 @@ def on_init_end(self):
for callback in self.callbacks:
callback.on_init_end(self)

def on_fit_start(self):
def on_fit_start(self, model):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_start(self)
callback.on_fit_start(self, model)

def on_fit_end(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_end(self)
callback.on_fit_end(self, self.get_model())

def on_sanity_check_start(self):
"""Called when the validation sanity check starts."""
Expand Down Expand Up @@ -101,6 +101,16 @@ def on_train_end(self):
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())

def on_pretrain_routine_start(self, model):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_pretrain_routine_start(self, model)

def on_pretrain_routine_end(self, model):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_pretrain_routine_end(self, model)

def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
Expand All @@ -111,35 +121,35 @@ def on_batch_end(self):
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_train_batch_start(self):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.get_model())
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_train_batch_end(self):
def on_train_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.get_model())
callback.on_train_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_batch_start(self):
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
for callback in self.callbacks:
callback.on_validation_batch_start(self, self.get_model())
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self):
def on_validation_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
for callback in self.callbacks:
callback.on_validation_batch_end(self, self.get_model())
callback.on_validation_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_test_batch_start(self):
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
for callback in self.callbacks:
callback.on_test_batch_start(self, self.get_model())
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_test_batch_end(self):
def on_test_batch_end(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
for callback in self.callbacks:
callback.on_test_batch_end(self, self.get_model())
callback.on_test_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)

def on_validation_start(self):
"""Called when the validation loop begins."""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def _evaluate(

# callbacks
if test_mode:
self.on_test_batch_start()
self.on_test_batch_start(batch, batch_idx, dataloader_idx)
else:
self.on_validation_batch_start()
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)

# -----------------
# RUN EVALUATION STEP
Expand All @@ -336,13 +336,13 @@ def _evaluate(
model_ref = self.get_model()
with self.profiler.profile('test_step_end'):
output = model_ref.test_step_end(output)
self.on_test_batch_end()
self.on_test_batch_end(batch, batch_idx, dataloader_idx)
else:
if self.is_overridden('validation_step_end'):
model_ref = self.get_model()
with self.profiler.profile('validation_step_end'):
output = model_ref.validation_step_end(output)
self.on_validation_batch_end()
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)

# track outputs for collation
if output is not None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def on_batch_start(self, trainer, pl_module):

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_train_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def fit(
self.config_validator.verify_loop_configurations(model)

# callbacks
self.on_fit_start()
self.on_fit_start(model)
if self.is_function_implemented('on_fit_start', model):
model.on_fit_start()

Expand Down Expand Up @@ -1053,13 +1053,12 @@ def fit(
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)

# callbacks
# on fit end callback
self.on_fit_end()

# model hooks
if self.is_function_implemented('on_fit_end'):
model.on_fit_end()

# teardown callback
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')
Expand Down Expand Up @@ -1154,6 +1153,11 @@ def run_pretrain_routine(self, model: LightningModule):
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

# on pretrain routine start
self.on_pretrain_routine_start(ref_model)
if self.is_function_implemented('on_pretrain_routine_start'):
ref_model.on_pretrain_routine_start()

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ModelSummary.MODES:
Expand Down Expand Up @@ -1196,6 +1200,11 @@ def run_pretrain_routine(self, model: LightningModule):
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# CORE TRAINING LOOP
self.train()

Expand Down
Loading

0 comments on commit f82d7fe

Please sign in to comment.