Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some errors when ensemble models #3075

Closed
Lg955 opened this issue May 8, 2021 · 5 comments · Fixed by #3082
Closed

Some errors when ensemble models #3075

Lg955 opened this issue May 8, 2021 · 5 comments · Fixed by #3082
Labels
bug Something isn't working

Comments

@Lg955
Copy link

Lg955 commented May 8, 2021

🐛 Bug

When using ensemble models, sometimes(ensemble yolo5x) get a error.

To Reproduce (REQUIRED)

Input:

python detect.py --source test-A --weights runs/train/exp4_yolo5l6/weights/best.pt runs/train/exp5_yolo5x/weights/best.pt --view-img --conf-thres 0.000 --augment

Output:

In my 4 experiences(5x, 5x6+5l6, 5x6+5x, 5l6+5x), the first 2 are successful, others are failed.

  • yolo5x:
    5x - 副本

  • yolo5x6+5l6:
    5x6+5l6 - 副本

  • yolo5x6+5x:
    5x6+5x - 副本

  • yolo5l6+5x:
    5l6+5x - 副本

Ensemble created with ['runs/train/exp4_yolo5l6/weights/best.pt', 'runs/train/exp5_yolo5x/weights/best.pt']

image 1/1200 /dataset/lg_datacode/code/yolov5/Underwater_optics/test-A/000001.jpg: Traceback (most recent call last):
  File "detect.py", line 199, in <module>
    detect(opt=opt)
  File "detect.py", line 78, in detect
    pred = model(img, augment=opt.augment)[0]
  File "/home/lg/anaconda3/envs/yolo5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dataset/lg_datacode/code/yolov5/models/experimental.py", line 106, in forward
    y.append(module(x, augment)[0])
  File "/home/lg/anaconda3/envs/yolo5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dataset/lg_datacode/code/yolov5/models/yolo.py", line 113, in forward
    yi = self.forward_once(xi)[0]  # forward
  File "/dataset/lg_datacode/code/yolov5/models/yolo.py", line 139, in forward_once
    x = m(x)  # run
  File "/home/lg/anaconda3/envs/yolo5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dataset/lg_datacode/code/yolov5/models/common.py", line 210, in forward
    return torch.cat(x, self.d)
RuntimeError: Sizes of tensors must match except in dimension 2. Got 23 and 24 (The offending index is 0)

Expected behavior

I think there are some bugs in the code, but I can not find it.

Environment

If applicable, add screenshots to help explain your problem.

  • OS: Ubuntu 18.04
  • GPU 4*2080Ti
  • cuda 10.1
  • python 3.8.8
  • pytorch 1.7.0

Additional context

Add any other context about the problem here.

@Lg955 Lg955 added the bug Something isn't working label May 8, 2021
@github-actions
Copy link
Contributor

github-actions bot commented May 8, 2021

👋 Hello @Lg955, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@glenn-jocher
Copy link
Member

@Lg955 thanks for the bug report! I am able to reproduce this using your example code, I see the same error. I have a few ideas for what might be causing this, either the dataloader input stride is not synced to the P6 model in the ensemble, or the ensemble module itself is not flattening the outputs into a single list. I'll debug later and try to get a fix pushed this weekend.

Screenshot 2021-05-08 at 10 38 47

@glenn-jocher
Copy link
Member

@Lg955 short term workaround seems to be to always pass the P6 model last if you have P6 models in the ensemble:

Screenshot 2021-05-08 at 10 41 12

@Lg955
Copy link
Author

Lg955 commented May 8, 2021

@Lg955 short term workaround seems to be to always pass the P6 model last if you have P6 models in the ensemble:

Screenshot 2021-05-08 at 10 41 12

Wow, very nice, I succeeded!!!
Looking forward to your future fix.

@glenn-jocher glenn-jocher linked a pull request May 8, 2021 that will close this issue
@glenn-jocher
Copy link
Member

glenn-jocher commented May 8, 2021

@Lg955 good news 😃! Your original issue may now been fixed ✅ in PR #3082. To receive this update you can:

  • git pull from within your yolov5/ directory
  • git clone https://github.com/ultralytics/yolov5 again
  • Force-reload PyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • View our updated notebooks: Open In Colab Open In Kaggle

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants