Skip to content

Commit

Permalink
Fix inconsistent benchmarking throughput/time (#221)
Browse files Browse the repository at this point in the history
* Supress model outputs and print values from current script

* Address PR comments

* Print batch size in timer callback

* Fix mypy issue

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 and Ashwin Vaidya committed Apr 12, 2022
1 parent aae3d62 commit 487ff45
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
6 changes: 5 additions & 1 deletion anomalib/utils/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: #
None
"""
testing_time = time.time() - self.start
print(f"Testing took {testing_time} seconds\nThroughput: {self.num_images/testing_time} FPS")
output = f"Testing took {testing_time} seconds\nThroughput "
if trainer.test_dataloaders is not None:
output += f"(batch_size={trainer.test_dataloaders[0].batch_size})"
output += f" : {self.num_images/testing_time} FPS"
print(output)
41 changes: 41 additions & 0 deletions tools/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# and limitations under the License.


import functools
import io
import logging
import math
import multiprocessing
import sys
import time
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -41,9 +45,41 @@
set_in_nested_config,
)

warnings.filterwarnings("ignore")

logger = logging.getLogger(__file__)
for logger_name in ["pytorch_lightning", "torchmetrics", "os"]:
logging.getLogger(logger_name).setLevel(logging.ERROR)


def hide_output(func):
"""Decorator to hide output of the function.
Args:
func (function): Hides output of this function.
Raises:
Exception: Incase the execution of function fails, it raises an exception.
Returns:
object of the called function
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
std_out = sys.stdout
sys.stdout = buf = io.StringIO()
try:
value = func(*args, **kwargs)
except Exception as exp:
raise Exception(buf.getvalue()) from exp
sys.stdout = std_out
return value

return wrapper


@hide_output
def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvino_metrics: bool = False) -> Dict:
"""Collects metrics for `model_name` and returns a dict of results.
Expand All @@ -65,6 +101,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi
trainer = Trainer(**model_config.trainer, logger=None, callbacks=callbacks)

start_time = time.time()

trainer.fit(model=model, datamodule=datamodule)

# get start time
Expand Down Expand Up @@ -218,6 +255,10 @@ def sweep(run_config: Union[DictConfig, ListConfig], device: int = 0, seed: int

# Run benchmarking for current config
model_metrics = get_single_model_metrics(model_config=model_config, openvino_metrics=convert_openvino)
output = f"One sweep run complete for model {model_config.model.name}"
output += f" On category {model_config.dataset.category}" if model_config.dataset.category is not None else ""
output += str(model_metrics)
print(output)

# Append configuration of current run to the collected metrics
for key, value in run_config.items():
Expand Down

0 comments on commit 487ff45

Please sign in to comment.