Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some tricks to improve v5~(Including biFPN, ASFF, and all kinds of data augmentations~)Call for everyone~ #3993

Closed
3 tasks
SpongeBab opened this issue Jul 13, 2021 · 23 comments · Fixed by #4195 or #4208
Labels
enhancement New feature or request Stale

Comments

@SpongeBab
Copy link
Contributor

SpongeBab commented Jul 13, 2021

🚀 Feature

@glenn-jocher
Hi, Do you have a plan to implement the yolov5 with biFPN?

Motivation

We can see the advance of biFPN in this paper: https://arxiv.org/pdf/1911.09070.pdf. It is better than PANet.

Additional context

I think it can improves the mAP a lot.

Edit:A list of tricks( TBC ): What I want to do.

@SpongeBab SpongeBab added the enhancement New feature or request label Jul 13, 2021
@SpongeBab
Copy link
Contributor Author

🚀 Feature

@glenn-jocher
And what about adaptive average pooling?

@glenn-jocher
Copy link
Member

@SpongeBab BiFPN is used in EfficientDet models, the YOLOv5 heads are PANets.

Adaptive average pooling is used in classification models to reduce the spatial dimensions down to 1. I use this in the YOLOv5 classifier branch for example in the Classify model (equivalent of Detect() module):

yolov5/models/common.py

Lines 378 to 388 in 8ee9fd1

class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
super(Classify, self).__init__()
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
self.flat = nn.Flatten()
def forward(self, x):
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
return self.flat(self.conv(z)) # flatten to x(b,c2)

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 14, 2021

@glenn-jocher oh,I missed that.Thank you!
BTW, emm, although it belongs to Efficientdet. But I think in Efficientdet the biFPN plays an important role. The mAP increase is mainly due to biFPN, which I think is applicable to V5. It might be able to bring about a big increase for v5.

@SpongeBab
Copy link
Contributor Author

I would like to try to do this,if I get a big improvement,I will tell you.
@glenn-jocher

@glenn-jocher
Copy link
Member

@SpongeBab yeah absolutely, if you see improvements please let us know!

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 16, 2021

At first, I tired yolov5-four-FPN.
I created the new yaml.

# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:
#  - [ 10,13, 16,30, 33,23 ]  # P3/8
#  - [ 30,61, 62,45, 59,119 ]  # P4/16
#  - [ 116,90, 156,198, 373,326 ]  # P5/32
  - [ 5,7, 8,12, 15,13 ]  # P2/4  # custom
  - [ 10,13, 16,30, 33,23 ]  # P3/8
  - [ 30,61, 62,45, 59,119 ]  # P4/16
  - [ 116,90, 156,198, 373,326 ]  # P5/32

# YOLOv5 backbone
backbone:
  # [from, number, module, args]
  [ [ -1, 1, Focus, [ 64, 3 ] ],  # 0-P1/2
    [ -1, 1, Conv, [ 128, 3, 2 ] ],  # 1-P2/4
    [ -1, 3, Bottleneck, [ 128 ] ],
    [ -1, 1, Conv, [ 256, 3, 2 ] ],  # 3-P3/8
    [ -1, 9, BottleneckCSP, [ 256 ] ],
    [ -1, 1, Conv, [ 512, 3, 2 ] ],  # 5-P4/16
    [ -1, 9, BottleneckCSP, [ 512 ] ],
    [ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32
    [ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ],
    [ -1, 6, BottleneckCSP, [ 1024 ] ],  # 9
  ]

# YOLOv5 FPN head
head:
  [ [ -1, 3, BottleneckCSP, [ 1024, False ] ],  # 10 (P5/32-large)

    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 6 ], 1, Concat, [ 1 ] ],  # cat backbone P4
    [ -1, 1, Conv, [ 512, 1, 1 ] ],
    [ -1, 3, BottleneckCSP, [ 512, False ] ],  # 14 (P4/16-medium)

    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 4 ], 1, Concat, [ 1 ] ],  # cat backbone P3
    [ -1, 1, Conv, [ 256, 1, 1 ] ],
    [ -1, 3, BottleneckCSP, [ 256, False ] ],  # 18 (P3/8-small)

    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 2 ], 1, Concat, [ 1 ] ],  # cat backbone P2
    [ -1, 1, Conv, [ 128, 1, 1 ] ],
    [ -1, 3, BottleneckCSP, [ 128, False ] ],  # 22 (P2/4-small)

    [ [ 18, 14, 10, 6 ], 1, Detect, [ nc, anchors ] ],  # Detect(P2, P3, P4, P5)
  ]

And I begin to retrain my custom dataset.
So far, expectations are likely to be lower.
original:

    12/299     1.72G    0.1028    0.2015         0    0.3043        85       640     0.396    0.3296    0.3256   **0.08686**   0.09502    0.2732         0
    13/299     1.72G   0.09832    0.1952         0    0.2935        80       640   0.08477   0.04569      0.01  **0.002103**     0.143     1.027         0
    14/299     1.72G   0.09651    0.2055         0     0.302        44       640    0.4327    0.4439    0.4303     **0.128**   0.09133    0.2394         0
    15/299     1.72G   0.09085    0.1977         0    0.2886        24       640    0.4973    0.4347    0.4324    **0.1315**   0.09125    0.2948         0

Modified:

    12/299     2.34G    0.1029    0.1509         0    0.2538        46       640    0.3257    0.2285    0.1548   **0.04483**    0.1016    0.2063         0
    13/299     2.34G   0.09988    0.1454         0    0.2452        85       640    0.2229    0.1991   0.07238   **0.01867**    0.1062    0.1958         0
    14/299     2.34G   0.09809    0.1433         0    0.2414        99       640     0.329    0.2369    0.1658   **0.04505**    0.1014    0.1966         0
    15/299     2.34G   0.09579    0.1469         0    0.2427       159       640    0.4225    0.2317    0.2051   **0.06759**   0.09982    0.1883         0

Unfortunately, I only have a 2070, 8GB graphics card. My batch-size is both 2.
If the batch-size is bigger my memory will overflow.
But maybe I can get better results with bigger batch-size modifications.

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 16, 2021

I tiried to use biFPN. Then I think I must modify the class Concat:

class Concat(nn.Module):
    # Concatenate a list of tensors along dimension
    def __init__(self, dimension=1):
        super(Concat, self).__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)

Because the biFPN repeat the block. FPN and PANet is different from it.
image

And I want to make it compatible with YOLOV5-S-M-L-X,and P6.
I haven' t achieve this.

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 16, 2021

Some results about four output v5-p5:
Modified:

               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 25/25 [00:05<00:00,  4.64it/s]
                 all         50       1532      0.961      0.916      0.963      0.665
Speed: 0.4ms pre-process, 32.8ms inference, 5.3ms NMS per image at shape (2, 3, 640, 640)

Original:

               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 25/25 [00:04<00:00,  5.25it/s]
                 all         50       1532      0.986      0.944      0.973      0.728
Speed: 0.5ms pre-process, 27.9ms inference, 3.0ms NMS per image at shape (2, 3, 640, 640)

Unfortunately, the mAP is even lower.

@SpongeBab SpongeBab changed the title About biFPN~ Some tricks to improve v5~(Including biFPN, ASFF, and all kinds of data augmentations~)Call for everyone~ Jul 16, 2021
@glenn-jocher
Copy link
Member

@SpongeBab the YOLOv5-P6 models use 4 outputs, you may want to start with them if you are interested in a 4 output model, i.e.:

# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [ 19,27, 44,40, 38,94 ] # P3/8
- [ 96,68, 86,152, 180,137 ] # P4/16
- [ 140,301, 303,264, 238,542 ] # P5/32
- [ 436,615, 739,380, 925,792 ] # P6/64
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2
[ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4
[ -1, 3, C3, [ 128 ] ],
[ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8
[ -1, 9, C3, [ 256 ] ],
[ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16
[ -1, 9, C3, [ 512 ] ],
[ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32
[ -1, 3, C3, [ 768 ] ],
[ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64
[ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ],
[ -1, 3, C3, [ 1024, False ] ], # 11
]
# YOLOv5 head
head:
[ [ -1, 1, Conv, [ 768, 1, 1 ] ],
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
[ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5
[ -1, 3, C3, [ 768, False ] ], # 15
[ -1, 1, Conv, [ 512, 1, 1 ] ],
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
[ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4
[ -1, 3, C3, [ 512, False ] ], # 19
[ -1, 1, Conv, [ 256, 1, 1 ] ],
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
[ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3
[ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small)
[ -1, 1, Conv, [ 256, 3, 2 ] ],
[ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4
[ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium)
[ -1, 1, Conv, [ 512, 3, 2 ] ],
[ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5
[ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large)
[ -1, 1, Conv, [ 768, 3, 2 ] ],
[ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6
[ -1, 3, C3, [ 1024, False ] ], # 32 (P6/64-xlarge)
[ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6)
]

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 16, 2021

Yeah,I am following the V5 method to modify the network. I'm conducting more experiments.
And btw,I found that the impact of batch-size on AP could not be ignored. In my previous experiment, when I increased the batch-size from 1 to 2, my mAP increased by nearly 6%.The effects seem to be more significant as the network gets deeper.

@Cassieyy
Copy link

Hiiii, have you tried to replace the nms with softer nms?

@SpongeBab
Copy link
Contributor Author

@Cassieyy
hi, I tried it on v4, softer-nms is useless. I think this is because yolo predicts the object in each unit grid, 9 anchor boxes in each grid, and then predicts two or three final detection frames. Traditional NMS or variants(soft-NMS, softer-NMS...) have obvious effects on two-stage detectors, like fastrcnn, because this detection method generates a large number of proposal boxes.
Maybe it is useful for v5, after all I didn't try it on v5. By the way, DIoU-NMS is very suitable for yolo.
Thank you~

@glenn-jocher glenn-jocher linked a pull request Jul 28, 2021 that will close this issue
@glenn-jocher
Copy link
Member

@SpongeBab good news 😃! Your original issue may now be fixed ✅ in PR #4195. This PR adds a new YOLOv5-BiFPN yaml to the models/hub directory. It turns out the current heads are very similar to BiFPN and the only remaining difference was to add a shortcut to the backbone layers on the intermediate outputs (P4 only for P5 models, or P4 and P5 for P6 models). One key differnece is I think EfficientDet weighted-sums their shortcuts whereas we concatenate our shortcuts togethor, so some study needs to be done on the effects there.

To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload with model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

@glenn-jocher glenn-jocher reopened this Jul 28, 2021
@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 28, 2021

@glenn-jocher Oh! Great! I am also working on this, but it is not finished. Ah, I don't need to do it.:)
Have you trained on the dataset? How about the results?

@SpongeBab
Copy link
Contributor Author

I have done this:

common.py:

class Concat(nn.Module):
    # Concatenate a list of tensors along dimension
    def __init__(self, c1, c2):
        super(Concat, self).__init__()
        # self.relu = nn.ReLU()
        self.w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
        self.epsilon = 0.0001
        self.conv = nn.Conv2d(c1, c2, kernel_size=1, stride=1, padding=0)
        self.swish = MemoryEfficientSwish()

    def forward(self, x):
        outs = self._forward(x)
        return outs

    def _forward(self, x): # intermediate result
        if len(x) == 2:
            # w = self.relu(self.w1)
            w = self.w1
            weight = w / (torch.sum(w, dim=0) + self.epsilon)
            x = self.conv(self.swish(weight[0] * x[0] + weight[1] * x[1]))
        elif len(x) == 3: # final result
            # w = self.relu(self.w2)
            w = self.w2
            weight = w / (torch.sum(w, dim=0) + self.epsilon)
            x = self.conv(self.swish(weight[0] * x[0] + weight[1] * x[1] + weight[2] * x[2]))

yolo.py:

 # elif m is Concat:
 #    c2 = sum([ch[x] for x in f])
   elif m is Concat:
      c2 = max([ch[x] for x in f])

I used yolov5x.yaml:

# parameters
nc: 80  # number of classes
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple


# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2  320  3,80
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4  160  80,160
   [-1, 3, C3, [128]],  # 160, 160
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  80  # 160, 320
   [-1, 9, C3, [256]],  # 320, 320
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  40  # 320, 640
   [-1, 9, C3, [512]],  # 640, 640
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32  20 # 640, 1280
   [-1, 1, SPP, [1024, [5, 9, 13]]],  # 1280, 1280
   [-1, 3, C3, [1024, False]],  # 9 1280, 1280
  ]

# YOLOv5 head
head:
  [[-1, 1, Conv, [512, 1, 1]],  # 10 1280, 640
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],  # 11 40 上采样
   [[-1, 6], 1, Concat, [640, 640]],  # 12 cat backbone P4  # cat 40,40
   [-1, 3, C3, [512, False]],  # 13  # 640, 640

   [-1, 1, Conv, [256, 1, 1]], # 640, 320
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],  # 80  640, 320 上采样
   [[-1, 4], 1, Concat, [320, 320]],  # cat backbone P3  # cat 80,80
   [-1, 4, C3, [256, False]],  # 17 (P3/8-small) # 320, 320

#   [-1, 1, Conv, [256, 1, 1]],   # 320, 320
   [-1, 1, Conv, [512, 3, 2]],   # 320, 640 # 下 40
   [[-1, 6, 13], 1, Concat, [640, 640]],  # cat head P4  # cat 40,40
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)  # 640, 640  #20

#   [-1, 1, Conv, [512, 1, 1]], # 640, 640
   [-1, 1, Conv, [1024, 3, 2]], # 640, 1280 # 下 20  #21
   [[-1, 9], 1, Concat, [1280, 1280]],  # cat head P5  cat 20,20 #22
   [-1, 3, C3, [1024, False]],  # 25 (P5/32-large) # 1280, 1280  #23

#   [[17, 21, 25], 1, Detect, [nc, anchors]] # Detect(P3, P4, P5)
   [[17, 20, 23], 1, Detect, [nc, anchors]] # Detect(P3, P4, P5)
  ]

   # layer 2
#   [-1, 1, Conv, [512, 1, 1]],  # 1280, 640
#   [-1, 1, nn.Upsample, [None, 2, 'nearest']],  # 40 上采样
#   [[-1, 21], 1, Concat, [640, 640]],  # cat backbone P4  # cat 40,40
#   [-1, 3, C3, [512, False]],  # 29  # 640, 640
#
#   [-1, 1, Conv, [256, 1, 1]], # 640, 320
#   [-1, 1, nn.Upsample, [None, 2, 'nearest']],  # 80  640, 320 上采样
#   [[-1, 17], 1, Concat, [320, 320]],  # cat backbone P3  # cat 80,80
#   [-1, 3, C3, [256, False]],  # 33 (P3/8-small) # 320, 320
#
#   [-1, 1, Conv, [256, 1, 1]],   # 320, 320
#   [-1, 1, Conv, [512, 3, 2]],   # 320, 640 # 下 40
#   [[-1, 21, 29], 1, Concat, [640, 640]],  # cat head P4  # cat 40,40
#   [-1, 3, C3, [512, False]],  # 37 (P4/16-medium)  # 640, 640
#
#   [-1, 1, Conv, [512, 1, 1]], # 640, 640
#   [-1, 1, Conv, [1024, 3, 2]], # 640, 1280 # 下 20
#   [[-1, 25], 1, Concat, [1280, 1280]],  # cat head P5  cat 20,20
#   [-1, 3, C3, [1024, False]],  # 41 (P5/32-large) # 640, 1280
#
#   [[33, 37, 41], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
#]

Although the code is ok, but I think it is incomplete.There is no complete implementation of bi-FPN.

@glenn-jocher
Copy link
Member

@SpongeBab the new yolov5-bifpn.yaml file implements a BiFPN head:
https://github.com/ultralytics/yolov5/blob/master/models/hub/yolov5-bifpn.yaml

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 28, 2021

@glenn-jocher yeah, I saw it. Have you trained it?

@glenn-jocher
Copy link
Member

@SpongeBab yes, it performs similarly to the baseline model.

@SpongeBab
Copy link
Contributor Author

SpongeBab commented Jul 28, 2021

@glenn-jocher And I think it’s yolov5l-bifpn, more correctly.Can I try v5s, m, x-bifpn?

@glenn-jocher
Copy link
Member

@SpongeBab sure, it's easy to apply to any size. Just compare yolov5l.yaml with yolov5-bifpn.yaml, and apply the same change to the other sizes.

@Adlith
Copy link

Adlith commented Jul 28, 2021

@SpongeBab sure, it's easy to apply to any size. Just compare yolov5l.yaml with yolov5-bifpn.yaml, and apply the same change to the other sizes.

@SpongeBab the new yolov5-bifpn.yaml file implements a BiFPN head:
https://github.com/ultralytics/yolov5/blob/master/models/hub/yolov5-bifpn.yaml

when I use
python models/yolo.py --cfg models/hub/yolov5-bifpn.yaml
my console shows that

0 -1 1 7040 models.common.Focus [3, 64, 3]
1 -1 1 73984 models.common.Conv [64, 128, 3, 2]
2 -1 3 246912 models.common.Bottleneck [128, 128]
3 -1 1 295424 models.common.Conv [128, 256, 3, 2]
4 -1 1 1627904 models.common.BottleneckCSP [256, 256, 9]
5 -1 1 1180672 models.common.Conv [256, 512, 3, 2]
6 -1 1 6499840 models.common.BottleneckCSP [512, 512, 9]
7 -1 1 4720640 models.common.Conv [512, 1024, 3, 2]
8 -1 1 2624512 models.common.SPP [1024, 1024, [5, 9, 13]]
9 -1 1 18105344 models.common.BottleneckCSP [1024, 1024, 6]
10 -1 1 10234880 models.common.BottleneckCSP [1024, 1024, 3, False]
11 -1 1 525312 models.common.Conv [1024, 512, 1, 1]
12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
Traceback (most recent call last):
File "models/yolo.py", line 286, in
model = Model(opt.cfg).to(device)
File "models/yolo.py", line 97, in init
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
File "models/yolo.py", line 251, in parse_model
c2 = sum([ch[x] for x in f])
File "models/yolo.py", line 251, in
c2 = sum([ch[x] for x in f])
IndexError: list index out of range

I don't know how to handle it.

@glenn-jocher glenn-jocher linked a pull request Jul 28, 2021 that will close this issue
@glenn-jocher
Copy link
Member

@SpongeBab good news 😃! Your original issue may now be fixed ✅ in PR #4208. To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload with model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

@github-actions
Copy link
Contributor

github-actions bot commented Aug 28, 2021

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

Access additional Ultralytics ⚡ resources:

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Stale
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants