Skip to content

Commit

Permalink
torch.jit.trace() fix (#9363)
Browse files Browse the repository at this point in the history
* Update common.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update ci-testing.yml

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher committed Sep 10, 2022
1 parent cafdd18 commit 23d0456
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ jobs:
python export.py --weights $m.pt --img 64 --include torchscript # export
python - <<EOF
import torch
im = torch.zeros([1, 3, 64, 64])
for path in '$m', '$b':
model = torch.hub.load('.', 'custom', path=path, source='local')
print(model('data/images/bus.jpg'))
model(im) # warmup, build grids for trace
torch.jit.trace(model, [im])
EOF
- name: Test classification
shell: bash # for Windows compatibility
Expand Down
1 change: 1 addition & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __init__(self, model, verbose=True):
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.inplace = False # Detect.inplace=False for safe multithread inference
m.export = True # do not output loss values

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
Expand Down

0 comments on commit 23d0456

Please sign in to comment.