Skip to content

Commit

Permalink
Move result teardown to loops (#8245)
Browse files Browse the repository at this point in the history
* Move result teardown to loops

* Update CHANGELOG

* Remove teardown from run

* Move previous teardown to on_run_end

* Add comment

* Merge 8250

* Remove stage set to None where it shouldnt
  • Loading branch information
carmocca authored Jul 2, 2021
1 parent f3e74ab commit 0e19d16
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))
* Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245))


- Refactored logging
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
break

output = self.on_run_end()
self.teardown()
return output

def restore(self) -> None:
Expand Down Expand Up @@ -149,7 +148,7 @@ def on_run_end(self) -> Any:
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""

def teardown(self) -> None:
"""The very last method called inside :meth:`run`. Use to release memory etc."""
"""Use to release memory etc."""

def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,7 @@ def on_evaluation_epoch_end(self) -> None:
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
7 changes: 3 additions & 4 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,10 @@ def advance(

def on_run_end(self) -> List[STEP_OUTPUT]:
"""Returns the outputs of the whole run"""
return self.outputs

def teardown(self) -> None:
"""Frees memory of tracked outputs"""
outputs = self.outputs
# free memory
self.outputs = []
return outputs

def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
"""The evaluation step (validation_step or test_step depending on the trainer's state).
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def advance(

def on_run_end(self) -> Tuple[Any, Any]:
"""Returns the predictions and the corresponding batch indices"""
return self.predictions, self._all_batch_indices

def teardown(self) -> None:
"""Frees memory of collected predictions."""
predictions = self.predictions
all_batch_indices = self._all_batch_indices
# free memory
self.predictions = []
self._all_batch_indices = []
return predictions, all_batch_indices

def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Runs the actual predict step together with all the
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,16 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
self._on_train_epoch_end_hook(processed_outputs)
self.trainer.call_hook('on_epoch_end')
self.trainer.logger_connector.on_epoch_end()
return self._epoch_output

epoch_output = self._epoch_output
# free memory
self._epoch_output = None
return epoch_output

def teardown(self) -> None:
"""Frees memory of tracked epoch outputs."""
self.epoch_output = None
self._results.cpu()
self.batch_loop.teardown()
self.val_loop.teardown()

def _run_validation(self):
# reload dataloaders
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def on_advance_end(self) -> None:
self.global_step += 1

def on_run_end(self) -> None:
"""Runs teardown logic and calls the ``on_train_end`` hook"""
"""Calls the ``on_train_end`` hook"""
# NOTE: the iteration_count/current_epoch is already incremented
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
Expand Down Expand Up @@ -265,9 +265,6 @@ def on_run_end(self) -> None:
# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()

# reset bookkeeping
self.trainer._running_stage = None

def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
return self.epoch_loop.batch_loop.should_accumulate()
Expand All @@ -291,3 +288,6 @@ def state_dict(self) -> Dict:

def load_state_dict(self, state_dict: Dict) -> None:
self.epoch_loop.load_state_dict(state_dict["epoch_loop"])

def teardown(self) -> None:
self.epoch_loop.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,3 @@ def progress_bar_metrics(self) -> Dict[str, float]:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
self.trainer.fit_loop.epoch_loop._results.cpu()
self.trainer.fit_loop.epoch_loop.val_loop._results.cpu()
self.trainer.validate_loop._results.cpu()
self.trainer.test_loop._results.cpu()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,10 @@ def _pre_dispatch(self):

def _post_dispatch(self):
self.accelerator.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
self.logger_connector.teardown()
self._active_loop.teardown()

def _dispatch(self):
if self.evaluating:
Expand Down Expand Up @@ -977,7 +979,6 @@ def _run_train(self) -> None:
self.on_keyboard_interrupt()
# same treatment as below
self.accelerator.on_train_end()
self.state.stage = None
except BaseException:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
Expand Down

0 comments on commit 0e19d16

Please sign in to comment.