diff --git a/vissl/config/defaults.yaml b/vissl/config/defaults.yaml index c9b3bc7e4..052334eac 100644 --- a/vissl/config/defaults.yaml +++ b/vissl/config/defaults.yaml @@ -692,6 +692,34 @@ config: QKV_BIAS: True # Bias for QKV in attention layers. QK_SCALE: False # Scale + # ------------------------------------------------------------- # + # BEiT. + # https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py + # https://arxiv.org/pdf/2106.08254.pdf + # ------------------------------------------------------------- # + BEIT: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + NUM_LAYERS: 12 + NUM_HEADS: 12 + HIDDEN_DIM: 768 + MLP_RATIO: 4.0 + # MLP and projection layer dropout rate + DROPOUT_RATE: 0 + # Attention dropout rate + ATTENTION_DROPOUT_RATE: 0 + # Stochastic depth dropout rate. Turning on stochastic depth and + # using aggressive augmentation is essentially the difference + # between a DeiT and a ViT. + DROP_PATH_RATE: 0 + QKV_BIAS: False # Bias for QKV in attention layers. + QK_SCALE: False # Scale + USE_ABS_POS_EMB: True + USE_REL_POS_BIAS: False + USE_SHARED_REL_POS_BIAS: False + USE_MEAN_POOLING: True + INIT_VALUES: False + # ------------------------------------------------------------- # # Parameters unique to the ConViT and not used for standard vision # transformers diff --git a/vissl/models/model_helpers.py b/vissl/models/model_helpers.py index 281dab47b..b93e681c1 100644 --- a/vissl/models/model_helpers.py +++ b/vissl/models/model_helpers.py @@ -649,6 +649,9 @@ def __init__(self, drop_prob=None): def forward(self, x): return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) diff --git a/vissl/models/trunks/beit_transformer.py b/vissl/models/trunks/beit_transformer.py new file mode 100644 index 000000000..d3867ab58 --- /dev/null +++ b/vissl/models/trunks/beit_transformer.py @@ -0,0 +1,489 @@ +# Portions Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Code modified from https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py # NOQA +""" +import copy +import logging +import math +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vissl.config import AttrDict +from vissl.models.model_helpers import DropPath, to_2tuple, trunc_normal_ +from vissl.models.trunks import register_model_trunk + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + # qkv = self.qkv(x).reshape( + # B, N, 3, self.num_heads, C // self.num_heads + # ).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if init_values > 0: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) + ) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], ( + f"Input image size ({H}*{W}) doesn't match model " + f"({self.img_size[0]}*{self.img_size[1]})." + ) + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +@register_model_trunk("beit") +class BEiT(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__( + self, + model_config: AttrDict, + model_name: str, + ): + super().__init__() + + assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" + trunk_config = copy.deepcopy(model_config.TRUNK.BEIT) + + logging.info("Building model: BEiT Vision Transformer from yaml config") + # Hacky workaround + trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()}) + + img_size = trunk_config.image_size + patch_size = trunk_config.patch_size + in_chans = 3 + embed_dim = trunk_config.hidden_dim + depth = trunk_config.num_layers + num_heads = trunk_config.num_heads + mlp_ratio = trunk_config.mlp_ratio + qkv_bias = trunk_config.qkv_bias + qk_scale = trunk_config.qk_scale + drop_rate = trunk_config.dropout_rate + attn_drop_rate = trunk_config.attention_dropout_rate + drop_path_rate = trunk_config.drop_path_rate + init_values = trunk_config.init_values + use_abs_pos_emb = trunk_config.use_abs_pos_emb + use_rel_pos_bias = trunk_config.use_rel_pos_bias + use_shared_rel_pos_bias = trunk_config.use_shared_rel_pos_bias + use_mean_pooling = trunk_config.use_mean_pooling + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_embed.patch_shape, num_heads=num_heads + ) + else: + self.rel_pos_bias = None + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.patch_shape + if use_rel_pos_bias + else None, + ) + for i in range(depth) + ] + ) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + # trunc_normal_(self.mask_token, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def prepare_tokens(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand( + batch_size, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + return x + + def forward_features(self, x): + x = self.prepare_tokens(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias=rel_pos_bias) + + x = self.norm(x) + return x[:, 0] + + def get_intermediate_features(self, x, names): + interms = [] + + x = self.prepare_tokens(x) + + # get feature from every intermediate block and apply norm + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias=rel_pos_bias) + interms.append(self.norm(x)) + + # feature names are as follows + # blkCLS[integer] => CLS token of blk[integer] + # concatCLS[integer] => concat of CLS token from last #"integer" blocks + # lastCLS => CLS token of last block + + output = [] + for name in names: + if name.startswith("blkCLS"): + v = int(name.replace("blkCLS", "")) + output.append(interms[v][:, 0]) + elif name.startswith("concatCLS"): + v = int(name.replace("concatCLS", "")) + feat = torch.cat([x[:, 0] for x in interms[-v:]], dim=-1) + output.append(feat) + elif name == "lastCLS": + output.append(interms[-1][:, 0]) + return output + + def forward(self, x: torch.Tensor, out_feat_keys: List[str] = None): + if out_feat_keys is None or len(out_feat_keys) == 0: + x = self.forward_features(x) + x = x.unsqueeze(0) + else: + # we specified a feature layer name. Follow DINO + # (https://github.com/facebookresearch/dino/blob/main/eval_linear.py#L159) # NOQA + x = self.get_intermediate_features(x, out_feat_keys) + return x