Skip to content

Commit

Permalink
Refactor vit pooling to add more reduction options, separately callable
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 15, 2024
1 parent 9567cf6 commit 71101eb
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._forward(x)


def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x

if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'

return x


class VisionTransformer(nn.Module):
""" Vision Transformer
Expand All @@ -400,7 +425,7 @@ def __init__(
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token',
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
Expand Down Expand Up @@ -459,10 +484,10 @@ def __init__(
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'max', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU

Expand Down Expand Up @@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool = None) -> None:
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
Expand Down Expand Up @@ -756,15 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool == 'max':
x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1)
elif self.global_pool:
x = x[:, 0] # class token
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.pool(x)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
Expand Down

0 comments on commit 71101eb

Please sign in to comment.