Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the use of MixConv2d #5403

Closed
Zengyf-CVer opened this issue Oct 29, 2021 · 7 comments · Fixed by #5410
Closed

About the use of MixConv2d #5403

Zengyf-CVer opened this issue Oct 29, 2021 · 7 comments · Fixed by #5410
Labels
question Further information is requested

Comments

@Zengyf-CVer
Copy link
Contributor

@glenn-jocher
How should I use this MixConv2d method? I did an experiment and replaced Conv directly with the MixConv2d method, and found that it couldn't run:

backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, MixConv2d, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, MixConv2d, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, MixConv2d, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, MixConv2d, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

Console output information:
ksnip_20211030-000732

I found that the problem with this k, how do I set this k?

groups = len(k)

@Zengyf-CVer Zengyf-CVer added the question Further information is requested label Oct 29, 2021
@glenn-jocher
Copy link
Member

@Zengyf-CVer see MixConv2d module for input requirements:

class MixConv2d(nn.Module):
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super().__init__()
groups = len(k)
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)
def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

@glenn-jocher
Copy link
Member

@Zengyf-CVer Conv() module for comparison:

yolov5/models/common.py

Lines 36 to 46 in ed887b5

class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))

@Zengyf-CVer
Copy link
Contributor Author

@glenn-jocher
I used a closed test and found a problem. Is this a bug?

import torch
from utils.torch_utils import profile
from models.experimental import MixConv2d

m = MixConv2d(128, 256, (3, 5), 1)
results = profile(input=torch.randn(16, 128, 80, 80), ops=[m], n=1)

Error message:

The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 1

Is there a bug for MixConv2d?

@glenn-jocher
Copy link
Member

@Zengyf-CVer I cam able to reproduce this error (thank you for the code to reproduce!), may be a bug, will investigate.

@glenn-jocher
Copy link
Member

TODO: Possible MixConv2d() bug

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 30, 2021

@Zengyf-CVer based on a re-read of the MixConv paper I don't understand why I put a shortcut into the current module. The reason for the error is that the forward function includes a sum of the input to the output, and the input and outputs of course have different channel counts (128 and 256), causing the error. I will remove the shortcut and push a fix.

    def forward(self, x):
        return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

Screenshot 2021-10-30 at 13 01 46

Screenshot 2021-10-30 at 13 01 09

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 30, 2021

@Zengyf-CVer good news 😃! Your original issue may now be fixed ✅ in PR #5410. To verify this fix:

import torch

from utils.torch_utils import profile
from models.experimental import MixConv2d
from models.common import Conv

m1 = MixConv2d(128, 256, (3, 5), 1)
m2 = Conv(128, 256, 3, 1)
results = profile(input=torch.randn(16, 128, 80, 80), ops=[m1, m2], n=3)

YOLOv5 🚀 v6.0-39-g3d9a368 torch 1.9.0+cu111 CUDA:0 (Tesla P100-PCIE-16GB, 16280.875MB)

      Params      GFLOPs  GPU_mem (GB)  forward (ms) backward (ms)                   input                  output
        4864      0.9961         0.684         4.922         12.76       (16, 128, 80, 80)       (16, 256, 80, 80)
      295424        60.5         0.990         9.727         8.917       (16, 128, 80, 80)       (16, 256, 80, 80)

To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload with model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

@glenn-jocher glenn-jocher removed the TODO label Nov 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants