Skip to content

Commit

Permalink
Merge pull request #1511 from mikel-brostrom/appearance-factory-refactor
Browse files Browse the repository at this point in the history
Appearance factory refactor
  • Loading branch information
mikel-brostrom committed Jul 8, 2024
2 parents 4bd3f6f + 2f297f3 commit c1af860
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 192 deletions.
93 changes: 1 addition & 92 deletions boxmot/appearance/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,92 +1 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

from __future__ import absolute_import

from boxmot.appearance.backbones.clip.make_model import make_model
from boxmot.appearance.backbones.hacnn import HACNN
from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n
from boxmot.appearance.backbones.mlfn import mlfn
from boxmot.appearance.backbones.mobilenetv2 import (mobilenetv2_x1_0,
mobilenetv2_x1_4)
from boxmot.appearance.backbones.osnet import (osnet_ibn_x1_0, osnet_x0_5,
osnet_x0_25, osnet_x0_75,
osnet_x1_0)
from boxmot.appearance.backbones.osnet_ain import (osnet_ain_x0_5,
osnet_ain_x0_25,
osnet_ain_x0_75,
osnet_ain_x1_0)
from boxmot.appearance.backbones.resnet import resnet50, resnet101

NR_CLASSES_DICT = {'market1501': 751, 'duke': 702, 'veri': 576, 'vehicleid': 576}


__model_factory = {
# image classification models
"resnet50": resnet50,
"resnet101": resnet101,
"mobilenetv2_x1_0": mobilenetv2_x1_0,
"mobilenetv2_x1_4": mobilenetv2_x1_4,
# reid-specific models
"hacnn": HACNN,
"mlfn": mlfn,
"osnet_x1_0": osnet_x1_0,
"osnet_x0_75": osnet_x0_75,
"osnet_x0_5": osnet_x0_5,
"osnet_x0_25": osnet_x0_25,
"osnet_ibn_x1_0": osnet_ibn_x1_0,
"osnet_ain_x1_0": osnet_ain_x1_0,
"osnet_ain_x0_75": osnet_ain_x0_75,
"osnet_ain_x0_5": osnet_ain_x0_5,
"osnet_ain_x0_25": osnet_ain_x0_25,
"lmbn_n": LMBN_n,
"clip": make_model,
}


def show_avai_models():
"""Displays available models.
Examples::
>>> from torchreid import models
>>> models.show_avai_models()
"""
print(list(__model_factory.keys()))


def get_nr_classes(weigths):
num_classes = [value for key, value in NR_CLASSES_DICT.items() if key in str(weigths.name)]
if len(num_classes) == 0:
num_classes = 1
else:
num_classes = num_classes[0]
return num_classes


def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True):
"""A function wrapper for building a model.
Args:
name (str): model name.
num_classes (int): number of training identities.
loss (str, optional): loss function to optimize the model. Currently
supports "softmax" and "triplet". Default is "softmax".
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
Default is True.
use_gpu (bool, optional): whether to use gpu. Default is True.
Returns:
nn.Module
Examples::
>>> from torchreid import models
>>> model = models.build_model('resnet50', 751, loss='softmax')
"""
avai_models = list(__model_factory.keys())
if name not in avai_models:
raise KeyError("Unknown model: {}. Must be one of {}".format(name, avai_models))
if 'clip' in name:
from boxmot.appearance.backbones.clip.config.defaults import _C as cfg
return __model_factory[name](cfg, num_class=num_classes, camera_num=2, view_num=1)
return __model_factory[name](
num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu
)
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
5 changes: 3 additions & 2 deletions boxmot/appearance/backends/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import gdown
import numpy as np
from abc import ABC, abstractmethod
from boxmot.appearance.backbones import build_model, get_nr_classes
from boxmot.appearance.reid_model_factory import (
get_model_name,
get_model_url
get_model_url,
build_model,
get_nr_classes
)

class BaseModelBackend:
Expand Down
5 changes: 3 additions & 2 deletions boxmot/appearance/reid_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from torch.utils.mobile_optimizer import optimize_for_mobile

from boxmot.appearance import export_formats
from boxmot.appearance.backbones import build_model, get_nr_classes
from boxmot.appearance.reid_model_factory import (get_model_name,
load_pretrained_weights)
load_pretrained_weights,build_model,
get_nr_classes
)
from boxmot.utils import WEIGHTS
from boxmot.utils import logger as LOGGER
from boxmot.utils.checks import TestRequirements
Expand Down
195 changes: 99 additions & 96 deletions boxmot/appearance/reid_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

import sys
import time
from collections import OrderedDict

import torch

from torch import nn
from boxmot.utils import logger as LOGGER

# Model Factory and Construction
from boxmot.appearance.backbones.clip.make_model import make_model
from boxmot.appearance.backbones.hacnn import HACNN
from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n
from boxmot.appearance.backbones.mlfn import mlfn
from boxmot.appearance.backbones.mobilenetv2 import mobilenetv2_x1_0, mobilenetv2_x1_4
from boxmot.appearance.backbones.osnet import (
osnet_ibn_x1_0,
osnet_x0_5,
osnet_x0_25,
osnet_x0_75,
osnet_x1_0,
)
from boxmot.appearance.backbones.osnet_ain import (
osnet_ain_x0_5,
osnet_ain_x0_25,
osnet_ain_x0_75,
osnet_ain_x1_0,
)
from boxmot.appearance.backbones.resnet import resnet50, resnet101

# Constants
__model_types = [
"resnet50",
"resnet101",
Expand All @@ -25,8 +45,6 @@
"clip",
]

lmbn_loc = 'https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/'

__trained_urls = {
# resnet50
"resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV",
Expand Down Expand Up @@ -67,93 +85,65 @@
"osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ",
"osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal",
# lmbn
"lmbn_n_duke.pt": lmbn_loc + "lmbn_n_duke.pth",
"lmbn_n_market.pt": lmbn_loc + "lmbn_n_market.pth",
"lmbn_n_cuhk03_d.pt": lmbn_loc + "lmbn_n_cuhk03_d.pth",
"lmbn_n_duke.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth",
"lmbn_n_market.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth",
"lmbn_n_cuhk03_d.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth",
# clip
"clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7",
"clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY",
"clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e",
"clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5"
}

NR_CLASSES_DICT = {
'market1501': 751,
'duke': 702,
'veri': 576,
'vehicleid': 576
}

def show_downloadable_models():
LOGGER.info("\nAvailable .pt ReID models for automatic download")
LOGGER.info(list(__trained_urls.keys()))


def get_model_url(model):
if model.name in __trained_urls:
return __trained_urls[model.name]
else:
None


def is_model_in_model_types(model):
if model.name in __model_types:
return True
else:
return False
__model_factory = {
"resnet50": resnet50,
"resnet101": resnet101,
"mobilenetv2_x1_0": mobilenetv2_x1_0,
"mobilenetv2_x1_4": mobilenetv2_x1_4,
"hacnn": HACNN,
"mlfn": mlfn,
"osnet_x1_0": osnet_x1_0,
"osnet_x0_75": osnet_x0_75,
"osnet_x0_5": osnet_x0_5,
"osnet_x0_25": osnet_x0_25,
"osnet_ibn_x1_0": osnet_ibn_x1_0,
"osnet_ain_x1_0": osnet_ain_x1_0,
"osnet_ain_x0_75": osnet_ain_x0_75,
"osnet_ain_x0_5": osnet_ain_x0_5,
"osnet_ain_x0_25": osnet_ain_x0_25,
"lmbn_n": LMBN_n,
"clip": make_model,
}


# Utility functions
def show_downloadable_models():
LOGGER.info("Available .pt ReID models for automatic download")
LOGGER.info(list(__trained_urls.keys()))


def get_model_name(model):
for x in __model_types:
if x in model.name:
return x
return None


def download_url(url, dst):
"""Downloads file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
from six.moves import urllib

LOGGER.info('* url="{}"'.format(url))
LOGGER.info('* destination="{}"'.format(dst))

def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write(
"\r...%d%%, %d MB, %d KB/s, %d seconds passed"
% (percent, progress_size / (1024 * 1024), speed, duration)
)
sys.stdout.flush()

urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write("\n")
def get_model_url(model):
if model.name in __trained_urls:
return __trained_urls[model.name]
else:
None


def load_pretrained_weights(model, weight_path):
r"""Loads pretrianed weights to model.
Features::
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing "module.".
Args:
model (nn.Module): network model.
weight_path (str): path to pretrained weights.
Examples::
>>> from boxmot.appearance.backbones import build_model
>>> from boxmot.appearance.reid_model_factory import load_pretrained_weights
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> model = build_model()
>>> load_pretrained_weights(model, weight_path)
"""

"""Loads pretrained weights to a model."""
if not torch.cuda.is_available():
checkpoint = torch.load(weight_path, map_location=torch.device("cpu"))
else:
Expand All @@ -168,25 +158,13 @@ def load_pretrained_weights(model, weight_path):

if "lmbn" in str(weight_path):
model.load_state_dict(model_dict, strict=True)
elif "clip" in str(weight_path):
def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None):
_, image_features, image_features_proj = old_forward(x, cv_emb)
return torch.cat([image_features[:, 0], image_features_proj[:, 0]], dim=1)
# print('model.load_param(str(weight_path))', str(weight_path))
model.load_param(str(weight_path))
model = model.image_encoder
# old_forward = model.forward
# model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs)
LOGGER.success(
f'Successfully loaded pretrained weights from "{weight_path}"'
)
else:
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []

for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:] # discard module.
k = k[7:] # remove 'module.' prefix if present

if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
Expand All @@ -199,16 +177,41 @@ def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None):

if len(matched_layers) == 0:
LOGGER.debug(
f'The pretrained weights "{weight_path}" cannot be loaded, '
"please check the key names manually "
"(** ignored and continue **)"
f"Pretrained weights from {weight_path} cannot be loaded. Check key names manually."
)
else:
LOGGER.success(
f'Successfully loaded pretrained weights from "{weight_path}"'
LOGGER.success(f"Loaded pretrained weights from {weight_path}")

if len(discarded_layers) > 0:
LOGGER.debug(
f"Discarded layers due to unmatched keys or layer size: {discarded_layers}"
)
if len(discarded_layers) > 0:
LOGGER.debug(
"The following layers are discarded "
f"due to unmatched keys or layer size: {*discarded_layers,}"
)


def show_available_models():
"""Displays available models."""
LOGGER.info("Available models:")
LOGGER.info(list(__model_factory.keys()))


def get_nr_classes(weights):
"""Returns the number of classes based on weights."""
num_classes = NR_CLASSES_DICT.get(weights.name.split('_')[1], 1)
return num_classes


def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True):
"""Builds a model based on specified parameters."""
available_models = list(__model_factory.keys())

if name not in available_models:
raise KeyError(f"Unknown model '{name}'. Must be one of {available_models}")

if 'clip' in name:
# Assuming clip requires special configuration, adjust as needed
from boxmot.appearance.backbones.clip.config.defaults import _C as cfg
return __model_factory[name](cfg, num_class=num_classes, camera_num=2, view_num=1)

return __model_factory[name](
num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu
)

0 comments on commit c1af860

Please sign in to comment.