Skip to content

Commit

Permalink
New smart_inference_mode() conditional decorator (ultralytics#8957)
Browse files Browse the repository at this point in the history
New smart_inference_mode()
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 92c81dc commit 57abf8e
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 19 deletions.
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync


@torch.no_grad()
@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
Expand Down
4 changes: 2 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
colorstr, file_size, print_args, url2file)
from utils.torch_utils import select_device
from utils.torch_utils import select_device, smart_inference_mode


def export_formats():
Expand Down Expand Up @@ -574,7 +574,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
LOGGER.info(f'\n{prefix} export failure: {e}')


@torch.no_grad()
@smart_inference_mode()
def run(
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
Expand Down
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, time_sync
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -578,7 +578,7 @@ def _apply(self, fn):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
@smart_inference_mode()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def forward(self, x):

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

def _make_grid(self, nx=20, ny=20, i=0):
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
d = self.anchors[i].device
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
yv, xv = torch.meshgrid(y, x, indexing='ij')
else:
yv, xv = torch.meshgrid(y, x)
Expand Down
26 changes: 17 additions & 9 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')


def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)

return decorate


def smart_DDP(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
Expand Down Expand Up @@ -364,17 +372,17 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
for p in self.ema.parameters():
p.requires_grad_(False)

@smart_inference_mode()
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)

msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
self.updates += 1
d = self.decay(self.updates)

msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
Expand Down
4 changes: 2 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync


def save_one_txt(predn, save_conf, shape, file):
Expand Down Expand Up @@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv):
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)


@torch.no_grad()
@smart_inference_mode()
def run(
data,
weights=None, # model.pt path(s)
Expand Down

0 comments on commit 57abf8e

Please sign in to comment.