From eecafc524bbb2165cc41ecb868a624e6dac32d9d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 12:49:15 +0200 Subject: [PATCH 01/13] TensorRT SegmentationModel fix --- models/common.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index 825a4c4e2633..baed9476765c 100644 --- a/models/common.py +++ b/models/common.py @@ -392,16 +392,16 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, bindings = OrderedDict() fp16 = False # default updated below dynamic = False - for index in range(model.num_bindings): - name = model.get_binding_name(index) - dtype = trt.nptype(model.get_binding_dtype(index)) - if model.binding_is_input(index): - if -1 in tuple(model.get_binding_shape(index)): # dynamic + for i in range(model.num_bindings): + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic dynamic = True - context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2])) + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) if dtype == np.float16: fp16 = True - shape = tuple(context.get_binding_shape(index)) + shape = tuple(context.get_binding_shape(i)) im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) From 3eb91b9d874cd3ce35a94e7404088390af2a73f9 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:15:51 +0200 Subject: [PATCH 02/13] TensorRT SegmentationModel fix --- export.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/export.py b/export.py index a575c73e375f..b881d8f21dee 100644 --- a/export.py +++ b/export.py @@ -66,7 +66,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.experimental import attempt_load -from models.yolo import ClassificationModel, Detect +from models.yolo import ClassificationModel, DetectionModel, SegmentationModel, Detect from utils.dataloaders import LoadImages from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version, check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save) @@ -134,6 +134,35 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') f = file.with_suffix('.onnx') + y = model(im) + n_out = len(y) if isinstance(y, (list, tuple)) else 1 # number of outputs + + if dynamic: + if isinstance(model, SegmentationModel): + dynamic_axes = { + 'images': { + 0: 'batch', + 2: 'height', + 3: 'width'}, # shape(1,3,640,640) + 'output0': { + 0: 'batch', + 1: 'anchors'}, # shape(1,25200,85) + 'output1': { + 0: 'batch', + 2: 'mask_height', + 3: 'mask_width'}} # shape(1,32,160,160) + elif isinstance(model, DetectionModel): + dynamic_axes = { + 'images': { + 0: 'batch', + 2: 'height', + 3: 'width'}, # shape(1,3,640,640) + 'output0': { + 0: 'batch', + 1: 'anchors'}} # shape(1,25200,85) + else: + dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + torch.onnx.export( model.cpu() if dynamic else model, # --dynamic only compatible with cpu im.cpu() if dynamic else im, @@ -142,16 +171,8 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX opset_version=opset, do_constant_folding=True, input_names=['images'], - output_names=['output'], - dynamic_axes={ - 'images': { - 0: 'batch', - 2: 'height', - 3: 'width'}, # shape(1,3,640,640) - 'output': { - 0: 'batch', - 1: 'anchors'} # shape(1,25200,85) - } if dynamic else None) + output_names=[f'output{i}' for i in range(n_out)], + dynamic_axes=dynamic_axes if dynamic else None) # Checks model_onnx = onnx.load(f) # load onnx model From b894812813f3c691f30ffc4cb1a4b11005c81254 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Sep 2022 11:20:07 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export.py b/export.py index b881d8f21dee..d5c1bd74c009 100644 --- a/export.py +++ b/export.py @@ -66,7 +66,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.experimental import attempt_load -from models.yolo import ClassificationModel, DetectionModel, SegmentationModel, Detect +from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel from utils.dataloaders import LoadImages from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version, check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save) From 6bbb93bb4a7955267585c90effda71b5ac5d42d6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:26:08 +0200 Subject: [PATCH 04/13] TensorRT SegmentationModel fix --- export.py | 42 +++++++++++++----------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/export.py b/export.py index d5c1bd74c009..5de97998ff70 100644 --- a/export.py +++ b/export.py @@ -134,34 +134,18 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') f = file.with_suffix('.onnx') - y = model(im) - n_out = len(y) if isinstance(y, (list, tuple)) else 1 # number of outputs - + output_names = ['output0'] if dynamic: if isinstance(model, SegmentationModel): - dynamic_axes = { - 'images': { - 0: 'batch', - 2: 'height', - 3: 'width'}, # shape(1,3,640,640) - 'output0': { - 0: 'batch', - 1: 'anchors'}, # shape(1,25200,85) - 'output1': { - 0: 'batch', - 2: 'mask_height', - 3: 'mask_width'}} # shape(1,32,160,160) + output_names = ['output0', 'output1'] + dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) + 'output0': {0: 'batch', 1: 'anchors'}, # shape(1,25200,85) + 'output1': {0: 'batch', 2: 'mask_height', 3: 'mask_width'}} # shape(1,32,160,160) elif isinstance(model, DetectionModel): - dynamic_axes = { - 'images': { - 0: 'batch', - 2: 'height', - 3: 'width'}, # shape(1,3,640,640) - 'output0': { - 0: 'batch', - 1: 'anchors'}} # shape(1,25200,85) + dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) + 'output0': {0: 'batch', 1: 'anchors'}} # shape(1,25200,85) else: - dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) torch.onnx.export( model.cpu() if dynamic else model, # --dynamic only compatible with cpu @@ -171,8 +155,8 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX opset_version=opset, do_constant_folding=True, input_names=['images'], - output_names=[f'output{i}' for i in range(n_out)], - dynamic_axes=dynamic_axes if dynamic else None) + output_names=output_names, + dynamic_axes=dynamic) # Checks model_onnx = onnx.load(f) # load onnx model @@ -462,9 +446,9 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' - r'"Identity_1": {"name": "Identity_1"}, ' - r'"Identity_2": {"name": "Identity_2"}, ' - r'"Identity_3": {"name": "Identity_3"}}}', json) + r'"Identity_1": {"name": "Identity_1"}, ' + r'"Identity_2": {"name": "Identity_2"}, ' + r'"Identity_3": {"name": "Identity_3"}}}', json) j.write(subst) return f, None From 6d848579862f489643aab65c82c5ea0189be0b35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Sep 2022 11:26:33 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- export.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/export.py b/export.py index 5de97998ff70..51d9a8b0508b 100644 --- a/export.py +++ b/export.py @@ -138,12 +138,27 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX if dynamic: if isinstance(model, SegmentationModel): output_names = ['output0', 'output1'] - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) - 'output0': {0: 'batch', 1: 'anchors'}, # shape(1,25200,85) - 'output1': {0: 'batch', 2: 'mask_height', 3: 'mask_width'}} # shape(1,32,160,160) + dynamic = { + 'images': { + 0: 'batch', + 2: 'height', + 3: 'width'}, # shape(1,3,640,640) + 'output0': { + 0: 'batch', + 1: 'anchors'}, # shape(1,25200,85) + 'output1': { + 0: 'batch', + 2: 'mask_height', + 3: 'mask_width'}} # shape(1,32,160,160) elif isinstance(model, DetectionModel): - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) - 'output0': {0: 'batch', 1: 'anchors'}} # shape(1,25200,85) + dynamic = { + 'images': { + 0: 'batch', + 2: 'height', + 3: 'width'}, # shape(1,3,640,640) + 'output0': { + 0: 'batch', + 1: 'anchors'}} # shape(1,25200,85) else: dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) @@ -446,9 +461,9 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' - r'"Identity_1": {"name": "Identity_1"}, ' - r'"Identity_2": {"name": "Identity_2"}, ' - r'"Identity_3": {"name": "Identity_3"}}}', json) + r'"Identity_1": {"name": "Identity_1"}, ' + r'"Identity_2": {"name": "Identity_2"}, ' + r'"Identity_3": {"name": "Identity_3"}}}', json) j.write(subst) return f, None From 43654e95bbe6c749c357998ac8d2d4faff058235 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:29:05 +0200 Subject: [PATCH 06/13] TensorRT SegmentationModel fix --- export.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/export.py b/export.py index 5de97998ff70..fa37aab5d7c5 100644 --- a/export.py +++ b/export.py @@ -136,16 +136,13 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX output_names = ['output0'] if dynamic: + dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) if isinstance(model, SegmentationModel): output_names = ['output0', 'output1'] - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) - 'output0': {0: 'batch', 1: 'anchors'}, # shape(1,25200,85) - 'output1': {0: 'batch', 2: 'mask_height', 3: 'mask_width'}} # shape(1,32,160,160) + dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) + dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) elif isinstance(model, DetectionModel): - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) - 'output0': {0: 'batch', 1: 'anchors'}} # shape(1,25200,85) - else: - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) torch.onnx.export( model.cpu() if dynamic else model, # --dynamic only compatible with cpu From fa53bd2044a033003bc3b33aa639582ead487afe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Sep 2022 11:30:26 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/export.py b/export.py index fa37aab5d7c5..74fb8a0f07b2 100644 --- a/export.py +++ b/export.py @@ -443,9 +443,9 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' - r'"Identity_1": {"name": "Identity_1"}, ' - r'"Identity_2": {"name": "Identity_2"}, ' - r'"Identity_3": {"name": "Identity_3"}}}', json) + r'"Identity_1": {"name": "Identity_1"}, ' + r'"Identity_2": {"name": "Identity_2"}, ' + r'"Identity_3": {"name": "Identity_3"}}}', json) j.write(subst) return f, None From 11b056cdd4a8150d02b7e014c10d4382397a4362 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:32:11 +0200 Subject: [PATCH 08/13] TensorRT SegmentationModel fix --- export.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/export.py b/export.py index fa37aab5d7c5..5bde35b36aea 100644 --- a/export.py +++ b/export.py @@ -134,11 +134,10 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') f = file.with_suffix('.onnx') - output_names = ['output0'] + output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] if dynamic: dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) if isinstance(model, SegmentationModel): - output_names = ['output0', 'output1'] dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) elif isinstance(model, DetectionModel): From a956b13cf977e0f4e7fb1354534b3293ff02dbf3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:50:57 +0200 Subject: [PATCH 09/13] TensorRT SegmentationModel fix --- models/common.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/common.py b/models/common.py index baed9476765c..b1e893223db8 100644 --- a/models/common.py +++ b/models/common.py @@ -405,6 +405,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + output_names = [x for x in list(bindings.keys()) if x.startswith('output')] batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size elif coreml: # CoreML LOGGER.info(f'Loading {w} for CoreML inference...') @@ -495,15 +496,17 @@ def forward(self, im, augment=False, visualize=False): y = list(self.executable_network([im]).values()) elif self.engine: # TensorRT if self.dynamic and im.shape != self.bindings['images'].shape: - i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output')) - self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic + i = self.model.get_binding_index('images') + self.context.set_binding_shape(i, im.shape) # reshape if dynamic self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) - self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out))) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) s = self.bindings['images'].shape assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" self.binding_addrs['images'] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) - y = self.bindings['output'].data + y = [self.bindings[x].data for x in self.output_names] elif self.coreml: # CoreML im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = Image.fromarray((im[0] * 255).astype('uint8')) From 520e86cd46ead66f88265bd7b6a09242f043b6fc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 13:55:01 +0200 Subject: [PATCH 10/13] fix --- export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export.py b/export.py index 8f9a951d22ff..9955870e9e43 100644 --- a/export.py +++ b/export.py @@ -152,7 +152,7 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX do_constant_folding=True, input_names=['images'], output_names=output_names, - dynamic_axes=dynamic) + dynamic_axes=dynamic or None) # Checks model_onnx = onnx.load(f) # load onnx model From 10e15351ef7a5f6c350e3c16270a9b7e37e43373 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 14:16:11 +0200 Subject: [PATCH 11/13] sort output names --- models/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index b1e893223db8..d0bc65e02f91 100644 --- a/models/common.py +++ b/models/common.py @@ -390,6 +390,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, model = runtime.deserialize_cuda_engine(f.read()) context = model.create_execution_context() bindings = OrderedDict() + output_names = [] fp16 = False # default updated below dynamic = False for i in range(model.num_bindings): @@ -401,11 +402,12 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) if dtype == np.float16: fp16 = True + else: # output + output_names.append(name) shape = tuple(context.get_binding_shape(i)) im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) - output_names = [x for x in list(bindings.keys()) if x.startswith('output')] batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size elif coreml: # CoreML LOGGER.info(f'Loading {w} for CoreML inference...') @@ -506,7 +508,7 @@ def forward(self, im, augment=False, visualize=False): assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" self.binding_addrs['images'] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) - y = [self.bindings[x].data for x in self.output_names] + y = [self.bindings[x].data for x in sorted(self.output_names)] elif self.coreml: # CoreML im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = Image.fromarray((im[0] * 255).astype('uint8')) From 99e7a36a0c420bdde8046d025763fbde56ae0576 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 14:22:25 +0200 Subject: [PATCH 12/13] Update ci-testing.yml Signed-off-by: Glenn Jocher --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 537ba96e7225..fffc92d1b72f 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -43,7 +43,7 @@ jobs: python benchmarks.py --data coco128.yaml --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29 - name: Benchmark SegmentationModel run: | - python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 + python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 --hard-fail 0.22 Tests: timeout-minutes: 60 From ae276d494fd04a81b5a8a2b7f0a14eaa8057eaa2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Sep 2022 14:44:42 +0200 Subject: [PATCH 13/13] Update ci-testing.yml Signed-off-by: Glenn Jocher --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index fffc92d1b72f..537ba96e7225 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -43,7 +43,7 @@ jobs: python benchmarks.py --data coco128.yaml --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29 - name: Benchmark SegmentationModel run: | - python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 --hard-fail 0.22 + python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 Tests: timeout-minutes: 60