Skip to content

Commit

Permalink
ultralytics 8.1.4 RTDETR TensorBoard graph visualization fix (ultra…
Browse files Browse the repository at this point in the history
…lytics#7725)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
  • Loading branch information
2 people authored and gkinman committed May 30, 2024
1 parent c608e3d commit 58afd8a
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ dev = [
"mkdocstrings[python]",
"mkdocs-jupyter", # for notebooks
"mkdocs-redirects", # for 301 redirects
"mkdocs-ultralytics-plugin>=0.0.34", # for meta descriptions and images, dates and authors
"mkdocs-ultralytics-plugin>=0.0.38", # for meta descriptions and images, dates and authors
]
export = [
"onnx>=1.12.0", # ONNX export
Expand Down
18 changes: 15 additions & 3 deletions ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.1.3"
__version__ = "8.1.4"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS
from ultralytics.utils import SETTINGS as settings
from ultralytics.utils import ASSETS, SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download

__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"
__all__ = (
"__version__",
"ASSETS",
"YOLO",
"NAS",
"SAM",
"FastSAM",
"RTDETR",
"checks",
"download",
"settings",
"Explorer",
)
4 changes: 3 additions & 1 deletion ultralytics/cfg/datasets/Argoverse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ download: |
# Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link)
dir = Path(yaml['path']) # dataset root dir
urls = ['https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link']
download(urls, dir=dir)
print("\n\nWARNING: Argoverse dataset MUST be downloaded manually, autodownload will NOT work.")
print(f"WARNING: Manually download Argoverse dataset '{urls[0]}' to '{dir}' and re-run your command.\n\n")
# download(urls, dir=dir)
# Convert
annotations_dir = 'Argoverse-HD/annotations/'
Expand Down
4 changes: 3 additions & 1 deletion ultralytics/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ def _apply(self, fn):
@property
def names(self):
"""Returns class names of the loaded model."""
return self.model.names if hasattr(self.model, "names") else None
from ultralytics.nn.autobackend import check_class_names

return check_class_names(self.model.names) if hasattr(self.model, "names") else None

@property
def device(self):
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/nn/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _get_encoder_input(self, x):

def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
"""Generates and prepares the input required for the decoder from the provided features and shapes."""
bs = len(feats)
bs = feats.shape[0]
# Prepare input for decoder
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
Expand Down
6 changes: 3 additions & 3 deletions ultralytics/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def forward(self, x):
@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
"""Builds 2D sine-cosine position embedding."""
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
Expand Down
53 changes: 38 additions & 15 deletions ultralytics/utils/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib

from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr

try:
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter

assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")

# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch

except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
Expand All @@ -25,20 +32,37 @@ def _log_scalars(scalars, step=0):

def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
try:
import warnings

from ultralytics.utils.torch_utils import de_parallel, torch
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning

imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return

# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")


def on_pretrain_routine_start(trainer):
Expand All @@ -47,10 +71,9 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr("TensorBoard: ")
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")


def on_train_start(trainer):
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def non_max_suppression(

# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 0.5 + max_time_img * bs # seconds to quit after
time_limit = 2.0 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)

prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
Expand Down

0 comments on commit 58afd8a

Please sign in to comment.