Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Add a simple classy transformer wrapper to load the ViT models traine…
Browse files Browse the repository at this point in the history
…d in classy vision (#505)

Summary:
Pull Request resolved: #505

as title

Reviewed By: iseessel, QuentinDuval

Differential Revision: D33795085

fbshipit-source-id: f40c9a5c92bf44a5377a361b254eef326ee97ef8
  • Loading branch information
prigoyal authored and facebook-github-bot committed Feb 1, 2022
1 parent 722a7cc commit dc59f07
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions vissl/models/trunks/classy_vision_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from typing import List

import torch
import torch.nn as nn
from classy_vision.models import VisionTransformer as ClassyVisionTransformer
from vissl.config import AttrDict
from vissl.models.trunks import register_model_trunk


@register_model_trunk("classy_vit")
class ClassyViT(nn.Module):
"""
Simple wrapper for ClassyVision Vision Transformer model.
This model is defined on the fly from a Vision Transformer base class and
a configuration file.
"""

def __init__(self, model_config: AttrDict, model_name: str):
super().__init__()
self.model_config = model_config

assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)

logging.info("Building model: Vision Transformer from yaml config")
trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})

self.model = ClassyVisionTransformer(
image_size=trunk_config.image_size,
patch_size=trunk_config.patch_size,
num_layers=trunk_config.num_layers,
num_heads=trunk_config.num_heads,
hidden_dim=trunk_config.hidden_dim,
mlp_dim=trunk_config.mlp_dim,
dropout_rate=trunk_config.dropout_rate,
attention_dropout_rate=trunk_config.attention_dropout_rate,
classifier=trunk_config.classifier,
)

def forward(
self, x: torch.Tensor, out_feat_keys: List[str] = None
) -> List[torch.Tensor]:
x = self.model(x)
x = x.unsqueeze(0)
return x

0 comments on commit dc59f07

Please sign in to comment.