From 056e695a84ed1e1c63320ddeecb577a9e03cde63 Mon Sep 17 00:00:00 2001 From: Mathilde Caron Date: Wed, 6 Oct 2021 06:46:12 -0700 Subject: [PATCH] xcit integration (#187) Summary: Integration of XCiT into VISSL :) Pull Request resolved: https://github.com/fairinternal/ssl_scaling/pull/187 Reviewed By: iseessel Differential Revision: D31378757 Pulled By: prigoyal fbshipit-source-id: d383390a3ae4def585eaf7bef6d743e9c5c059e8 --- .../dino/dino_16gpus_xcit_small_12_p16.yaml | 140 ++++++ vissl/config/defaults.yaml | 19 + vissl/models/trunks/xcit.py | 401 ++++++++++++++++++ vissl/utils/hydra_config.py | 2 +- 4 files changed, 561 insertions(+), 1 deletion(-) create mode 100644 configs/config/pretrain/dino/dino_16gpus_xcit_small_12_p16.yaml create mode 100644 vissl/models/trunks/xcit.py diff --git a/configs/config/pretrain/dino/dino_16gpus_xcit_small_12_p16.yaml b/configs/config/pretrain/dino/dino_16gpus_xcit_small_12_p16.yaml new file mode 100644 index 000000000..c8c10e698 --- /dev/null +++ b/configs/config/pretrain/dino/dino_16gpus_xcit_small_12_p16.yaml @@ -0,0 +1,140 @@ +# @package _global_ +config: + VERBOSE: False + LOG_FREQUENCY: 10 + TEST_ONLY: False + TEST_MODEL: False + SEED_VALUE: 0 + MULTI_PROCESSING_METHOD: forkserver + HOOKS: + PERF_STATS: + MONITOR_PERF_STATS: True + PERF_STAT_FREQUENCY: 40 + ROLLING_BTIME_FREQ: 5 + DATA: + NUM_DATALOADER_WORKERS: 10 + TRAIN: + DATA_SOURCES: [disk_folder] + DATASET_NAMES: [imagenet1k_folder] + BATCHSIZE_PER_REPLICA: 64 + LABEL_TYPE: sample_index # just an implementation detail. Label isn't used + TRANSFORMS: + - name: ImgPilToMultiCrop + total_num_crops: 10 + size_crops: [224, 96] + num_crops: [2, 8] + crop_scales: [[0.3, 1], [0.05, 0.3]] + - name: RandomHorizontalFlip + p: 0.5 + - name: ImgPilColorDistortion + strength: 0.5 + - name: ImgPilMultiCropRandomApply + transforms: [{"name": "ImgPilGaussianBlur", "p": 1., "radius_min": 0.1, "radius_max": 2.0}] + prob: [1., 0.1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + - name: ImgPilMultiCropRandomApply + transforms: [{"name": "ImgPilRandomSolarize", "p": 1.}] + prob: [0., 0.2, 0., 0., 0, 0, 0, 0, 0, 0] + - name: ToTensor + - name: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + COLLATE_FUNCTION: multicrop_collator + MMAP_MODE: True + COPY_TO_LOCAL_DISK: False + COPY_DESTINATION_DIR: /tmp/imagenet1k/ + DROP_LAST: True + TRAINER: + TRAIN_STEP_NAME: standard_train_step + METERS: + name: "" + MODEL: + TRUNK: + NAME: xcit + XCIT: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + HIDDEN_DIM: 384 + NUM_LAYERS: 12 + NUM_HEADS: 8 + DROPOUT_RATE: 0 + ATTENTION_DROPOUT_RATE: 0 + DROP_PATH_RATE: 0.05 + ETA: 1 + TOKENS_NORM: True + QKV_BIAS: True + QK_SCALE: False + HEAD: + PARAMS: [ + ["swav_head", {"use_weight_norm_prototypes": True, "dims": [384, 2048, 2048, 256], "use_bn": False, "return_embeddings": False, "activation_name": "GELU", "num_clusters": [65536]}], + ] + TEMP_FROZEN_PARAMS_ITER_MAP: [ + ['module.heads.0.prototypes0.weight_v', 1251], + ['module.heads.0.prototypes0.weight_g', 1251], + ] + AMP_PARAMS: + AMP_TYPE: pytorch + USE_AMP: True + LOSS: + name: dino_loss + dino_loss: + momentum: 0.996 + teacher_temp_warmup_iters: 37530 # 30 epochs + teacher_temp_min: 0.04 + teacher_temp_max: 0.07 + ema_center: 0.9 + normalize_last_layer: false + OPTIMIZER: + name: adamw + momentum: 0.9 + nesterov: False + num_epochs: 300 + regularize_bn: False + regularize_bias: False + param_schedulers: + lr_head: + name: composite + schedulers: + - name: linear + start_value: 0.00001 + end_value: 0.002 + - name: cosine + start_value: 0.002 + end_value: 0.00001 + update_interval: epoch + interval_scaling: [rescaled, fixed] + lengths: [0.0333, 0.9667] + lr: + name: composite + schedulers: + - name: linear + start_value: 0.00001 + end_value: 0.002 + - name: cosine + start_value: 0.002 + end_value: 0.00001 + update_interval: epoch + interval_scaling: [rescaled, fixed] + lengths: [0.0333, 0.9667] + weight_decay: + name: cosine + start_value: 0.04 + end_value: 0.4 + update_interval: epoch + weight_decay_head: + name: cosine + start_value: 0.04 + end_value: 0.4 + update_interval: epoch + DISTRIBUTED: + BACKEND: nccl + NUM_NODES: 2 + NUM_PROC_PER_NODE: 8 + INIT_METHOD: tcp + RUN_ID: auto + MACHINE: + DEVICE: gpu + CHECKPOINT: + DIR: "." + AUTO_RESUME: True + CHECKPOINT_FREQUENCY: 5 + OVERWRITE_EXISTING: true diff --git a/vissl/config/defaults.yaml b/vissl/config/defaults.yaml index c7ac1e494..8aaf11194 100644 --- a/vissl/config/defaults.yaml +++ b/vissl/config/defaults.yaml @@ -622,6 +622,25 @@ config: QKV_BIAS: False # Bias for QKV in attention layers. QK_SCALE: False # Scale + # ------------------------------------------------------------- # + # XCiT (https://arxiv.org/abs/2106.09681) + # default config is that of xcit_small_12_p16 + # ------------------------------------------------------------- # + XCIT: + name: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + HIDDEN_DIM: 384 + NUM_LAYERS: 12 + NUM_HEADS: 8 + DROPOUT_RATE: 0 + ATTENTION_DROPOUT_RATE: 0 + DROP_PATH_RATE: 0.05 + ETA: 1 + TOKENS_NORM: True + QKV_BIAS: True # Bias for QKV in attention layers. + QK_SCALE: False # Scale + # ------------------------------------------------------------- # # Parameters unique to the ConViT and not used for standard vision # transformers diff --git a/vissl/models/trunks/xcit.py b/vissl/models/trunks/xcit.py new file mode 100644 index 000000000..38f48347c --- /dev/null +++ b/vissl/models/trunks/xcit.py @@ -0,0 +1,401 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Code modified from https://github.com/facebookresearch/xcit/blob/main/xcit.py # NOQA +""" + +import copy +import logging +import math +from functools import partial + +import torch +import torch.nn as nn +from functools import partial +from vissl.config import AttrDict +from vissl.models.model_helpers import DropPath, to_2tuple, trunc_normal_ +from vissl.models.trunks import register_model_trunk +from vissl.models.trunks.vision_transformer import Mlp + + +class PositionalEncodingFourier(nn.Module): + """ + Positional encoding relying on a fourier kernel matching the one used in the + "Attention is all of Need" paper. The implementation builds on DeTR code + https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py + """ + + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, B, H, W): + mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return torch.nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ), + nn.SyncBatchNorm(out_planes) + ) + + +class ConvPatchEmbed(nn.Module): + """ Image to Patch Embedding using multiple convolutional layers + """ + + def __init__(self, img_size=224, patch_size=16, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + if patch_size[0] == 16: + self.proj = torch.nn.Sequential( + conv3x3(3, embed_dim // 8, 2), + nn.GELU(), + conv3x3(embed_dim // 8, embed_dim // 4, 2), + nn.GELU(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + nn.GELU(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + elif patch_size[0] == 8: + self.proj = torch.nn.Sequential( + conv3x3(3, embed_dim // 4, 2), + nn.GELU(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + nn.GELU(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + + def forward(self, x, padding_size=None): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) + + return x, (Hp, Wp) + + +class LPI(nn.Module): + """ + Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows + to augment the implicit communcation performed by the block diagonal scatter attention. + Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + drop=0., kernel_size=3): + super().__init__() + out_features = out_features or in_features + + padding = kernel_size // 2 + + self.conv1 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size, + padding=padding, groups=out_features) + self.act = act_layer() + self.bn = nn.SyncBatchNorm(in_features) + self.conv2 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size, + padding=padding, groups=out_features) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.act(x) + x = self.bn(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + + return x + + +class ClassAttention(nn.Module): + """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239 + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + qc = q[:, :, 0:1] # CLS token + attn_cls = (qc * k).sum(dim=-1) * self.scale + attn_cls = attn_cls.softmax(dim=-1) + attn_cls = self.attn_drop(attn_cls) + + cls_tkn = (attn_cls.unsqueeze(2) @ v).transpose(1, 2).reshape(B, 1, C) + cls_tkn = self.proj(cls_tkn) + x = torch.cat([self.proj_drop(cls_tkn), x[:, 1:]], dim=1) + return x + + +class ClassAttentionBlock(nn.Module): + """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239 + """ + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., + attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=None, + tokens_norm=False): + super().__init__() + self.norm1 = norm_layer(dim) + + self.attn = ClassAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, + drop=drop) + + if eta is not None: # LayerScale Initialization (no layerscale when None) + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # FIXME: A hack for models pre-trained with layernorm over all the tokens not just the CLS + self.tokens_norm = tokens_norm + + def forward(self, x, H, W, mask=None): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + if self.tokens_norm: + x = self.norm2(x) + else: + x[:, 0:1] = self.norm2(x[:, 0:1]) + + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.mlp(cls_token) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class XCA(nn.Module): + """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted + sum. The weights are obtained from the (softmax normalized) Cross-covariance + matrix (Q^T K \\in d_h \\times d_h) + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q.transpose(-2, -1) + k = k.transpose(-2, -1) + v = v.transpose(-2, -1) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class XCABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., + attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + num_tokens=196, eta=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = XCA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, + drop=drop) + + self.norm3 = norm_layer(dim) + self.local_mp = LPI(in_features=dim, act_layer=act_layer) + + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + + def forward(self, x, H, W): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + +@register_model_trunk("xcit") +class XCiT(nn.Module): + """ + Based on timm and DeiT code bases + https://github.com/rwightman/pytorch-image-models/tree/master/timm + https://github.com/facebookresearch/deit/ + """ + + def __init__(self, model_config: AttrDict, model_name: str): + super().__init__() + + assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" + trunk_config = copy.deepcopy(model_config.TRUNK.XCIT) + + logging.info("Building model: XCiT from yaml config") + # Hacky workaround + trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()}) + img_size = trunk_config.image_size + patch_size = trunk_config.patch_size + embed_dim = trunk_config.hidden_dim + depth = trunk_config.num_layers + num_heads = trunk_config.num_heads + mlp_ratio = 4.0 + qkv_bias = trunk_config.qkv_bias + qk_scale = trunk_config.qk_scale + drop_rate = trunk_config.dropout_rate + attn_drop_rate = trunk_config.attention_dropout_rate + drop_path_rate = trunk_config.drop_path_rate + eta = trunk_config.eta + tokens_norm = trunk_config.tokens_norm + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim, + patch_size=patch_size) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [drop_path_rate for i in range(depth)] + self.blocks = nn.ModuleList([ + XCABlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, num_tokens=num_patches, eta=eta) + for i in range(depth)]) + + cls_attn_layers = 2 + self.cls_attn_blocks = nn.ModuleList([ + ClassAttentionBlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, + eta=eta, tokens_norm=tokens_norm) + for i in range(cls_attn_layers)]) + self.norm = norm_layer(embed_dim) + + self.pos_embeder = PositionalEncodingFourier(dim=embed_dim) + self.use_pos = True + + # Classifier head + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos: + pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, Hp, Wp) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for blk in self.cls_attn_blocks: + x = blk(x, Hp, Wp) + + x = self.norm(x)[:, 0] + return x + + def forward(self, x, out_feat_keys=None): + x = self.forward_features(x) + x = x.unsqueeze(0) + return x diff --git a/vissl/utils/hydra_config.py b/vissl/utils/hydra_config.py index a236766ba..ba5af521c 100644 --- a/vissl/utils/hydra_config.py +++ b/vissl/utils/hydra_config.py @@ -451,7 +451,7 @@ def infer_losses_config(cfg): queue_length // world_size ) - # some inference for Simdist loss. + # some inference for DINO loss. if cfg.LOSS.name == "dino_loss": assert len(cfg.MODEL.HEAD.PARAMS) == 1 assert cfg.MODEL.HEAD.PARAMS[0][0] == "swav_head"