Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Add simclrv2 vissl compatible models pretrained and in1k finetuned (#512
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #512

as title

Reviewed By: QuentinDuval

Differential Revision: D33794631

fbshipit-source-id: 252094911ee0d7f8ad335ed58a19104303261f71
  • Loading branch information
prigoyal authored and facebook-github-bot committed Feb 1, 2022
1 parent 7877a45 commit 1260872
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 10 deletions.
10 changes: 10 additions & 0 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ config:
# if we want to evaluate the full model, this requires loading the head weights as well
# from model weights file. In this case, set the following to True.
EVAL_TRUNK_AND_HEAD: False
# whether to assert that the head layers shape matches exactly
ASSERT_HEAD_LAYER_SHAPE_INIT: True
# whether features should be flattened to result in N x D feature shape
SHOULD_FLATTEN_FEATS: True
# model features that should be evaluated for linear classification and what
Expand Down Expand Up @@ -634,6 +636,14 @@ config:
# use this so we set the default as 2.
LAYER4_STRIDE: 2

# ------------------------------------------------------------- #
# SimCLR-v2 ResNet params that includes selective kernel
# ------------------------------------------------------------- #
RESNETS_SK:
DEPTH: 152
WIDTH_MULTIPLIER: 3
SK_RATIO: 0.0625

# ------------------------------------------------------------- #
# EfficientNet params
# ------------------------------------------------------------- #
Expand Down
3 changes: 2 additions & 1 deletion vissl/models/heads/swav_prototypes_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(
This could be particularly useful when performing full finetuning on
hidden layers.
use_weight_norm_prototypes (bool): whether to use weight norm module for the prototypes layers.
use_weight_norm_prototypes (bool): whether to use weight norm module for the
prototypes layers.
"""

super().__init__()
Expand Down
194 changes: 194 additions & 0 deletions vissl/models/trunks/resnext_selective_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Portions Copyright (c) Facebook, Inc. and its affiliates.

# Code from: https://github.com/Separius/SimCLRv2-Pytorch/blob/main/resnet.py
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.functional as F
from vissl.config import AttrDict
from vissl.models.trunks import register_model_trunk

BATCH_NORM_EPSILON = 1e-5
BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well


BLOCK_CONFIG = {
152: (3, 8, 36, 3),
}


class SUPPORTED_DEPTHS(int, Enum):
RN152 = 152


class BatchNormRelu(nn.Sequential):
def __init__(self, num_channels, relu=True):
super().__init__(
nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON),
nn.ReLU() if relu else nn.Identity(),
)


def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
bias=bias,
)


class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32):
super().__init__()
assert sk_ratio > 0.0
self.main_conv = nn.Sequential(
conv(in_channels, 2 * out_channels, stride=stride),
BatchNormRelu(2 * out_channels),
)
mid_dim = max(int(out_channels * sk_ratio), min_dim)
self.mixing_conv = nn.Sequential(
conv(out_channels, mid_dim, kernel_size=1),
BatchNormRelu(mid_dim),
conv(mid_dim, 2 * out_channels, kernel_size=1),
)

def forward(self, x):
x = self.main_conv(x)
x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W
g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True)
m = self.mixing_conv(g)
m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1
return (x * F.softmax(m, dim=0)).sum(dim=0)


class Projection(nn.Module):
def __init__(self, in_channels, out_channels, stride, sk_ratio=0):
super().__init__()
if sk_ratio > 0:
self.shortcut = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
# kernel_size = 2 => padding = 1
nn.AvgPool2d(kernel_size=2, stride=stride, padding=0),
conv(in_channels, out_channels, kernel_size=1),
)
else:
self.shortcut = conv(
in_channels, out_channels, kernel_size=1, stride=stride
)
self.bn = BatchNormRelu(out_channels, relu=False)

def forward(self, x):
return self.bn(self.shortcut(x))


class BottleneckBlock(nn.Module):
expansion = 4

def __init__(
self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False
):
super().__init__()
if use_projection:
self.projection = Projection(
in_channels, out_channels * 4, stride, sk_ratio
)
else:
self.projection = nn.Identity()
ops = [
conv(in_channels, out_channels, kernel_size=1),
BatchNormRelu(out_channels),
]
if sk_ratio > 0:
ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio))
else:
ops.append(conv(out_channels, out_channels, stride=stride))
ops.append(BatchNormRelu(out_channels))
ops.append(conv(out_channels, out_channels * 4, kernel_size=1))
ops.append(BatchNormRelu(out_channels * 4, relu=False))
self.net = nn.Sequential(*ops)

def forward(self, x):
shortcut = self.projection(x)
return F.relu(shortcut + self.net(x))


class Blocks(nn.Module):
def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0):
super().__init__()
self.blocks = nn.ModuleList(
[BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)]
)
self.channels_out = out_channels * BottleneckBlock.expansion
for _ in range(num_blocks - 1):
self.blocks.append(
BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio)
)

def forward(self, x):
for b in self.blocks:
x = b(x)
return x


class Stem(nn.Sequential):
def __init__(self, sk_ratio, width_multiplier):
ops = []
channels = 64 * width_multiplier // 2
if sk_ratio > 0:
ops.append(conv(3, channels, stride=2))
ops.append(BatchNormRelu(channels))
ops.append(conv(channels, channels))
ops.append(BatchNormRelu(channels))
ops.append(conv(channels, channels * 2))
else:
ops.append(conv(3, channels * 2, kernel_size=7, stride=2))
ops.append(BatchNormRelu(channels * 2))
ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
super().__init__(*ops)


@register_model_trunk("resnet_sk")
class ResNetSelectiveKernel(nn.Module):
def __init__(self, model_config: AttrDict, model_name: str):
super(ResNetSelectiveKernel, self).__init__()
self.model_config = model_config

self.trunk_config = self.model_config.TRUNK.RESNETS_SK
self.depth = SUPPORTED_DEPTHS(self.trunk_config.DEPTH)
self.width_multiplier = self.trunk_config.WIDTH_MULTIPLIER
self.sk_ratio = self.trunk_config.SK_RATIO

layers = BLOCK_CONFIG[self.depth]
width_multiplier = self.width_multiplier
sk_ratio = self.sk_ratio

logging.info(
f"Building model: ResNet-SK"
f"-d{self.depth}-{self.width_multiplier}x"
f"-sk{self.sk_ratio}"
)

ops = [Stem(sk_ratio, width_multiplier)]
channels_in = 64 * width_multiplier
ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio))
channels_in = ops[-1].channels_out
ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio))
channels_in = ops[-1].channels_out
ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio))
channels_in = ops[-1].channels_out
ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio))
channels_in = ops[-1].channels_out
self.channels_out = channels_in
self.net = nn.Sequential(*ops)

def forward(self, x, apply_fc=False):
h = self.net(x).mean(dim=[2, 3])
return [h]
35 changes: 26 additions & 9 deletions vissl/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,16 +997,33 @@ def init_model_from_consolidated_weights(
param = interpolate_position_embeddings(
model, all_layers[layername], param
)
assert all_layers[layername].shape == param.shape, (
f"{layername} have different shapes: "
f"checkpoint: {param.shape}, model: {all_layers[layername].shape}"
)
all_layers[layername].copy_(param)
if local_rank == 0:
logging.info(
f"Loaded: {layername: <{max_len_model}} of "
f"shape: {all_layers[layername].size()} from checkpoint"
if (
"heads" in layername
and not config.MODEL.FEATURE_EVAL_SETTINGS.ASSERT_HEAD_LAYER_SHAPE_INIT
):
if local_rank == 0:
logging.info(
f"Ignore shape check: {layername} "
f"checkpoint: {param.shape}, model: {all_layers[layername].shape}"
)
if all_layers[layername].shape == param.shape:
all_layers[layername].copy_(param)
if local_rank == 0:
logging.info(
f"Loaded: {layername: <{max_len_model}} of "
f"shape: {all_layers[layername].size()} from checkpoint"
)
else:
assert all_layers[layername].shape == param.shape, (
f"{layername} have different shapes: "
f"checkpoint: {param.shape}, model: {all_layers[layername].shape}"
)
all_layers[layername].copy_(param)
if local_rank == 0:
logging.info(
f"Loaded: {layername: <{max_len_model}} of "
f"shape: {all_layers[layername].size()} from checkpoint"
)

# In case the layer is ignored by settings
else:
Expand Down

0 comments on commit 1260872

Please sign in to comment.