Skip to content

Commit

Permalink
Merge pull request #18 from tensorpix/feauture/inferenceBenchmark
Browse files Browse the repository at this point in the history
Feauture/inference benchmark
  • Loading branch information
bfreskura committed Jun 11, 2024
2 parents 4f4f263 + 4a98dae commit 6395fbc
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 28 deletions.
39 changes: 25 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

<p align="center" >
<img width="400" src="https://cdn.tensorpix.ai/TensorPix-Logo-color.svg" alt="Tensorpix logo"/>
</p>
Expand Down Expand Up @@ -38,15 +37,15 @@ You can use this benchmark repo to:

Please open an issue if you need support for a new architecture.

* ResNet50
* ConvNext (base)
* VGG16
* Efficient Net v2
* MobileNet V3
* ResNeXt50
* SWIN
* VIT
* UNet with ResNet50 backbone
- ResNet50
- ConvNext (base)
- VGG16
- Efficient Net v2
- MobileNet V3
- ResNeXt50
- SWIN
- VIT
- UNet with ResNet50 backbone

## 📖 How to benchmark

Expand All @@ -58,19 +57,31 @@ In order to run benchmark docker containers you must have the following installe
- NVIDIA drivers. See [Versions](#versions) when choosing the docker image.
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) - required in order to use CUDA inside docker containers

### Training vs Inference

To benchmark model training, append the `src.train` when running the container. If you want to benchmark model inference, append the `src.inference` to the docker run command. See examples below for more details.

### Examples

**Minimal**

`docker run --rm --ipc=host --ulimit memlock=-1 --gpus all ghcr.io/tensorpix/benchmarking-cv-models --batch-size 32`
`docker run --rm --ipc=host --ulimit memlock=-1 --gpus all ghcr.io/tensorpix/benchmarking-cv-models src.train --batch-size 32`

**Advanced**

`docker run --rm --ipc=host --ulimit memlock=-1 --gpus '"device=0,1"' -v ./benchmarks:/workdir/benchmarks ghcr.io/tensorpix/benchmarking-cv-models --batch-size 32 --n-iters 1000 --warmup-steps 100 --model resnext50 --precision 16-mixed --width 320 --height 320`
`docker run --rm --ipc=host --ulimit memlock=-1 --gpus '"device=0,1"' -v ./benchmarks:/workdir/benchmarks ghcr.io/tensorpix/benchmarking-cv-models src.train --batch-size 32 --n-iters 1000 --warmup-steps 100 --model resnext50 --precision 16-mixed --width 320 --height 320`

**Benchmark Inference**

`docker run --rm --ipc=host --ulimit memlock=-1 --gpus all ghcr.io/tensorpix/benchmarking-cv-models src.inference --batch-size 32 --n-iters 1000 --model resnext50 --precision 16 --width 256 --height 256`

**List all train options:**

`docker run --rm ghcr.io/tensorpix/benchmarking-cv-models src.train --help`

**List all options:**
**List all inference options:**

`docker run --rm ghcr.io/tensorpix/benchmarking-cv-models --help`
`docker run --rm ghcr.io/tensorpix/benchmarking-cv-models src.inference --help`

### How to select particular GPUs

Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/cuda118/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ RUN pip3 install --no-cache-dir -r /tmp/requirements.txt --extra-index-url https
COPY ./src /workdir/src
WORKDIR /workdir

ENTRYPOINT [ "python3", "-m", "src.train" ]
ENTRYPOINT [ "python3", "-m" ]
4 changes: 2 additions & 2 deletions dockerfiles/cuda120/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ RUN apt update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install
rm -rf /var/lib/apt/lists/*

COPY requirements.txt /tmp/requirements.txt
RUN pip3 install --no-cache-dir -r /tmp/requirements.txt
RUN pip3 install --no-cache-dir -r /tmp/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121

COPY ./src /workdir/src
WORKDIR /workdir

ENTRYPOINT [ "python3", "-m", "src.train" ]
ENTRYPOINT [ "python3", "-m" ]
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
lightning==2.1.4
lightning==2.2.5
protobuf==3.20.*
segmentation-models-pytorch==0.3.3
six==1.16.0
torch==2.1.2
torchvision==0.16.2
torch==2.3.1
torchvision==0.18.1
159 changes: 159 additions & 0 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse

import torch
import torch.utils.benchmark as benchmark

from src import log
from src.log import print_requirements

logger = log.logger

ARCHITECTURES = {
"resnet50": "resnet50",
"convnext": "convnext_base",
"vgg16": "vgg16",
"efficient_net_v2": "efficientnet_v2_m",
"mobilenet_v3": "mobilenet_v3_large",
"resnext50": "resnext50_32x4d",
"swin": "swin_b",
"vit": "vit_b_16",
"ssd_vgg16": "ssd300_vgg16",
"fasterrcnn_resnet50_v2": "fasterrcnn_resnet50_fpn_v2",
}


def benchmark_inference(
stmt: str,
setup: str,
input: torch.Tensor,
n_runs: int = 100,
num_threads: int = 1,
):
"""
Benchmark a model using torch.utils.benchmark.
When evaluating model throughoutput in MP/s only the image height, width and batch size are taken into
account. The number of channels are ignored as they are fixed to 3 channels in most cases (RGB images).
Speed evaluation measures how fast can we process an arbitrary input image so channels
don't affect the model computation speed.
"""

timer = benchmark.Timer(
stmt=stmt,
setup=setup,
num_threads=num_threads,
globals={"x": input},
)

logger.info(
f"Running benchmark on sample of {n_runs} runs with {num_threads} thread(s)..."
)
result = timer.timeit(n_runs)

batch, height, width = input.size(0), input.size(-2), input.size(-1)
total_pixels = batch * width * height

logger.info(f"Batch size: {batch}")
logger.info(f"Input resolution: {width}x{height} pixels\n")

mean_per_batch = result.mean
median_per_batch = result.median

mean_speed_mpx = (total_pixels / 1e6) / mean_per_batch
median_speed_mpx = (total_pixels / 1e6) / median_per_batch

logger.info(
f"Mean throughoutput per {batch} {width}x{height} px frames: {mean_per_batch:.4f} s"
)
logger.info(
f"Median throughoutput per {batch} {width}x{height} px frames: {median_per_batch:.4f} s\n"
)

logger.info(
f"Model mean throughoutput in megapixels per second: {mean_speed_mpx:.3f} MP/s"
)
logger.info(
f"Model median throughoutput in megapixels per second: {median_speed_mpx:.3f} MP/s\n"
)


def main(args):
if args.list_requirements:
print_requirements()

if args.model.lower() not in ARCHITECTURES:
raise ValueError("Architecture not supported.")

stmt = """ \
with torch.inference_mode():
out = model(x)
out = out.float().cpu()
"""

arch = ARCHITECTURES[args.model.lower()]
setup = f"from torchvision.models import {arch}; model = {arch}(); model.eval()"

input_shape = [args.batch_size, 3, args.height, args.width]
precision = torch.float16 if args.precision == "16" else torch.float32

x = torch.rand(*input_shape, dtype=precision)
x = x.cuda(args.gpu_device_index, non_blocking=True)
setup = f"{setup}; model.cuda({args.gpu_device_index})"

if args.precision == "16":
setup = f"{setup}; model.half()"

benchmark_inference(
stmt=stmt,
setup=setup,
input=x,
n_runs=args.n_iters,
num_threads=args.n_workers,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark CV models training on GPU.")

parser.add_argument("--batch-size", type=int, required=True, default=1)
parser.add_argument(
"--n-iters",
type=int,
default=100,
help="Number of training iterations to benchmark for. One iteration = one batch update",
)
parser.add_argument("--precision", choices=["32", "16"], default="16")
parser.add_argument("--n-workers", type=int, default=1)
parser.add_argument("--gpu-device-index", type=int, default=0)

parser.add_argument("--width", type=int, default=224, help="Input width")
parser.add_argument("--height", type=int, default=224, help="Input height")

parser.add_argument(
"--model",
default="resnet50",
choices=list(ARCHITECTURES.keys()),
help="Architecture to benchmark.",
)
parser.add_argument("--list-requirements", action="store_true")

args = parser.parse_args()

if args.n_iters <= 0:
raise ValueError("Number of iterations must be > 0")

logger.info("########## STARTING NEW INFERENCE BENCHMARK RUN ###########")

if not torch.cuda.is_available():
raise ValueError("CUDA device not found on this system.")
else:
logger.info(
f"CUDA Device Name: {torch.cuda.get_device_name(args.gpu_device_index)}"
)
logger.info(f"CUDNN version: {torch.backends.cudnn.version()}")
logger.info(
"CUDA Device Total Memory: "
+ f"{(torch.cuda.get_device_properties(args.gpu_device_index).total_memory / 1e9):.2f} GB"
)

main(args=args)
11 changes: 11 additions & 0 deletions src/log.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pip._internal.operations import freeze


def setup_custom_logger(name: str = "benchmark"):
logger = logging.getLogger(name)
Expand All @@ -14,3 +16,12 @@ def setup_custom_logger(name: str = "benchmark"):
logger.setLevel(level=logging.DEBUG)

return logger


def print_requirements():
pkgs = freeze.freeze()
for pkg in pkgs:
logger.info(pkg)


logger = setup_custom_logger()
10 changes: 2 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import segmentation_models_pytorch as smp
import torch
from lightning import Trainer
from pip._internal.operations import freeze
from torch.utils.data import DataLoader
from torchvision.models import (
convnext_base,
Expand All @@ -19,9 +18,10 @@
from src import log
from src.callbacks import BenchmarkCallback
from src.data.in_memory_dataset import InMemoryDataset
from src.log import print_requirements
from src.models.lightning_modules import LitClassification

logger = log.setup_custom_logger()
logger = log.logger

ARCHITECTURES = {
"resnet50": resnet50,
Expand All @@ -38,12 +38,6 @@
}


def print_requirements():
pkgs = freeze.freeze()
for pkg in pkgs:
logger.info(pkg)


def main(args):
if args.list_requirements:
print_requirements()
Expand Down

0 comments on commit 6395fbc

Please sign in to comment.