-
Notifications
You must be signed in to change notification settings - Fork 640
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FPS counter to
TimerCallback
(#12)
* Add FPS counter to `TimerCallback` - Fix docs strings - Add frames per second calculation to `timer.py` * Address PR comments Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
- Loading branch information
1 parent
310b6e4
commit cf8bdf6
Showing
2 changed files
with
60 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,75 @@ | ||
"""Callback to measure training and testing time of a PyTorch Lightning module.""" | ||
import time | ||
|
||
from pytorch_lightning import Callback, LightningModule | ||
from pytorch_lightning import Callback, LightningModule, Trainer | ||
|
||
|
||
class TimerCallback(Callback): | ||
"""Callback that measures the training and testing time of a PyTorch Lightning module.""" | ||
|
||
def __init__(self): | ||
self.start = None | ||
self.start: float | ||
self.num_images: int = 0 | ||
|
||
def on_fit_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when fit begins.""" | ||
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when fit begins. | ||
Sets the start time to the time training started. | ||
Args: | ||
trainer (Trainer): PyTorch Lightning trainer. | ||
pl_module (LightningModule): Current training module. | ||
Returns: | ||
None | ||
""" | ||
self.start = time.time() | ||
|
||
def on_fit_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when fit ends.""" | ||
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when fit ends. | ||
Prints the time taken for training. | ||
Args: | ||
trainer (Trainer): PyTorch Lightning trainer. | ||
pl_module (LightningModule): Current training module. | ||
Returns: | ||
None | ||
""" | ||
print(f"Training took {time.time() - self.start} seconds") | ||
|
||
def on_test_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when the test begins.""" | ||
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when the test begins. | ||
Sets the start time to the time testing started. | ||
Goes over all the test dataloaders and adds the number of images in each. | ||
Args: | ||
trainer (Trainer): PyTorch Lightning trainer. | ||
pl_module (LightningModule): Current training module. | ||
Returns: | ||
None | ||
""" | ||
self.start = time.time() | ||
self.num_images = 0 | ||
|
||
if trainer.test_dataloaders is not None: # Check to placate Mypy. | ||
for dataloader in trainer.test_dataloaders: | ||
self.num_images += len(dataloader.dataset) | ||
|
||
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when the test ends. | ||
Prints the time taken for testing and the throughput in frames per second. | ||
Args: | ||
trainer (Trainer): PyTorch Lightning trainer. | ||
pl_module (LightningModule): Current training module. | ||
def on_test_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 | ||
"""Call when the test ends.""" | ||
print(f"Testing took {time.time() - self.start} seconds.") | ||
Returns: | ||
None | ||
""" | ||
testing_time = time.time() - self.start | ||
print(f"Testing took {testing_time} seconds\nThroughput: {self.num_images/testing_time} FPS") |