Skip to content

Commit

Permalink
Detect.inplace=False for multithread-safe inference (#8801)
Browse files Browse the repository at this point in the history
Detect.inplace=False for safe multithread inference
  • Loading branch information
glenn-jocher committed Jul 30, 2022
1 parent 7921351 commit 1e89807
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model.model.model[-1].inplace = False # Detect.inplace=False for safe multithread inference
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
if not verbose:
LOGGER.setLevel(logging.INFO) # reset to default
Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
self.inplace = inplace # use inplace ops (e.g. slice assignment)

def forward(self, x):
z = [] # inference output
Expand Down

0 comments on commit 1e89807

Please sign in to comment.