Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic batch size support for TensorRT #8526

Merged
merged 21 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d0238a3
Dynamic batch size support for TensorRT
Jul 8, 2022
4c03ec2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2022
5631f4c
Merge branch 'master' into patch-1
glenn-jocher Jul 11, 2022
dc26f06
Merge branch 'master' into patch-1
democat3457 Jul 12, 2022
18ec9a9
Merge branch 'master' into patch-1
democat3457 Jul 13, 2022
10fdc3a
Merge branch 'master' into patch-1
glenn-jocher Jul 15, 2022
9facb54
Update export.py
glenn-jocher Jul 15, 2022
42b2b60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2022
6ac75d6
Fix optimization profile when batch size is 1
democat3457 Jul 15, 2022
e6314be
Merge branch 'master' into patch-1
democat3457 Jul 19, 2022
6a7c398
Merge branch 'master' into patch-1
glenn-jocher Jul 21, 2022
d7411ad
Merge branch 'master' into patch-1
democat3457 Jul 26, 2022
f9cad51
Warn users if they use batch-size=1 with dynamic
democat3457 Jul 27, 2022
cbdc898
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2022
0dc949d
More descriptive assertion error
democat3457 Jul 27, 2022
f973a14
Fix syntax
democat3457 Jul 27, 2022
8be2d36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2022
c709d49
pre-commit formatting sucked
democat3457 Jul 27, 2022
52d0cfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2022
52c9ace
Merge branch 'master' into patch-1
glenn-jocher Jul 27, 2022
352d45a
Update export.py
glenn-jocher Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,16 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
return None, None


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
def export_engine(model,
im,
file,
train,
half,
dynamic,
simplify,
workspace=4,
verbose=False,
prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
Expand All @@ -231,11 +240,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13
onnx = file.with_suffix('.onnx')

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
Expand Down Expand Up @@ -264,6 +273,12 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

if dynamic:
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (im.shape[0] // 2, *im.shape[1:]), im.shape)
glenn-jocher marked this conversation as resolved.
Show resolved Hide resolved
config.add_optimization_profile(profile)

LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
Expand Down Expand Up @@ -461,7 +476,7 @@ def run(
keras=False, # use Keras
optimize=False, # TorchScript: optimize for mobile
int8=False, # CoreML/TF INT8 quantization
dynamic=False, # ONNX/TF: dynamic axes
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
Expand Down Expand Up @@ -519,7 +534,7 @@ def run(
if jit:
f[0] = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO
Expand Down Expand Up @@ -578,7 +593,7 @@ def parse_opt():
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
Expand Down
18 changes: 13 additions & 5 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,19 +384,24 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
context = model.create_execution_context()
bindings = OrderedDict()
fp16 = False # default updated below
dynamic_input = False
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
if model.binding_is_input(index):
if -1 in tuple(model.get_binding_shape(index)): # dynamic
dynamic_input = True
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
if dtype == np.float16:
fp16 = True
shape = tuple(context.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
if model.binding_is_input(index) and dtype == np.float16:
fp16 = True
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
Expand Down Expand Up @@ -464,6 +469,9 @@ def forward(self, im, augment=False, visualize=False, val=False):
im = im.cpu().numpy() # FP32
y = self.executable_network([im])[self.output_layer]
elif self.engine: # TensorRT
if im.shape != self.bindings['images'].shape and self.dynamic_input:
self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
Expand Down