Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move result teardown to loops #8245

Merged
merged 9 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,7 @@ def on_evaluation_epoch_end(self) -> None:
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:

def teardown(self) -> None:
"""Frees memory of tracked epoch outputs."""
self.epoch_output = None
self._epoch_output = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._results.cpu()
self.batch_loop.teardown()
self.val_loop.teardown()

def _run_validation(self):
# reload dataloaders
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,6 @@ def state_dict(self) -> Dict:

def load_state_dict(self, state_dict: Dict) -> None:
self.epoch_loop.load_state_dict(state_dict["epoch_loop"])

def teardown(self) -> None:
self.epoch_loop.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,3 @@ def progress_bar_metrics(self) -> Dict[str, float]:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
self.trainer.fit_loop.epoch_loop._results.cpu()
self.trainer.fit_loop.epoch_loop.val_loop._results.cpu()
self.trainer.validate_loop._results.cpu()
self.trainer.test_loop._results.cpu()
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,6 @@ def _pre_dispatch(self):
def _post_dispatch(self):
self.accelerator.post_dispatch(self)
self.accelerator.teardown()
self.logger_connector.teardown()

def _dispatch(self):
if self.evaluating:
Expand Down