diff --git a/docs/en/models/yolov9.md b/docs/en/models/yolov9.md index 07eab68e227..3252c1d0b0d 100644 --- a/docs/en/models/yolov9.md +++ b/docs/en/models/yolov9.md @@ -76,15 +76,62 @@ The YOLOv9-C model, in particular, highlights the effectiveness of the architect These results showcase YOLOv9's strategic advancements in model design, emphasizing its enhanced efficiency without compromising on the precision essential for real-time object detection tasks. The model not only pushes the boundaries of performance metrics but also emphasizes the importance of computational efficiency, making it a pivotal development in the field of computer vision. -## Integration and Future Directions - -YOLOv9 embodies the spirit of open-source collaboration that is central to the advancement of AI technology. With plans for future integration into the Ultralytics package, YOLOv9 is poised to become an accessible tool for researchers and practitioners alike, further enhancing its impact on the field of computer vision. - ## Conclusion YOLOv9 represents a pivotal development in real-time object detection, offering significant improvements in terms of efficiency, accuracy, and adaptability. By addressing critical challenges through innovative solutions like PGI and GELAN, YOLOv9 sets a new precedent for future research and application in the field. As the AI community continues to evolve, YOLOv9 stands as a testament to the power of collaboration and innovation in driving technological progress. -Stay tuned for updates on Ultralytics package integration and explore the possibilities that YOLOv9 brings to the realm of computer vision. + +## Usage Examples + +This example provides simple YOLOv9 training and inference examples. For full documentation on these and other [modes](../modes/index.md) see the [Predict](../modes/predict.md), [Train](../modes/train.md), [Val](../modes/val.md) and [Export](../modes/export.md) docs pages. + +!!! Example + + === "Python" + + PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLO()` class to create a model instance in python: + + ```python + from ultralytics import YOLO + + # Build a YOLOv9c model from scratch + model = YOLO('yolov9c.yaml') + + # Build a YOLOv9c model from pretrained weight + model = YOLO('yolov9c.pt') + + # Display model information (optional) + model.info() + + # Train the model on the COCO8 example dataset for 100 epochs + results = model.train(data='coco8.yaml', epochs=100, imgsz=640) + + # Run inference with the YOLOv9c model on the 'bus.jpg' image + results = model('path/to/bus.jpg') + ``` + + === "CLI" + + CLI commands are available to directly run the models: + + ```bash + # Build a YOLOv9c model from scratch and train it on the COCO8 example dataset for 100 epochs + yolo train model=yolov9c.yaml data=coco8.yaml epochs=100 imgsz=640 + + # Build a YOLOv9c model from scratch and run inference on the 'bus.jpg' image + yolo predict model=yolov9c.yaml source=path/to/bus.jpg + ``` + +## Supported Tasks and Modes + +The YOLOv9 series offers a range of models, each optimized for high-performance [Object Detection](../tasks/detect.md). These models cater to varying computational needs and accuracy requirements, making them versatile for a wide array of applications. + +| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | +|------------|-----------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------| +| YOLOv9-C | [yolov9c.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov9c.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ✅ | ✅ | +| YOLOv9-E | [yolov9e.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov9e.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ✅ | ✅ | + +This table provides a detailed overview of the YOLOv9 model variants, highlighting their capabilities in object detection tasks and their compatibility with various operational modes such as [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md). This comprehensive support ensures that users can fully leverage the capabilities of YOLOv9 models in a broad range of object detection scenarios. ## Citations and Acknowledgements diff --git a/docs/en/reference/nn/modules/block.md b/docs/en/reference/nn/modules/block.md index cd2c146bd9b..e94da5ae800 100644 --- a/docs/en/reference/nn/modules/block.md +++ b/docs/en/reference/nn/modules/block.md @@ -106,3 +106,35 @@ keywords: YOLO, Ultralytics, neural network, nn.modules.block, Proto, HGBlock, S ## ::: ultralytics.nn.modules.block.BNContrastiveHead

+ +## ::: ultralytics.nn.modules.block.RepBottleneck + +

+ +## ::: ultralytics.nn.modules.block.RepCSP + +

+ +## ::: ultralytics.nn.modules.block.RepNCSPELAN4 + +

+ +## ::: ultralytics.nn.modules.block.ADown + +

+ +## ::: ultralytics.nn.modules.block.SPPELAN + +

+ +## ::: ultralytics.nn.modules.block.Silence + +

+ +## ::: ultralytics.nn.modules.block.CBLinear + +

+ +## ::: ultralytics.nn.modules.block.CBFuse + +

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index e65cb8a892d..8f76a568342 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.22" +__version__ = "8.1.23" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/models/v9/yolov9c.yaml b/ultralytics/cfg/models/v9/yolov9c.yaml new file mode 100644 index 00000000000..66c02d64719 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9c.yaml @@ -0,0 +1,36 @@ +# YOLOv9 + +# parameters +nc: 80 # number of classes + +# gelan backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # DDetect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9e.yaml b/ultralytics/cfg/models/v9/yolov9e.yaml new file mode 100644 index 00000000000..8e15a42bb94 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9e.yaml @@ -0,0 +1,60 @@ +# YOLOv9 + +# parameters +nc: 80 # number of classes + +# gelan backbone +backbone: + - [-1, 1, Silence, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# gelan head +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + # detect + - [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 173983e2107..fd3aea3afa6 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -201,7 +201,6 @@ def __call__(self, model=None): assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" if edgetpu and not LINUX: raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") - print(type(model)) if isinstance(model, WorldModel): LOGGER.warning( "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n" diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py index 34df22215c7..d785c008c8b 100644 --- a/ultralytics/nn/modules/__init__.py +++ b/ultralytics/nn/modules/__init__.py @@ -40,6 +40,12 @@ ResNetLayer, ContrastiveHead, BNContrastiveHead, + RepNCSPELAN4, + ADown, + SPPELAN, + CBFuse, + CBLinear, + Silence, ) from .conv import ( CBAM, @@ -123,4 +129,10 @@ "ImagePoolingAttn", "ContrastiveHead", "BNContrastiveHead", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "Silence", ) diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index 38d7c34ec69..c772f8127e7 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from .conv import Conv, DWConv, GhostConv, LightConv, RepConv +from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad from .transformer import TransformerBlock __all__ = ( @@ -31,6 +31,12 @@ "Proto", "RepC3", "ResNetLayer", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "Silence", ) @@ -531,7 +537,6 @@ class BNContrastiveHead(nn.Module): Args: embed_dims (int): Embed dimensions of text and image features. - norm_cfg (dict): Normalization parameters. """ def __init__(self, embed_dims: int): @@ -548,3 +553,146 @@ def forward(self, x, w): w = F.normalize(w, dim=-1, p=2) x = torch.einsum("bchw,bkc->bkhw", x, w) return x * self.logit_scale.exp() + self.bias + + +class RepBottleneck(nn.Module): + """Rep bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion + ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = RepConv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Forward pass through RepBottleneck layer.""" + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class RepCSP(nn.Module): + """Rep CSP Bottleneck with 3 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through RepCSP layer.""" + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + + +class RepNCSPELAN4(nn.Module): + """CSP-ELAN.""" + + def __init__(self, c1, c2, c3, c4, n=1): + """Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions.""" + super().__init__() + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1)) + self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1)) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + def forward(self, x): + """Forward pass through RepNCSPELAN4 layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class ADown(nn.Module): + """ADown.""" + + def __init__(self, c1, c2): + """Initializes ADown module with convolution layers to downsample input from channels c1 to c2.""" + super().__init__() + self.c = c2 // 2 + self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1) + self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0) + + def forward(self, x): + """Forward pass through ADown layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + x1, x2 = x.chunk(2, 1) + x1 = self.cv1(x1) + x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1) + x2 = self.cv2(x2) + return torch.cat((x1, x2), 1) + + +class SPPELAN(nn.Module): + """SPP-ELAN.""" + + def __init__(self, c1, c2, c3, k=5): + """Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling.""" + super().__init__() + self.c = c3 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv5 = Conv(4 * c3, c2, 1, 1) + + def forward(self, x): + """Forward pass through SPPELAN layer.""" + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4]) + return self.cv5(torch.cat(y, 1)) + + +class Silence(nn.Module): + """Silence.""" + + def __init__(self): + """Initializes the Silence module.""" + super(Silence, self).__init__() + + def forward(self, x): + """Forward pass through Silence layer.""" + return x + + +class CBLinear(nn.Module): + """CBLinear.""" + + def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): + """Initializes the CBLinear module, passing inputs unchanged.""" + super(CBLinear, self).__init__() + self.c2s = c2s + self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True) + + def forward(self, x): + """Forward pass through CBLinear layer.""" + outs = self.conv(x).split(self.c2s, dim=1) + return outs + + +class CBFuse(nn.Module): + """CBFuse.""" + + def __init__(self, idx): + """Initializes CBFuse module with layer index for selective feature fusion.""" + super(CBFuse, self).__init__() + self.idx = idx + + def forward(self, xs): + """Forward pass through CBFuse layer.""" + target_size = xs[-1].shape[2:] + res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] + out = torch.sum(torch.stack(res + xs[-1:]), dim=0) + return out diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 86203e44fb3..64ee7f5031c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -43,6 +43,12 @@ RTDETRDecoder, Segment, WorldDetect, + RepNCSPELAN4, + ADown, + SPPELAN, + CBFuse, + CBLinear, + Silence, ) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml @@ -570,7 +576,7 @@ def set_classes(self, text): text_token = clip.tokenize(text).to(device) txt_feats = model.encode_text(text_token).to(dtype=torch.float32) txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) - self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) + self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() self.model[-1].nc = len(text) def init_criterion(self): @@ -850,6 +856,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) C1, C2, C2f, + RepNCSPELAN4, + ADown, + SPPELAN, C2fAttn, C3, C3TR, @@ -892,6 +901,12 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args[2] = make_divisible(min(args[2], max_channels) * width, 8) elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 args.insert(1, [ch[x] for x in f]) + elif m is CBLinear: + c2 = args[0] + c1 = ch[f] + args = [c1, c2, *args[1:]] + elif m is CBFuse: + c2 = ch[f[-1]] else: c2 = ch[f] diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index 470fa83dc0e..f02e44c4e21 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -22,6 +22,7 @@ + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + [f"yolov8{k}-world.pt" for k in "smlx"] + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "ce"] + [f"yolo_nas_{k}.pt" for k in "sml"] + [f"sam_{k}.pt" for k in "bl"] + [f"FastSAM-{k}.pt" for k in "sx"]