Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HF pretrained backbones #31145

Merged
merged 11 commits into from
Jun 6, 2024

Conversation

amyeroberts
Copy link
Collaborator

@amyeroberts amyeroberts commented May 30, 2024

What does this PR do?

Enables loading HF pretrained model weights for backbones.

from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=True)
model = MaskFormerForInstanceSegmentation(config)
  • Updates documentation to show correct initialization
  • Removes check in verify_backbone_config_arguments
  • Adds test for loading HF checkpoints

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts changed the title Hf pretrained backbones Enable HF pretrained backbones May 30, 2024
@@ -333,16 +333,6 @@ config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone
model = MaskFormerForInstanceSegmentation(config) # head
```

You could also load the backbone config separately and then pass it to the model config.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed as it's repeated and included in the section about loading pretrained backbones, but the example loads randomly initialized weights

@@ -50,7 +50,7 @@ def __init__(self, config, **kwargs):
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")

if config.backbone not in timm.list_models():
if config.backbone.split(".")[0] not in timm.list_models():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is because certain timm checkpoints will have the base model + a bunch of specific model specifications e.g. vit_large_patch14_dinov2.lvd142m

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, maybe worth adding this comment to the code or doing it outside if statement with a clear variable name?


config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=True) # backbone and neck config
config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=True) # backbone and neck config
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to the correct checkpoint

self.backbone = AutoBackbone.from_config(
config.backbone_config, attn_implementation=config._attn_implementation
)
self.backbone = load_backbone(config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All backbones should be loaded through load_backbone. Being able to propogate attn_implementation should be done in a follow up (and wasn't being used for depth anything)

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for working on this! Just a few comments, nothing critical

Comment on lines +381 to +385
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a chance that backbone and backbone_config are both None by misconfiguration? Should we throw an error here or check in verify_backbone_config_arguments?

If backbone_model_type will remain None, then if "resnet" in backbone_model_type: will raise an error:

TypeError: argument of type 'NoneType' is not iterable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a chance that backbone and backbone_config are both None by misconfiguration?

It shouldn't happen from a fresh config, as verify_backbone_config_arguments will check that.

It's possible something like this could happen:

from transformers import ConditionalDetrConfig, ConditionalDetrModel

config = ConditionalDetrConfig()
config.backbone = None
config.backbone_config = None

model = ConditionalDetrModel(config)

as it's not uncommon to modify configs post-creation.

If backbone_model_type will remain None, then if "resnet" in backbone_model_type: will raise an error:

Good point. I'll add in an exception is neither are set here and for the other, similar bits of logic

@@ -182,7 +182,7 @@ def __init__(

use_autobackbone = False
if self.is_hybrid:
if backbone_config is None and backbone is None:
if backbone_config is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, in case we use if-if instead of if-elif we can remove logging and config initialization for the first if. This will be handled by the second if:

if backbone_config is None:
-  logger.info("Initializing the config with a `BiT` backbone.")
  backbone_config = {
      "global_padding": "same",
      "layer_type": "bottleneck",
      "depths": [3, 4, 9],
      "out_features": ["stage1", "stage2", "stage3"],
      "embedding_dynamic_padding": True,
  }
-  backbone_config = BitConfig(**backbone_config)
if isinstance(backbone_config, dict):
  logger.info("Initializing the config with a `BiT` backbone.")
  backbone_config = BitConfig(**backbone_config)

@@ -50,7 +50,7 @@ def __init__(self, config, **kwargs):
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")

if config.backbone not in timm.list_models():
if config.backbone.split(".")[0] not in timm.list_models():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, maybe worth adding this comment to the code or doing it outside if statement with a clear variable name?

src/transformers/models/tvp/modeling_tvp.py Show resolved Hide resolved
Comment on lines +363 to +381
Use `use_timm_backbone=True` and `use_pretrained_backbone=True` to load pretrained timm weights for the backbone.

```python
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=True, use_timm_backbone=True) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

Set `use_timm_backbone=True` and `use_pretrained_backbone=False` to load a randomly initialized timm backbone.

```python
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=False, use_timm_backbone=True) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

You could also load the backbone config and use it to create a `TimmBackbone` or pass it to the model config. Timm backbones will load pretrained weights by default. Set `use_pretrained_backbone=False` to load randomly initialized weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Copy link
Collaborator Author

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally left comments as review comments and I'm too lazy to store / delete / readd


self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
Copy link
Collaborator Author

@amyeroberts amyeroberts Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't default to setting these to None if use_autobackbone is True as:

  • Even if is_hybrid is False, some of these values are needed e.g. image_size and patch_size are needed in patch embeddings which are used in DPTViTEmbeddings initialized here.
  • It can lead to surprising behaviour e.g. I pass in DPTConfig(is_hybrid=False, num_hidden_layers=5) and see the num_hidden_layers is then set to None.

@@ -208,9 +208,8 @@ def __init__(
if readout_type != "project":
raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")

elif backbone_config is not None:
elif backbone is not None or backbone_config is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To specify which backbone to load from a config, the user can do one of two things, either:

  • Specify the checkpoint e.g. backbone=microsoft/resnet-10
  • Specify a config e.g. backbone_config=BitConfg()

We need to be able to support both to enable load_backbone i.e. loading timm or HF pretrained and randomly initialized architectures.

Comment on lines +381 to +385
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a chance that backbone and backbone_config are both None by misconfiguration?

It shouldn't happen from a fresh config, as verify_backbone_config_arguments will check that.

It's possible something like this could happen:

from transformers import ConditionalDetrConfig, ConditionalDetrModel

config = ConditionalDetrConfig()
config.backbone = None
config.backbone_config = None

model = ConditionalDetrModel(config)

as it's not uncommon to modify configs post-creation.

If backbone_model_type will remain None, then if "resnet" in backbone_model_type: will raise an error:

Good point. I'll add in an exception is neither are set here and for the other, similar bits of logic

@amyeroberts amyeroberts merged commit bdf36dc into huggingface:main Jun 6, 2024
23 checks passed
@amyeroberts amyeroberts deleted the hf-pretrained-backbones branch June 6, 2024 21:02
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
* Enable load HF or tim backbone checkpoints

* Fix up

* Fix test - pass in proper out_indices

* Update docs

* Fix tvp tests

* Fix doc examples

* Fix doc examples

* Try to resolve DPT backbone param init

* Don't conditionally set to None

* Add condition based on whether backbone is defined

* Address review comments
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 14, 2024
* Enable load HF or tim backbone checkpoints

* Fix up

* Fix test - pass in proper out_indices

* Update docs

* Fix tvp tests

* Fix doc examples

* Fix doc examples

* Try to resolve DPT backbone param init

* Don't conditionally set to None

* Add condition based on whether backbone is defined

* Address review comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants