Skip to content

Commit

Permalink
Add FPS counter to TimerCallback (#12)
Browse files Browse the repository at this point in the history
* Add FPS counter to `TimerCallback`

- Fix docs strings
- Add frames per second calculation to `timer.py`

* Address PR comments

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 and Ashwin Vaidya committed Dec 3, 2021
1 parent 310b6e4 commit cf8bdf6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install Tox
run: pip install tox
- name: Code quality checks
run: tox -e black,isort,flake8,pylint,mypy
run: tox -e black,isort,flake8,pylint,mypy,pydocstyle
- name: Coverage
run: tox -e coverage
- name: Upload coverage result
Expand Down
70 changes: 59 additions & 11 deletions anomalib/core/callbacks/timer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,75 @@
"""Callback to measure training and testing time of a PyTorch Lightning module."""
import time

from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback, LightningModule, Trainer


class TimerCallback(Callback):
"""Callback that measures the training and testing time of a PyTorch Lightning module."""

def __init__(self):
self.start = None
self.start: float
self.num_images: int = 0

def on_fit_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit begins."""
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit begins.
Sets the start time to the time training started.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
self.start = time.time()

def on_fit_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit ends."""
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit ends.
Prints the time taken for training.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
print(f"Training took {time.time() - self.start} seconds")

def on_test_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test begins."""
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test begins.
Sets the start time to the time testing started.
Goes over all the test dataloaders and adds the number of images in each.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
self.start = time.time()
self.num_images = 0

if trainer.test_dataloaders is not None: # Check to placate Mypy.
for dataloader in trainer.test_dataloaders:
self.num_images += len(dataloader.dataset)

def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test ends.
Prints the time taken for testing and the throughput in frames per second.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
def on_test_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test ends."""
print(f"Testing took {time.time() - self.start} seconds.")
Returns:
None
"""
testing_time = time.time() - self.start
print(f"Testing took {testing_time} seconds\nThroughput: {self.num_images/testing_time} FPS")

0 comments on commit cf8bdf6

Please sign in to comment.