Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FPS counter to TimerCallback #12

Merged
merged 2 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install Tox
run: pip install tox
- name: Code quality checks
run: tox -e black,isort,flake8,pylint,mypy
run: tox -e black,isort,flake8,pylint,mypy,pydocstyle
- name: Coverage
run: tox -e coverage
- name: Upload coverage result
Expand Down
70 changes: 59 additions & 11 deletions anomalib/core/callbacks/timer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,75 @@
"""Callback to measure training and testing time of a PyTorch Lightning module."""
import time

from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback, LightningModule, Trainer


class TimerCallback(Callback):
"""Callback that measures the training and testing time of a PyTorch Lightning module."""

def __init__(self):
self.start = None
self.start: float
self.num_images: float = 0
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

def on_fit_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit begins."""
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit begins.

Sets the start time to the time training started.

Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.

Returns:
None
"""
self.start = time.time()

def on_fit_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit ends."""
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit ends.

Prints the time taken for trining.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo


Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.

Returns:
None
"""
print(f"Training took {time.time() - self.start} seconds")

def on_test_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test begins."""
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test begins.

Sets the start time to the time testing started.
Goes over all the test dataloaders and adds the number of images in each.

Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.

Returns:
None
"""
self.start = time.time()
self.num_images = 0

if trainer.test_dataloaders is not None: # Check to placate Mypy.
for dataloader in trainer.test_dataloaders:
self.num_images += len(dataloader)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test ends.

Prints the time taken for testing and the throughput in frames per second.

Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.

def on_test_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test ends."""
print(f"Testing took {time.time() - self.start} seconds.")
Returns:
None
"""
testing_time = time.time() - self.start
print(f"Testing took {testing_time} seconds.\nThroughput: {self.num_images/testing_time} FPS")