Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 30, 2022
1 parent d8615d3 commit a401029
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions notebooks/build_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import sys

from numpy.core.fromnumeric import trace

sys.path.append("./")

import logging
import argparse
import logging
import traceback

import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import traceback

import pycuda.driver as cuda
import tensorrt as trt
from yolort.v5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages

logging.basicConfig(level=logging.INFO)
Expand All @@ -54,8 +54,10 @@ def __init__(self, calib_shape=None, calib_dtype=None) -> None:
self.shape = (self.batch_size, 3, *calib_shape)
self.num_images = len(self.dataset)
self.image_index = 0

def get_batch(self, ):

def get_batch(
self,
):
return iter(self.dataset)


Expand All @@ -73,7 +75,7 @@ def __init__(self, cache_file):
self.image_batcher: ImageBatcher = None
self.batch_allocation = None
self.batch_generator = None

def set_image_batcher(self, image_batcher: ImageBatcher):
"""
Define the image batcher to use, if any. If using only the cache file, an image batcher doesn't need
Expand Down Expand Up @@ -111,16 +113,20 @@ def get_batch(self, names):
image = image[np.newaxis, :, :, :]
batch, _, _, _ = image.shape
self.image_batcher.image_index += 1

log.info("Calibrating image {} / {}".format(self.image_batcher.image_index, self.image_batcher.num_images))

log.info(
"Calibrating image {} / {}".format(
self.image_batcher.image_index, self.image_batcher.num_images
)
)
cuda.memcpy_htod(self.batch_allocation, np.ascontiguousarray(batch))
return [int(self.batch_allocation)]
except StopIteration:
log.info("Finished calibration batches")
return None
except Exception:
traceback.print_exc()

def read_calibration_cache(self):
"""
Overrides from trt.IInt8EntropyCalibrator2.
Expand Down Expand Up @@ -171,7 +177,7 @@ def create_network(self, onnx_path):
Parse the ONNX graph and create the corresponding TensorRT network definition.
:param onnx_path: The path to the ONNX graph to load.
"""
network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

self.network = self.builder.create_network(network_flags)
self.parser = trt.OnnxParser(self.network, self.trt_logger)
Expand All @@ -196,8 +202,16 @@ def create_network(self, onnx_path):
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000,
calib_batch_size=8, calib_preprocessor=None):
def create_engine(
self,
engine_path,
precision,
calib_input=None,
calib_cache=None,
calib_num_images=25000,
calib_batch_size=8,
calib_preprocessor=None,
):
"""
Build the TensorRT engine and serialize it to disk.
:param engine_path: The path where to serialize the engine to.
Expand Down Expand Up @@ -229,9 +243,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
if not os.path.exists(calib_cache):
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
calib_dtype = trt.nptype(inputs[0].dtype)
self.config.int8_calibrator.set_image_batcher(
ImageBatcher(calib_shape, calib_dtype)
)
self.config.int8_calibrator.set_image_batcher(ImageBatcher(calib_shape, calib_dtype))

with self.builder.build_engine(self.network, self.config) as engine:
with open(engine_path, "wb") as f:
Expand All @@ -242,26 +254,53 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
def main(args):
builder = EngineBuilder(args.verbose)
builder.create_network(args.onnx)
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
args.calib_batch_size, args.calib_preprocessor)
builder.create_engine(
args.engine,
args.precision,
args.calib_input,
args.calib_cache,
args.calib_num_images,
args.calib_batch_size,
args.calib_preprocessor,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--onnx", help="The input ONNX model file to load")
parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"],
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'")
parser.add_argument(
"-p",
"--precision",
default="fp16",
choices=["fp32", "fp16", "int8"],
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
parser.add_argument("--calib_input", help="The directory holding images to use for calibration")
parser.add_argument("--calib_cache", default="./calibration.cache",
help="The file path for INT8 calibration cache to use, default: ./calibration.cache")
parser.add_argument("--calib_num_images", default=10, type=int,
help="The maximum number of images to use for calibration, default: 25000")
parser.add_argument("--calib_batch_size", default=1, type=int,
help="The batch size for the calibration process, default: 1")
parser.add_argument("--calib_preprocessor", default="V2", choices=["V1", "V1MS", "V2"],
help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2")
parser.add_argument(
"--calib_cache",
default="./calibration.cache",
help="The file path for INT8 calibration cache to use, default: ./calibration.cache",
)
parser.add_argument(
"--calib_num_images",
default=10,
type=int,
help="The maximum number of images to use for calibration, default: 25000",
)
parser.add_argument(
"--calib_batch_size",
default=1,
type=int,
help="The batch size for the calibration process, default: 1",
)
parser.add_argument(
"--calib_preprocessor",
default="V2",
choices=["V1", "V1MS", "V2"],
help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2",
)
args = parser.parse_args()
if not all([args.onnx, args.engine]):
parser.print_help()
Expand Down

0 comments on commit a401029

Please sign in to comment.