Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input channel yaml['ch'] addition #1741

Merged
merged 1 commit into from
Dec 19, 2020
Merged

Input channel yaml['ch'] addition #1741

merged 1 commit into from
Dec 19, 2020

Conversation

glenn-jocher
Copy link
Member

@glenn-jocher glenn-jocher commented Dec 19, 2020

This PR adds support for optional input channel definition in model yaml files, i.e.

# parameters
nc: 80  # number of classes
ch: 10  # input channels  <------------------------
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple

Tested with 1, 3 and 10 channel models.

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Refinements in YOLOv5 model initialization and FLOPS calculation.

📊 Key Changes

  • math module import has been moved for better code organization.
  • The self.yaml['ch'] now explicitly initializes with input channels parameter ch.
  • Removed the redundant ch_out comment in model parsing as it was confusing.
  • Model's FLOPS calculation now dynamically uses the ch value from model.yaml.

🎯 Purpose & Impact

  • 💡 Better code structure provides clarity for future maintenance and development.
  • 🚀 Enhances flexibility by ensuring model initialization aligns with configuration files, potentially improving model adaptability to different input channels.
  • 🔍 Improves technical documentation within the code for better understanding.
  • ✅ More accurate FLOPS calculation leads to better performance insights, impacting developers' ability to optimize models.

@glenn-jocher glenn-jocher linked an issue Dec 19, 2020 that may be closed by this pull request
@glenn-jocher glenn-jocher merged commit 394d1c8 into master Dec 19, 2020
@thhart
Copy link

thhart commented Dec 19, 2020

Did specify the ch:1 in yaml but unfortunately it fails to execute, ch:3 works, did I miss a further settings? All inputs are single channel ones.

$: python ./train.py --img 1024 --batch 8 --epochs 256 --data dataDisc.yaml --cfg ./models/disc_yolov5m.yaml --weights '' --name yolo5m_disc --cache
Using torch 1.7.0 CUDA:0 (GeForce GTX 1070, 8114MB)

Namespace(adam=False, batch_size=8, bucket='', cache_images=True, cfg='./models/disc_yolov5m.yaml', data='dataDisc.yaml', device='', epochs=256, evolve=False, exist_ok=False, global_rank=-1, hyp='data/hyp.scratch.yaml', image_weights=False, img_size=[1024, 1024], local_rank=-1, log_artifacts=False, log_imgs=16, multi_scale=False, name='yolo5m_disc', noautoanchor=False, nosave=False, notest=False, project='runs/train', rect=False, resume=False, save_dir='runs/train/yolo5m_disc11', single_cls=False, sync_bn=False, total_batch_size=8, weights='', workers=8, world_size=1)
Start Tensorboard with "tensorboard --logdir runs/train", view at http://localhost:6006/
Hyperparameters {'lr0': 0.01, 'lrf': 0.2, 'momentum': 0.937, 'weight_decay': 0.0005, 'warmup_epochs': 3.0, 'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1, 'box': 0.05, 'cls': 0.5, 'cls_pw': 1.0, 'obj': 1.0, 'obj_pw': 1.0, 'iou_t': 0.2, 'anchor_t': 4.0, 'fl_gamma': 0.0, 'hsv_h': 0.015, 'hsv_s': 0.7, 'hsv_v': 0.4, 'degrees': 0.0, 'translate': 0.1, 'scale': 0.5, 'shear': 0.0, 'perspective': 0.0, 'flipud': 0.0, 'fliplr': 0.5, 'mosaic': 1.0, 'mixup': 0.0}

                 from  n    params  module                                  arguments                     
  0                -1  1      1824  models.common.Focus                     [1, 48, 3]                    
  1                -1  1     41664  models.common.Conv                      [48, 96, 3, 2]                
  2                -1  1     67680  models.common.BottleneckCSP             [96, 96, 2]                   
  3                -1  1    166272  models.common.Conv                      [96, 192, 3, 2]               
  4                -1  1    639168  models.common.BottleneckCSP             [192, 192, 6]                 
  5                -1  1    664320  models.common.Conv                      [192, 384, 3, 2]              
  6                -1  1   2550144  models.common.BottleneckCSP             [384, 384, 6]                 
  7                -1  1   2655744  models.common.Conv                      [384, 768, 3, 2]              
  8                -1  1   1476864  models.common.SPP                       [768, 768, [5, 9, 13]]        
  9                -1  1   4283136  models.common.BottleneckCSP             [768, 768, 2, False]          
 10                -1  1    295680  models.common.Conv                      [768, 384, 1, 1]              
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1   1219968  models.common.BottleneckCSP             [768, 384, 2, False]          
 14                -1  1     74112  models.common.Conv                      [384, 192, 1, 1]              
 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1    305856  models.common.BottleneckCSP             [384, 192, 2, False]          
 18                -1  1    332160  models.common.Conv                      [192, 192, 3, 2]              
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1   1072512  models.common.BottleneckCSP             [384, 384, 2, False]          
 21                -1  1   1327872  models.common.Conv                      [384, 384, 3, 2]              
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1   4283136  models.common.BottleneckCSP             [768, 768, 2, False]          
 24      [17, 20, 23]  1     24246  models.yolo.Detect                      [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [192, 384, 768]]
Model Summary: 391 layers, 21482358 parameters, 21482358 gradients, 50.7 GFLOPS

Optimizer groups: 86 .bias, 94 conv.weight, 83 other
Scanning 'disc/train/labels.cache' for images and labels... 459 found, 0 missing, 0 empty, 0 corrupted: 100%|????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????| 459/459 [00:00
    train(hyp, opt, device, tb_writer, wandb)
  File "./train.py", line 289, in train
    pred = model(imgs)  # forward
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/th/dev/itth/okraLearn/yolov5/models/yolo.py", line 123, in forward
    return self.forward_once(x, profile)  # single-scale inference, train
  File "/home/th/dev/itth/okraLearn/yolov5/models/yolo.py", line 139, in forward_once
    x = m(x)  # run
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/th/dev/itth/okraLearn/yolov5/models/common.py", line 109, in forward
    return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/th/dev/itth/okraLearn/yolov5/models/common.py", line 35, in forward
    return self.act(self.bn(self.conv(x)))
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 420, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [48, 4, 3, 3], expected input[8, 12, 512, 512] to have 4 channels, but got 12 channels instead

@glenn-jocher
Copy link
Member Author

@thhart ah, the PR merely allows you to specify the model channel count via the yaml. This works successfully in your code, as you can see the input shape starts with 1, and the model forward pass is evaluated successfully to give you a FLOPS value.

                 from  n    params  module                                  arguments                     
  0                -1  1      1824  models.common.Focus                     [1, 48, 3]  

The training is another matter. As I said before you'd need to customize the dataloader to suit your requirements, as the current dataloader uses cv2.imread() which loads RGB and greyscale as 3-ch images.

@glenn-jocher
Copy link
Member Author

The training dataloader is here:

yolov5/utils/datasets.py

Lines 336 to 338 in 394d1c8

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):

@thhart
Copy link

thhart commented Dec 19, 2020

Thanks for clarification. Will see if I can enhance the dataloader to support this.

@thhart
Copy link

thhart commented Dec 20, 2020

I checked datasets.py but unfortunately it is harder than expected. The whole code is dependent on the color scheme. Even if I get img handling as gray scale through it fails in yolo.py fiinaly. When you speak you tested with different channels you speak of inference only?

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Dec 20, 2020

@thhart yes. After construction the model.info() method runs, which passes a ch-channel blank image through the model to trace it for FLOPS, so if you see a FLOPS value inference runs correctly for the ch-channel model.

cv2 defaults to loading greyscale as 3-ch images, and the increased FLOPS are negligible as you can see for yourself. COCO dataset already includes a few greyscale images that train correctly using this method, thus the default implementation is already capable of mixed or isolated greyscale and 3-ch (and 4-ch PNG with alpha) image training and inference.

Unless you have hyperspectral imagery I would just train normally with the default repository.

@thhart
Copy link

thhart commented Dec 20, 2020

Currently all images are converted to BGR via cv2 default. So yes this makes it working for all image types of course.

The problem is cv2 is returning a different img structure for single band images. This makes it a bit dirty to hack.

You are right I am heading for increased FLOPS and would like to check. However I am asking if it is worth the effort. I am running inference on CPU on large input. However I don't need this much performance yet.
So if you are saying it is negligible I would drop this for the moment.

@glenn-jocher
Copy link
Member Author

@thhart the #1 issue people encounter here and that I would advise against is premature optimization. If possible, train normally with default settings, then once you have a speed and mAP baseline to compare against, if unsatisfactory you can start optimization your approach.

@thhart
Copy link

thhart commented Dec 20, 2020

Thanks for the input, appreciate.

KMint1819 pushed a commit to KMint1819/yolov5 that referenced this pull request May 12, 2021
taicaile pushed a commit to taicaile/yolov5 that referenced this pull request Oct 12, 2021
@Zhiying-Li-dot
Copy link

Hello,I encountered a problem,when I use yolo.py to test my models.I changed this "ch" with 1, 3 and 10.But useless.I can't resolve this problem.I have no idea to handle this channels error.

                from  n    params  module                                  arguments                     
  0                -1  1      2432  models.common.Focus                     [1, 64, 3]                    
  1                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               
  2                -1  3    156928  models.common.C3                        [128, 128, 3]                 
  3                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              
  4                -1  9   1611264  models.common.C3                        [256, 256, 9]                 
  5                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              
  6                -1  9   6433792  models.common.C3                        [512, 512, 9]                 
  7                -1  1   3540480  models.common.Conv                      [512, 768, 3, 2]              
  8                -1  3   5611008  models.common.C3                        [768, 768, 3]                 
  9                -1  1   7079936  models.common.Conv                      [768, 1024, 3, 2]             
 10                -1  1   2624512  models.common.SPP                       [1024, 1024, [3, 5, 7]]       
 11                -1  3   9971712  models.common.C3                        [1024, 1024, 3, False]        
 12                -1  3    301344  models.common.CoordAtt                  [1024, 1024]                  
 13                -1  1    787968  models.common.Conv                      [1024, 768, 1, 1]             
 14                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 15           [-1, 8]  1    148229  models.common.Concat_bifpn              [384, 384]                    
 16                -1  3   5611008  models.common.C3                        [768, 768, 3, False]          
 17                -1  1    394240  models.common.Conv                      [768, 512, 1, 1]              
 18                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 19           [-1, 6]  1     66053  models.common.Concat_bifpn              [256, 256]                    
 20                -1  3   2495488  models.common.C3                        [512, 512, 3, False]          
 21                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]              
 22                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
 23           [-1, 4]  1     16645  models.common.Concat_bifpn              [128, 128]                    
 24                -1  3    625152  models.common.C3                        [256, 256, 3, False]          
 25                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              
 26       [-1, 6, 20]  1     66053  models.common.Concat_bifpn              [256, 256]                    
 27                -1  3   2495488  models.common.C3                        [512, 512, 3, False]          
 28                -1  1   3540480  models.common.Conv                      [512, 768, 3, 2]              
 29       [-1, 8, 16]  1    148229  models.common.Concat_bifpn              [384, 384]                    
 30                -1  3   5611008  models.common.C3                        [768, 768, 3, False]          
 31                -1  1   7079936  models.common.Conv                      [768, 1024, 3, 2]             
 32          [-1, 11]  1    263173  models.common.Concat_bifpn              [512, 512]                    
 33                -1  3   9971712  models.common.C3                        [1024, 1024, 3, False]        
 34  [24, 27, 30, 33]  1    830736  Detect                                  [103, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [256, 512, 768, 1024]]
Traceback (most recent call last):
  File "models/yolo.py", line 317, in <module>
    model = Model(opt.cfg).to(device)
  File "models/yolo.py", line 112, in __init__
    m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
  File "models/yolo.py", line 126, in forward
    return self._forward_once(x, profile, visualize)  # single-scale inference, train
  File "models/yolo.py", line 149, in _forward_once
    x = m(x)  # run
  File "/home/lzy/anaconda3/envs/yolov5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lzy/exper/yolo/models/common.py", line 542, in forward
    x = self.conv(self.act(weight[0] * x[0] + weight[1] * x[1]))
  File "/home/lzy/anaconda3/envs/yolov5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lzy/exper/yolo/models/common.py", line 45, in forward
    return self.act(self.bn(self.conv(x)))
  File "/home/lzy/anaconda3/envs/yolov5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lzy/anaconda3/envs/yolov5/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 443, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/lzy/anaconda3/envs/yolov5/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 440, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [384, 384, 1, 1], expected input[1, 768, 8, 8] to have 384 channels, but got 768 channels instead

@glenn-jocher
Copy link
Member Author

@lizhiying-2019 yes you can use the ch: 10 attribute in your model.yaml to create a YOLOv5 model with 10 input channels. This works correctly, just tested.

Using this model for training is not currently supported, and would require multiple updates to the dataloader code.

@Zhiying-Li-dot
Copy link

emm... I add ch: 10 in my model.yaml. Useless. Where else should I modify? Please tell me what to do. Thanks.

BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this pull request Aug 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to configure channel size
3 participants