Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New smart_inference_mode() conditional decorator #8957

Merged
merged 1 commit into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync


@torch.no_grad()
@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
Expand Down
4 changes: 2 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
colorstr, file_size, print_args, url2file)
from utils.torch_utils import select_device
from utils.torch_utils import select_device, smart_inference_mode


def export_formats():
Expand Down Expand Up @@ -455,7 +455,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
LOGGER.info(f'\n{prefix} export failure: {e}')


@torch.no_grad()
@smart_inference_mode()
def run(
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
Expand Down
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, time_sync
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -578,7 +578,7 @@ def _apply(self, fn):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
@smart_inference_mode()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def forward(self, x):

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

def _make_grid(self, nx=20, ny=20, i=0):
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
d = self.anchors[i].device
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
yv, xv = torch.meshgrid(y, x, indexing='ij')
else:
yv, xv = torch.meshgrid(y, x)
Expand Down
26 changes: 17 additions & 9 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')


def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)

return decorate


def smart_DDP(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
Expand Down Expand Up @@ -364,17 +372,17 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
for p in self.ema.parameters():
p.requires_grad_(False)

@smart_inference_mode()
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)

msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
self.updates += 1
d = self.decay(self.updates)

msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
Expand Down
4 changes: 2 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
scale_coords, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, time_sync
from utils.torch_utils import select_device, smart_inference_mode, time_sync


def save_one_txt(predn, save_conf, shape, file):
Expand Down Expand Up @@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv):
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)


@torch.no_grad()
@smart_inference_mode()
def run(
data,
weights=None, # model.pt path(s)
Expand Down