From 4376b7f35f70ae1cf8e72e7cd1358dab9e3a5889 Mon Sep 17 00:00:00 2001 From: Andrei Ionut Damian <44048963+andreiionutdamian@users.noreply.github.com> Date: Sat, 9 Oct 2021 10:58:29 +0300 Subject: [PATCH 1/3] update detect.py in order to support torch script This change assumes the torchscrip file was previously saved with `export.py` --- detect.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/detect.py b/detect.py index 75ec3ecc5ff3..f5cc31d44655 100644 --- a/detect.py +++ b/detect.py @@ -77,7 +77,15 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults if pt: - model = attempt_load(weights, map_location=device) # load FP32 model + if 'torchscript' in w: + # this is torchscript saved not a pickle that can be loaded with th.load + # we assume the file exist in the target folder + if Path(w).is_file(): + model = torch.jit.load(w) + else: + raise ValueError('Cannot find torchscript file {}'.format(Path(w)) + else: + model = attempt_load(weights, map_location=device) # load FP32 model stride = int(model.stride.max()) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names if half: From e2379fa3d8b0cb568e1df37b81841ba008384fbf Mon Sep 17 00:00:00 2001 From: Andrei Ionut Damian <44048963+andreiionutdamian@users.noreply.github.com> Date: Sat, 9 Oct 2021 11:13:23 +0300 Subject: [PATCH 2/3] update `detect.py` for torchscript support Simple update for torchscript support. Assumes the torchscript file has been generated with `export.py` --- detect.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/detect.py b/detect.py index f5cc31d44655..bbe079e0e16f 100644 --- a/detect.py +++ b/detect.py @@ -80,10 +80,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) if 'torchscript' in w: # this is torchscript saved not a pickle that can be loaded with th.load # we assume the file exist in the target folder - if Path(w).is_file(): - model = torch.jit.load(w) - else: - raise ValueError('Cannot find torchscript file {}'.format(Path(w)) + model = torch.jit.load(w) else: model = attempt_load(weights, map_location=device) # load FP32 model stride = int(model.stride.max()) # model stride From 48827272b1b830544f0669e321a1e89e7b96a0ae Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 11 Oct 2021 21:25:47 -0700 Subject: [PATCH 3/3] Cleanup --- detect.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/detect.py b/detect.py index 689cfd1250d8..4e497305668c 100644 --- a/detect.py +++ b/detect.py @@ -79,12 +79,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults if pt: - if 'torchscript' in w: - # this is torchscript saved not a pickle that can be loaded with th.load - # we assume the file exist in the target folder - model = torch.jit.load(w) - else: - model = attempt_load(weights, map_location=device) # load FP32 model + model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device) stride = int(model.stride.max()) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names if half: