Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why not freeze layers for finetuning? #1264

Closed
mphillips-valleyit opened this issue Nov 3, 2020 · 13 comments
Closed

Why not freeze layers for finetuning? #1264

mphillips-valleyit opened this issue Nov 3, 2020 · 13 comments
Labels
question Further information is requested Stale

Comments

@mphillips-valleyit
Copy link

❔Question

Hi @glenn-jocher, I'm just wondering if it was a conscious decision not to freeze lower layers in the model (e.g. some or all of the backbone) when finetuning. My own experience (though not tested here yet) is that it is not beneficial to allow lower layers to be retrained from a fine-tuning dataset, particularly when that dataset is small--not to mention training is faster when you only train top layers. Thoughts?

Additional context

@mphillips-valleyit mphillips-valleyit added the question Further information is requested label Nov 3, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Nov 3, 2020

Hello @mphillips-valleyit, thank you for your interest in our work! Please visit our Custom Training Tutorial to get started, and see our Jupyter Notebook Open In Colab, Docker Image, and Google Cloud Quickstart Guide for example environments.

If this is a bug report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom model or data training question, please note Ultralytics does not provide free personal support. As a leader in vision ML and AI, we do offer professional consulting, from simple expert advice up to delivery of fully customized, end-to-end production solutions for our clients, such as:

  • Cloud-based AI systems operating on hundreds of HD video streams in realtime.
  • Edge AI integrated into custom iOS and Android apps for realtime 30 FPS video inference.
  • Custom data training, hyperparameter evolution, and model exportation to any destination.

For more information please visit https://www.ultralytics.com.

@glenn-jocher
Copy link
Member

@mphillips-valleyit I've never observed layer freezing improving performance. This was widely discussed in YOLOv3, i.e. ultralytics/yolov3#106

@mphillips-valleyit
Copy link
Author

Thanks for that link @glenn-jocher that was a very helpful discussion. I note though that the only freezing data you showed involved freezing all but the last layer.

ultralytics/yolov3#106 (comment)

Did you also experiment with freezing say the backbone only? I'm checking into this now but if you have additional results there it would be great to know that.

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 4, 2020

@mphillips-valleyit right. It would be interesting to see a more complete study, perhaps finetuning VOC, as that can be done on a Colab notebook. Though I suspect in general the results will vary greatly by dataset.

If you want, you can try 3 scenarios:

  • finetune YOLOv5m VOC with default settings.
  • finetune YOLOv5m VOC freezing backbone.
  • finetune YOLOv5m VOC freezing all layers except outputs.

The command for finetuning VOC is here using evolved hyps. This reaches to about 0.90 mAP@0.50 after 50 epochs.

python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50 --hyps hyp.finetune.yaml

Make sure you visualize your results with Tensorboard, W&B (pip install wandb), or at the minimum by plotting all 3 runs on the same results.png with utils/general.plot_results().

@glenn-jocher
Copy link
Member

@mphillips-valleyit if you look at the model.yaml you can create an iterator for which layers to freeze. The YOLOv5 backbone is contained in layers 0-9, so your freeze list would be:

    freeze = ['model.%s.' % x for x in range(9)]  # parameter names to freeze (full or partial)
    # ['model.0.', 'model.1.', 'model.2.', 'model.3.', 'model.4.', 'model.5.', 'model.6.', 'model.7.', 'model.8.']

@glenn-jocher
Copy link
Member

freeze = ['model.%s.' % x for x in range(9)]  # parameter names to freeze (full or partial)
... for k, v in model.named_parameters():
...     v.requires_grad = True  # train all layers
...     if any(x in k for x in freeze):
...         print('freezing %s' % k)
...         v.requires_grad = False
... 
freezing model.0.conv.conv.weight
freezing model.0.conv.bn.weight
freezing model.0.conv.bn.bias
freezing model.1.conv.weight
freezing model.1.bn.weight
freezing model.1.bn.bias
freezing model.2.cv1.conv.weight
freezing model.2.cv1.bn.weight
freezing model.2.cv1.bn.bias
freezing model.2.cv2.weight
freezing model.2.cv3.weight
freezing model.2.cv4.conv.weight
freezing model.2.cv4.bn.weight
freezing model.2.cv4.bn.bias
freezing model.2.bn.weight
freezing model.2.bn.bias
freezing model.2.m.0.cv1.conv.weight
freezing model.2.m.0.cv1.bn.weight
freezing model.2.m.0.cv1.bn.bias
freezing model.2.m.0.cv2.conv.weight
freezing model.2.m.0.cv2.bn.weight
freezing model.2.m.0.cv2.bn.bias
freezing model.3.conv.weight
freezing model.3.bn.weight
freezing model.3.bn.bias
freezing model.4.cv1.conv.weight
freezing model.4.cv1.bn.weight
freezing model.4.cv1.bn.bias
freezing model.4.cv2.weight
freezing model.4.cv3.weight
freezing model.4.cv4.conv.weight
freezing model.4.cv4.bn.weight
freezing model.4.cv4.bn.bias
freezing model.4.bn.weight
freezing model.4.bn.bias
freezing model.4.m.0.cv1.conv.weight
freezing model.4.m.0.cv1.bn.weight
freezing model.4.m.0.cv1.bn.bias
freezing model.4.m.0.cv2.conv.weight
freezing model.4.m.0.cv2.bn.weight
freezing model.4.m.0.cv2.bn.bias
freezing model.4.m.1.cv1.conv.weight
freezing model.4.m.1.cv1.bn.weight
freezing model.4.m.1.cv1.bn.bias
freezing model.4.m.1.cv2.conv.weight
freezing model.4.m.1.cv2.bn.weight
freezing model.4.m.1.cv2.bn.bias
freezing model.4.m.2.cv1.conv.weight
freezing model.4.m.2.cv1.bn.weight
freezing model.4.m.2.cv1.bn.bias
freezing model.4.m.2.cv2.conv.weight
freezing model.4.m.2.cv2.bn.weight
freezing model.4.m.2.cv2.bn.bias
freezing model.5.conv.weight
freezing model.5.bn.weight
freezing model.5.bn.bias
freezing model.6.cv1.conv.weight
freezing model.6.cv1.bn.weight
freezing model.6.cv1.bn.bias
freezing model.6.cv2.weight
freezing model.6.cv3.weight
freezing model.6.cv4.conv.weight
freezing model.6.cv4.bn.weight
freezing model.6.cv4.bn.bias
freezing model.6.bn.weight
freezing model.6.bn.bias
freezing model.6.m.0.cv1.conv.weight
freezing model.6.m.0.cv1.bn.weight
freezing model.6.m.0.cv1.bn.bias
freezing model.6.m.0.cv2.conv.weight
freezing model.6.m.0.cv2.bn.weight
freezing model.6.m.0.cv2.bn.bias
freezing model.6.m.1.cv1.conv.weight
freezing model.6.m.1.cv1.bn.weight
freezing model.6.m.1.cv1.bn.bias
freezing model.6.m.1.cv2.conv.weight
freezing model.6.m.1.cv2.bn.weight
freezing model.6.m.1.cv2.bn.bias
freezing model.6.m.2.cv1.conv.weight
freezing model.6.m.2.cv1.bn.weight
freezing model.6.m.2.cv1.bn.bias
freezing model.6.m.2.cv2.conv.weight
freezing model.6.m.2.cv2.bn.weight
freezing model.6.m.2.cv2.bn.bias
freezing model.7.conv.weight
freezing model.7.bn.weight
freezing model.7.bn.bias
freezing model.8.cv1.conv.weight
freezing model.8.cv1.bn.weight
freezing model.8.cv1.bn.bias
freezing model.8.cv2.conv.weight
freezing model.8.cv2.bn.weight
freezing model.8.cv2.bn.bias

model.info()
Model Summary: 191 layers, 7.46816e+06 parameters, 4.41625e+06 gradients

@mphillips-valleyit
Copy link
Author

Yup exactly, I found the leftover 'freeze' code in train.py as well. Looking forward to trying this out.

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 6, 2020

@mphillips-valleyit I've performed a full analysis of transfer learning with varying degrees of layer freezing. I've created an official Tutorial now from my results. See https://docs.ultralytics.com/yolov5/tutorials/transfer_learning_with_frozen_layers

@mphillips-valleyit
Copy link
Author

Hi @glenn-jocher thanks a lot for your work here. I have since looked at this a bit but not yet truly systematically.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@pragyan430
Copy link

pragyan430 commented Apr 5, 2021

@glenn-jocher Hi I am trying to use the weights of my own previously trained model through transfer learning for my new model. I've loaded my weights and frozen the first 23 layers but except the last one. So, as I try to train I keep on running into this error:

Traceback (most recent call last):
  File "train.py", line 536, in <module>
    train(hyp, opt, device, tb_writer, wandb)
  File "train.py", line 151, in train
    optimizer.load_state_dict(ckpt['optimizer'])
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 148, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

I don't know what is causing this problem both the original and new model use yolov5x.pt. Why is this happening??

@glenn-jocher
Copy link
Member

glenn-jocher commented Apr 5, 2021

@pragyan430 for a tutorial on freezing layers see Transfer Learning Tutorial:

YOLOv5 Tutorials

If you believe you have a reproducible issue, we suggest you close this issue and raise a new one using the 🐛 Bug Report template, providing screenshots and a minimum reproducible example to help us better understand and diagnose your problem. Thank you!

@pragyan430
Copy link

@glenn-jocher Yes, I followed every steps given in the tutorial but still get the error above. I cannot find the problem. I just supplied the best.pt of the previous model to the weights argument of train.py. Then, I froze the layers till 23 .However, I still get this error

Traceback (most recent call last):
  File "train.py", line 536, in <module>
    train(hyp, opt, device, tb_writer, wandb)
  File "train.py", line 151, in train
    optimizer.load_state_dict(ckpt['optimizer'])
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 148, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested Stale
Projects
None yet
Development

No branches or pull requests

3 participants