Skip to content

Commit

Permalink
Dynamic batch size support for TensorRT (#8526)
Browse files Browse the repository at this point in the history
* Dynamic batch size support for TensorRT

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update export.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix optimization profile when batch size is 1

* Warn users if they use batch-size=1 with dynamic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* More descriptive assertion error

* Fix syntax

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pre-commit formatting sucked

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update export.py

Co-authored-by: Colin Wong <noreply@brains4drones.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
4 people authored Jul 29, 2022
1 parent 3e85863 commit 587a3a3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
21 changes: 15 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,9 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
return None, None


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
prefix = colorstr('TensorRT:')
try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
Expand All @@ -230,11 +231,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13
onnx = file.with_suffix('.onnx')

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
Expand Down Expand Up @@ -263,6 +264,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)

LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
Expand Down Expand Up @@ -460,7 +469,7 @@ def run(
keras=False, # use Keras
optimize=False, # TorchScript: optimize for mobile
int8=False, # CoreML/TF INT8 quantization
dynamic=False, # ONNX/TF: dynamic axes
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
Expand Down Expand Up @@ -520,7 +529,7 @@ def run(
if jit:
f[0] = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO
Expand Down Expand Up @@ -579,7 +588,7 @@ def parse_opt():
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
Expand Down
22 changes: 16 additions & 6 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,19 +384,24 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
context = model.create_execution_context()
bindings = OrderedDict()
fp16 = False # default updated below
dynamic_input = False
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
if model.binding_is_input(index):
if -1 in tuple(model.get_binding_shape(index)): # dynamic
dynamic_input = True
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
if dtype == np.float16:
fp16 = True
shape = tuple(context.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
if model.binding_is_input(index) and dtype == np.float16:
fp16 = True
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
Expand Down Expand Up @@ -466,7 +471,12 @@ def forward(self, im, augment=False, visualize=False, val=False):
im = im.cpu().numpy() # FP32
y = self.executable_network([im])[self.output_layer]
elif self.engine: # TensorRT
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
if im.shape != self.bindings['images'].shape and self.dynamic_input:
self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
assert im.shape == self.bindings['images'].shape, (
f"image shape {im.shape} exceeds model max shape {self.bindings['images'].shape}" if self.dynamic_input
else f"image shape {im.shape} does not match model shape {self.bindings['images'].shape}")
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
Expand Down

0 comments on commit 587a3a3

Please sign in to comment.