Skip to content

Commit

Permalink
Some missed reset_classifier() type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 16, 2024
1 parent 71101eb commit b1a6f4a
Show file tree
Hide file tree
Showing 16 changed files with 33 additions and 23 deletions.
2 changes: 1 addition & 1 deletion timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/inception_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
6 changes: 3 additions & 3 deletions timm/models/metaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(
# if using MlpHead, dropout is handled by MlpHead
if num_classes > 0:
if self.use_mlp_head:
# FIXME hidden size
# FIXME not actually returning mlp hidden state right now as pre-logits.
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
self.head_hidden_size = self.num_features
else:
Expand Down Expand Up @@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
Expand Down
2 changes: 1 addition & 1 deletion timm/models/nasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/pnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_intermediates(
Expand Down
3 changes: 2 additions & 1 deletion timm/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from functools import partial
from math import ceil
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/selecsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_relpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
21 changes: 15 additions & 6 deletions timm/models/vovnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""

from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -134,9 +134,17 @@ def __init__(
else:
drop_path = None
blocks += [OsaBlock(
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
]
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=residual and i > 0,
depthwise=depthwise,
attn=attn if last_block else '',
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=drop_path
)]
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)

Expand Down Expand Up @@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_features(self, x):
Expand Down

0 comments on commit b1a6f4a

Please sign in to comment.