Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additions and tests for torchscript inference (#5161) #5215

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,30 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
jit = pt and 'torchscript' in w
if pt:
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
if jit:
import json
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
try:
d = json.loads(extra_files['config.txt']) # extra_files dict
ts_h, ts_w = d['HW']
stride = int(d['STRIDE'])
ts_params = True # torchscript predefined params
print(f"Using saved graph params HW: {(ts_h, ts_w)}, stride: {stride}")
except:
ts_params = False
print(f"Failed to load default jit graph params from {w}")
else:
model = attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
if dnn:
check_requirements(('opencv-python>=4.5.4',))
Expand Down Expand Up @@ -114,15 +129,18 @@ def wrap_frozen_graph(gd, inputs, outputs):
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size
if jit and ts_params and imgsz != [ts_h, ts_w]:
print(f"Changing resolution from {imgsz} to {[ts_h, ts_w]} based to graph defaults")
imgsz = ts_h, ts_w

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

Expand All @@ -146,7 +164,10 @@ def wrap_frozen_graph(gd, inputs, outputs):
# Inference
if pt:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0]
if jit:
pred = model(img)[0]
else:
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
if dnn:
net.setInput(img)
Expand Down
13 changes: 11 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,21 @@

def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
import json # required to save graph props
try:
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')

h, w = im.shape[-2:]
batch_size = im.shape[0]
stride = int(max(model.stride))
dev_type = next(model.parameters()).device.type
dct_params = {"HW": [h, w], "BATCH": batch_size, "STRIDE": stride, "DEVICE": dev_type}
str_config = json.dumps(dct_params)
extra_files = {'config.txt': str_config} # torch._C.ExtraFilesMap()

ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
(optimize_for_mobile(ts) if optimize else ts).save(f, _extra_files=extra_files)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
Expand Down Expand Up @@ -333,7 +342,7 @@ def parse_opt():
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
parser.add_argument('--train', action='store_true', help='model.train() mode')
Expand Down