Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support getting model from the name defined in the model-index file. #1236

Merged
merged 7 commits into from
Dec 6, 2022

Conversation

mzr1996
Copy link
Member

@mzr1996 mzr1996 commented Dec 2, 2022

Motivation

Previously, if we want to use a pre-defined model in MMClassification, we have to clone the MMClassification and specify the config path to build a model.

This PR simplified the process. And we can get a model from the name defined in the model-index file.

Modification

  1. Add get_model and list_models API.
  2. Import all modules in the apis package into the top-level package.
  3. Move init_model to mmcls/apis/hub.py.
  4. Remove the options argument in init_model and consider all other keyword arguments as the config options of model. (BC-breaking)

Use cases (Optional)

Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
...     print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'

List all models:

>>> from mmcls import list_models
>>> print(list_models())

List ResNet-50 models on ImageNet-1k dataset:

>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
 'resnet50_8xb32-fp16_in1k',
 'resnet50_8xb256-rsb-a1-600e_in1k',
 'resnet50_8xb256-rsb-a2-300e_in1k',
 'resnet50_8xb256-rsb-a3-100e_in1k']

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • CLA has been signed and all committers have signed the CLA in this PR.

@codecov
Copy link

codecov bot commented Dec 2, 2022

Codecov Report

Base: 0.02% // Head: 89.17% // Increases project coverage by +89.15% 🎉

Coverage data is based on head (5d52d90) compared to base (b8b31e9).
Patch has no changes to coverable lines.

Additional details and impacted files
@@             Coverage Diff              @@
##           dev-1.x    #1236       +/-   ##
============================================
+ Coverage     0.02%   89.17%   +89.15%     
============================================
  Files          121      150       +29     
  Lines         8217    11543     +3326     
  Branches      1368     1848      +480     
============================================
+ Hits             2    10294    +10292     
+ Misses        8215      967     -7248     
- Partials         0      282      +282     
Flag Coverage Δ
unittests 89.17% <ø> (+89.15%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/datasets/transforms/compose.py
mmcls/engine/optimizers/adan_t.py 10.60% <0.00%> (ø)
mmcls/models/backbones/mobileone.py 94.47% <0.00%> (ø)
mmcls/models/necks/reduction.py 100.00% <0.00%> (ø)
mmcls/models/backbones/deit3.py 94.52% <0.00%> (ø)
mmcls/engine/hooks/retriever_hooks.py 72.72% <0.00%> (ø)
mmcls/models/backbones/efficientformer.py 95.08% <0.00%> (ø)
mmcls/models/classifiers/timm.py 25.97% <0.00%> (ø)
mmcls/models/retrievers/image2image.py 90.90% <0.00%> (ø)
mmcls/models/backbones/swin_transformer_v2.py 89.63% <0.00%> (ø)
... and 141 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

mmcls/apis/hub.py Outdated Show resolved Hide resolved
@okotaku
Copy link
Collaborator

okotaku commented Dec 5, 2022

I think it is good if I use inference of classification, but if I use it for detection, I would like to call the same for 21k models where no config exists, as shown below.

https://github.com/mzr1996/mmclassification/blob/1x-get-model/configs/swin_transformer_v2/metafile.yml#L194

class Model:
    def __init__:
        self.backbone = get_model('swin-large-21k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)), neck=None, head=None)
        self.backbone.head = nn.Identity()
        self.nekc = FPN()
        self.head = Head()

Some users may want to use it like timm.

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True, neck=None, head=None)
        self.nekc = GAP()
        self.head = Linear()

A function like reset_classifier may be useful if I use such a call.
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/resnet.py#L704

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True)
        self.backbone.reset_classifier()
        self.nekc = GAP()
        self.head = Linear()

or

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True)
        self.backbone.reset_classifier(num_classes=10)

The scope of support in this area will depend on how you expect your users to use the system.

@mzr1996
Copy link
Member Author

mzr1996 commented Dec 5, 2022

I think it is good if I use inference of classification, but if I use it for detection, I would like to call the same for 21k models where no config exists, as shown below.

https://github.com/mzr1996/mmclassification/blob/1x-get-model/configs/swin_transformer_v2/metafile.yml#L194

class Model:
    def __init__:
        self.backbone = get_model('swin-large-21k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)), neck=None, head=None)
        self.backbone.head = nn.Identity()
        self.nekc = FPN()
        self.head = Head()

Some users may want to use it like timm.

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True, neck=None, head=None)
        self.nekc = GAP()
        self.head = Linear()

A function like reset_classifier may be useful if I use such a call. https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/resnet.py#L704

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True)
        self.backbone.reset_classifier()
        self.nekc = GAP()
        self.head = Linear()

or

class Model:
    def __init__:
        self.backbone = get_model('resnet50_8xb32_in1k', pretrained=True)
        self.backbone.reset_classifier(num_classes=10)

The scope of support in this area will depend on how you expect your users to use the system.

This first problem can be avoided by adding missing config files.
As for the reset_classifier, my expected usage is to modify the num_classes in the head directly, like

from mmcls import get_model
model = get_model('resnet50_8xb32_in1k', pretrained=True, head=dict(num_classes=10))

Do you think we need to implement the reset_classifier API?

@okotaku
Copy link
Collaborator

okotaku commented Dec 5, 2022

@mzr1996 It looks good. No need to add reset_classifier.
Do you expect users adding models to hub, as in Huggingface?
If so, it might be a good idea to write how to add them in the documentation.

@Ezra-Yu
Copy link
Collaborator

Ezra-Yu commented Dec 5, 2022

from mmcls import get_model
model = get_model('resnet50_8xb32_in1k', pretrained=True, head=dict(num_classes=10))

Suppose a user doesn't know the concepts of backbone, neck, and head. But if he were to understand head=dict(num_classes=100), he would need to know the splitting of the model into ''backbone'', ''neck'', ''head'', and he would also need the config registration mechanism as well as inheritance.

Maybe, we can simplify this by the following:

from mmcls import get_model
model = get_model('resnet50_8xb32_in1k', pretrained=True, num_classes=10)
# model = get_model('resnet50_8xb32_in1k', pretrained=True) # default IN1k or IN21k pretrain

@mzr1996
Copy link
Member Author

mzr1996 commented Dec 5, 2022

@mzr1996 It looks good. No need to add reset_classifier. Do you expect users adding models to hub, as in Huggingface? If so, it might be a good idea to write how to add them in the documentation.

Yes, but by now, users can only register their own config files by metafile.yml. It's a little complex for users.

from mmcls.apis import ModelHub

ModelHub.register_model_index('my_metafile.yml')

And I want to design some other methods to register new models like timm-style. I think it's better to add this documentation later.

@mzr1996 mzr1996 requested a review from tonysy December 5, 2022 10:27
Copy link
Collaborator

@tonysy tonysy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mzr1996 mzr1996 merged commit c127c47 into open-mmlab:dev-1.x Dec 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants