Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv5 (6.0/6.1) brief summary #6998

Open
WZMIAOMIAO opened this issue Mar 16, 2022 · 63 comments · Fixed by #7146
Open

YOLOv5 (6.0/6.1) brief summary #6998

WZMIAOMIAO opened this issue Mar 16, 2022 · 63 comments · Fixed by #7146
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@WZMIAOMIAO
Copy link

WZMIAOMIAO commented Mar 16, 2022

Content

1. Model Structure

YOLOv5 (v6.0/6.1) consists of:

  • Backbone: New CSP-Darknet53
  • Neck: SPPF, New CSP-PAN
  • Head: YOLOv3 Head

Model structure (yolov5l.yaml):

yolov5

Some minor changes compared to previous versions:

  1. Replace the Focus structure with 6x6 Conv2d(more efficient, refer Is the Focus layer equivalent to a simple Conv layer? #4825)
  2. Replace the SPP structure with SPPF(more than double the speed)
test code
import time
import torch
import torch.nn as nn


class SPP(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool1 = nn.MaxPool2d(5, 1, padding=2)
        self.maxpool2 = nn.MaxPool2d(9, 1, padding=4)
        self.maxpool3 = nn.MaxPool2d(13, 1, padding=6)

    def forward(self, x):
        o1 = self.maxpool1(x)
        o2 = self.maxpool2(x)
        o3 = self.maxpool3(x)
        return torch.cat([x, o1, o2, o3], dim=1)


class SPPF(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool2d(5, 1, padding=2)

    def forward(self, x):
        o1 = self.maxpool(x)
        o2 = self.maxpool(o1)
        o3 = self.maxpool(o2)
        return torch.cat([x, o1, o2, o3], dim=1)


def main():
    input_tensor = torch.rand(8, 32, 16, 16)
    spp = SPP()
    sppf = SPPF()
    output1 = spp(input_tensor)
    output2 = sppf(input_tensor)

    print(torch.equal(output1, output2))

    t_start = time.time()
    for _ in range(100):
        spp(input_tensor)
    print(f"spp time: {time.time() - t_start}")

    t_start = time.time()
    for _ in range(100):
        sppf(input_tensor)
    print(f"sppf time: {time.time() - t_start}")


if __name__ == '__main__':
    main()

result:

True
spp time: 0.5373051166534424
sppf time: 0.20780706405639648

2. Data Augmentation

  • Mosaic

  • Copy paste

  • Random affine(Rotation, Scale, Translation and Shear)

  • MixUp

  • Albumentations
  • Augment HSV(Hue, Saturation, Value)

  • Random horizontal flip

3. Training Strategies

  • Multi-scale training(0.5~1.5x)
  • AutoAnchor(For training custom data)
  • Warmup and Cosine LR scheduler
  • EMA(Exponential Moving Average)
  • Mixed precision
  • Evolve hyper-parameters

4. Others

4.1 Compute Losses

The YOLOv5 loss consists of three parts:

  • Classes loss(BCE loss)
  • Objectness loss(BCE loss)
  • Location loss(CIoU loss)

loss

4.2 Balance Losses

The objectness losses of the three prediction layers(P3, P4, P5) are weighted differently. The balance weights are [4.0, 1.0, 0.4] respectively.

obj_loss

4.3 Eliminate Grid Sensitivity

In YOLOv2 and YOLOv3, the formula for calculating the predicted target information is:

b_x
b_y
b_w
b_h

In YOLOv5, the formula is:

bx
by
bw
bh

Compare the center point offset before and after scaling. The center point offset range is adjusted from (0, 1) to (-0.5, 1.5).
Therefore, offset can easily get 0 or 1.

Compare the height and width scaling ratio(relative to anchor) before and after adjustment. The original yolo/darknet box equations have a serious flaw. Width and Height are completely unbounded as they are simply out=exp(in), which is dangerous, as it can lead to runaway gradients, instabilities, NaN losses and ultimately a complete loss of training. refer this issue

4.4 Build Targets

Match positive samples:

  • Calculate the aspect ratio of GT and Anchor Templates

rw

rh

rwmax

rhmax

rmax

match

  • Assign the successfully matched Anchor Templates to the corresponding cells

  • Because the center point offset range is adjusted from (0, 1) to (-0.5, 1.5). GT Box can be assigned to more anchors.

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

YOLOv5 CI

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training, validation, inference, export and benchmarks on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@WZMIAOMIAO
Copy link
Author

@glenn-jocher hi, today I briefly summarized yolov5(v6.0). Please help to see if there are any problems or put forward better suggestions. Some schematic diagrams or contents will be added later. Thank you for your great work.

@zlj-ky
Copy link

zlj-ky commented Mar 16, 2022

hi, 'prediction layers(P3, P4, P5) are weighted differently', how do I find it in the code, and further, modify it?

@WZMIAOMIAO
Copy link
Author

hi, 'prediction layers(P3, P4, P5) are weighted differently', how do I find it in the code, and further, modify it?

self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7

and

lobj += obji * self.balance[i] # obj loss

@zlj-ky
Copy link

zlj-ky commented Mar 16, 2022

@WZMIAOMIAO thx!

This was referenced Mar 17, 2022
@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 17, 2022

@WZMIAOMIAO awesome summary, nice work!

@zlj-ky yes the balancing parameters are there, we tuned these manually on COCO. The idea is to balance losses from each layer (just like we balance losses across loss components (box, obj, class)). The reason I didn't turn these into learnable weights is that as absolute values the gradient would always want to drag them to zero to minimize the loss. I suppose we could constantly normalize them so they all sum to 1 to avoid this effect. Might be an interesting experiment, and this might help the balancing adapt better to different datasets and image sizes etc.

@WZMIAOMIAO
Copy link
Author

@glenn-jocher Could we add this brief summary to the document?

@glenn-jocher
Copy link
Member

@WZMIAOMIAO yes maybe it's a good idea to document this somewhere. Which document do you mean though?

@WZMIAOMIAO
Copy link
Author

@glenn-jocher I think it could be added to the Tutorials. What do you think?

glenn-jocher added a commit that referenced this issue Mar 25, 2022
* Add Architecture Summary to README Tutorials

Per #6998 (comment)

* Update README.md
@glenn-jocher
Copy link
Member

@WZMIAOMIAO all done in #7146! Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐

@glenn-jocher
Copy link
Member

@HERIUN built_targets() implements an anchor-label assignment strategy so we can calculate the losses between assigned anchor-label pairs.

@xinxin342
Copy link

@glenn-jocher what's the adjustment strategy for the balancing parameters?How to change them to learnable weights?

@WZMIAOMIAO awesome summary, nice work!

@zlj-ky yes the balancing parameters are there, we tuned these manually on COCO. The idea is to balance losses from each layer (just like we balance losses across loss components (box, obj, class)). The reason I didn't turn these into learnable weights is that as absolute values the gradient would always want to drag them to zero to minimize the loss. I suppose we could constantly normalize them so they all sum to 1 to avoid this effect. Might be an interesting experiment, and this might help the balancing adapt better to different datasets and image sizes etc.

@glenn-jocher what's the adjustment strategy for the balancing parameters?How to change them to learnable weights?

@glenn-jocher
Copy link
Member

@xinxin342 the balance params are here, you'd have to convert them to nn.Parameter types assigned to an existing class and set their compute grad to True:

self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7

@zlj-ky
Copy link

zlj-ky commented Apr 16, 2022

@xinxin342 the balance params are here, you'd have to convert them to nn.Parameter types assigned to an existing class and set their compute grad to True:

self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7

@glenn-jocher
I try to convert the weight to a learnable parameter like this(Limited by my limited experience)
图片
However, this parameter was not updated during training, I don't know why or how to revise my method. Can you teach me, even though it's a very simple question

@glenn-jocher
Copy link
Member

glenn-jocher commented Apr 17, 2022

@zlj-ky that seems like a good approach, but you might need to place self.w inside the model so it's affected by model.train(), model.eval(), etc. You can just place it inside models.yolo.Detect and then access it like this. (Note your code is out of date):

class ComputeLoss:
    sort_obj_iou = False

    def __init__(self, model, autobalance=False):
        device = next(model.parameters()).device  # get model device
        h = model.hyp  # hyperparameters

        # Define criteria
        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))

        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
        self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets

        # Focal loss
        g = h['fl_gamma']  # focal loss gamma
        if g > 0:
            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

        m = de_parallel(model).model[-1]  # Detect() module
        self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7
        self.ssi = list(m.stride).index(16) if autobalance else 0  # stride 16 index
        self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
        self.na = m.na  # number of anchors
        self.nc = m.nc  # number of classes
        self.nl = m.nl  # number of layers
        self.anchors = m.anchors
        self.w = m.w  # <------------------------ NEW CODE 
        self.device = device

This might or might not work as I don't know if this will create a copy or access the Detect parameter.

Even if you get this to work though It's not clear that these are learnable parameters as I'm not sure if they can be correlated to the gradient directly, i.e. the optimizer seeks to reduce loss, so the rebalance may just weigh higher the lower loss components to reduce loss, which may not have the desired effect.

The same concept applies to anchors, which don't seem learnable either during training.

@glenn-jocher glenn-jocher mentioned this issue Apr 17, 2022
1 task
@zlj-ky
Copy link

zlj-ky commented Apr 18, 2022

@zlj-ky that seems like a good approach, but you might need to place self.w inside the model so it's affected by model.train(), model.eval(), etc. You can just place it inside models.yolo.Detect and then access it like this. (Note your code is out of date):

class ComputeLoss:
    sort_obj_iou = False

    def __init__(self, model, autobalance=False):
        device = next(model.parameters()).device  # get model device
        h = model.hyp  # hyperparameters

        # Define criteria
        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))

        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
        self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets

        # Focal loss
        g = h['fl_gamma']  # focal loss gamma
        if g > 0:
            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

        m = de_parallel(model).model[-1]  # Detect() module
        self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7
        self.ssi = list(m.stride).index(16) if autobalance else 0  # stride 16 index
        self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
        self.na = m.na  # number of anchors
        self.nc = m.nc  # number of classes
        self.nl = m.nl  # number of layers
        self.anchors = m.anchors
        self.w = m.w  # <------------------------ NEW CODE 
        self.device = device

This might or might not work as I don't know if this will create a copy or access the Detect parameter.

Even if you get this to work though It's not clear that these are learnable parameters as I'm not sure if they can be correlated to the gradient directly, i.e. the optimizer seeks to reduce loss, so the rebalance may just weigh higher the lower loss components to reduce loss, which may not have the desired effect.

The same concept applies to anchors, which don't seem learnable either during training.

@glenn-jocher Thank you for sharing your views on this matter and for your patient guidance. I will try it latter.

@Cong-Wan
Copy link

@glenn-jocher hi, today I briefly summarized yolov5(v6.0). Please help to see if there are any problems or put forward better suggestions. Some schematic diagrams or contents will be added later. Thank you for your great work.

@WZMIAOMIAO @glenn-jocher Hi, thank for your nice work! There I have two questions, first, how could I print every layers outputs.(Here I'd like to change first layer kernel to small size that it's possible for small object detection.) Next, I also want to add a output for object tracing, ([x,y,w,h,nc] -> [x, y, w, h, nc, id]) but I don't know use which loss function to do it.

@kadirnar
Copy link

kadirnar commented Oct 7, 2022

@engrjav FPN and PANet are just two head architectures. Earlier versions of YOLOv5 used FPN and newer versions use PANet. CSP is a type of repeating module which as evolved into the current C3 modules. Screen Shot 2022-07-29 at 3 05 16 PM

Hi @glenn-jocher
Why did you choose PANet? Is there a comparison chart? Do you think to prefer Light-BiFPN module for small models?
Light-Yolov5: https://arxiv.org/pdf/2208.13422.pdf

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 9, 2022

@kadirnar BiFPN and PANet are nearly identical, in a P3-P5 output model the only difference is a single shortcut. There are versions of all 3 heads available here:
https://github.com/ultralytics/yolov5/tree/master/models/hub

As always all design decisions are based on empirical results.

@divided7
Copy link

Hello,can we get the results of the ablation experiment?Such as SPP2SPPF、Focus2Conv mAP results on big datasets

@glenn-jocher
Copy link
Member

@divided-by-7 I'm sorry, we don't this R&D saved in a presentable manner.

@dlod-openvino
Copy link

@WZMIAOMIAO Could you please summarize the YOLOv5 Instance Segmentation Model Structure? especially the keywords definition of output0 float32[1,25200,117] and output1 float32[1,32,160,160]. Thank you very much in advance!

@ishakpacal
Copy link

Dear @glenn-jocher @WZMIAOMIAO
The segmentation part is excellent. What has changed in the model architecture related to this, could you provide an example model architecture, thanks in advance.

@tayahiyukon
Copy link

Hi! What do k, s, p, and c represent in the model structure, respectively?

@XueZ-phd
Copy link

Hi! What do k, s, p, and c represent in the model structure, respectively?

This is a simple question. k = kernel size, s = stride, p = padding, c = channel dims

@tayahiyukon
Copy link

Hi! What do k, s, p, and c represent in the model structure, respectively?

This is a simple question. k = kernel size, s = stride, p = padding, c = channel dims

Okay, thank you very much!

@karl-gardner
Copy link

Hello @glenn-jocher or anyone who knows the answer. I am trying to understand the build targets process a little more. When you say GTx%1>0.5 and GTy%1>0.5 is the % just the modulus? If it is the modulo operator, then why is this used?

Thanks,

Karl Gardner

@scraus
Copy link

scraus commented Dec 27, 2022

@WZMIAOMIAO @glenn-jocher or anyone who knows. I am trying to understand more about the model structure. Is there an article that discusses and explains the YOLOv5 structure? Thanks!

SecretStar112 added a commit to SecretStar112/yolov5 that referenced this issue May 24, 2023
* Add Architecture Summary to README Tutorials

Per ultralytics/yolov5#6998 (comment)

* Update README.md
@gracesmrngkr
Copy link

Hi @glenn-jocher can i know what is the formula if input image 640x640x3 becomes 320x320x64 with k=3 s=2 p=1?

@glenn-jocher
Copy link
Member

@gracesmrngkr this transformation is governed by the following formula:

[
\text{output_size} = \left\lfloor \frac{\text{input_size} - \text{kernel_size} + 2\times \text{padding}}{\text{stride}} \right\rfloor + 1
]

So in this case, with an input size of 640 and a kernel size of 3, a stride of 2, and padding of 1, the output size would be 320.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging a pull request may close this issue.