From ae3ced93ba257902eb508033976d9e643b01ae73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 14:36:31 +0200 Subject: [PATCH 01/11] cleaner model handling --- boxmot/appearance/reid_model_factory.py | 259 +++++++++++------------- 1 file changed, 113 insertions(+), 146 deletions(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index f11ba03d0..e54f8feda 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -8,137 +8,117 @@ from boxmot.utils import logger as LOGGER -__model_types = [ - "resnet50", - "resnet101", - "mlfn", - "hacnn", - "mobilenetv2_x1_0", - "mobilenetv2_x1_4", - "osnet_x1_0", - "osnet_x0_75", - "osnet_x0_5", - "osnet_x0_25", - "osnet_ibn_x1_0", - "osnet_ain_x1_0", - "lmbn_n", - "clip", -] - -lmbn_loc = 'https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/' - -__trained_urls = { - # resnet50 - "resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV", - "resnet50_dukemtmcreid.pt": "https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg", - "resnet50_msmt17.pt": "https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj", - "resnet50_fc512_market1501.pt": "https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt", - "resnet50_fc512_dukemtmcreid.pt": "https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx", - "resnet50_fc512_msmt17.pt": "https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud", - # mlfn - "mlfn_market1501.pt": "https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS", - "mlfn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum", - "mlfn_msmt17.pt": "https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-", - # hacnn - "hacnn_market1501.pt": "https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF", - "hacnn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH", - "hacnn_msmt17.pt": "https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ", - # mobilenetv2 - "mobilenetv2_x1_0_market1501.pt": "https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp", - "mobilenetv2_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds", - "mobilenetv2_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ", - "mobilenetv2_x1_4_market1501.pt": "https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5", - "mobilenetv2_x1_4_dukemtmcreid.pt": "https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN", - "mobilenetv2_x1_4_msmt17.pt": "https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz", - # osnet - "osnet_x1_0_market1501.pt": "https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA", - "osnet_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq", - "osnet_x1_0_msmt17.pt": "https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M", - "osnet_x0_75_market1501.pt": "https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer", - "osnet_x0_75_dukemtmcreid.pt": "https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or", - "osnet_x0_75_msmt17.pt": "https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc", - "osnet_x0_5_market1501.pt": "https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT", - "osnet_x0_5_dukemtmcreid.pt": "https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu", - "osnet_x0_5_msmt17.pt": "https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv", - "osnet_x0_25_market1501.pt": "https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj", - "osnet_x0_25_dukemtmcreid.pt": "https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l", - "osnet_x0_25_msmt17.pt": "https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF", - # osnet_ain | osnet_ibn - "osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ", - "osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal", - # lmbn - "lmbn_n_duke.pt": lmbn_loc + "lmbn_n_duke.pth", - "lmbn_n_market.pt": lmbn_loc + "lmbn_n_market.pth", - "lmbn_n_cuhk03_d.pt": lmbn_loc + "lmbn_n_cuhk03_d.pth", - # clip - "clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7", - "clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY", - "clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e", - "clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5" +class ModelType(Enum): + RESNET50 = "resnet50" + RESNET101 = "resnet101" + MLFN = "mlfn" + HACNN = "hacnn" + MOBILENETV2_X1_0 = "mobilenetv2_x1_0" + MOBILENETV2_X1_4 = "mobilenetv2_x1_4" + OSNET_X1_0 = "osnet_x1_0" + OSNET_X0_75 = "osnet_x0_75" + OSNET_X0_5 = "osnet_x0_5" + OSNET_X0_25 = "osnet_x0_25" + OSNET_IBN_X1_0 = "osnet_ibn_x1_0" + OSNET_AIN_X1_0 = "osnet_ain_x1_0" + LMBN_N = "lmbn_n" + CLIP = "clip" + + +@dataclass +class ModelInfo: + name: ModelType + url: str + + +trained_urls = { + ModelType.RESNET50: [ + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV"), + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg"), + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj"), + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt"), + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx"), + ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud"), + ], + ModelType.MLFN: [ + ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS"), + ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum"), + ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-"), + ], + ModelType.HACNN: [ + ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF"), + ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH"), + ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ"), + ], + ModelType.MOBILENETV2_X1_0: [ + ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp"), + ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds"), + ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ"), + ], + ModelType.MOBILENETV2_X1_4: [ + ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5"), + ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN"), + ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz"), + ], + ModelType.OSNET_X1_0: [ + ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA"), + ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq"), + ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M"), + ], + ModelType.OSNET_X0_75: [ + ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer"), + ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or"), + ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc"), + ], + ModelType.OSNET_X0_5: [ + ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT"), + ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu"), + ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv"), + ], + ModelType.OSNET_X0_25: [ + ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj"), + ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l"), + ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF"), + ], + ModelType.OSNET_IBN_X1_0: [ + ModelInfo(name=ModelType.OSNET_IBN_X1_0, url="https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ") + ], + ModelType.OSNET_AIN_X1_0: [ + ModelInfo(name=ModelType.OSNET_AIN_X1_0, url="https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal") + ], + ModelType.LMBN_N: [ + ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth"), + ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth"), + ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth"), + ], + ModelType.CLIP: [ + ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7"), + ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY"), + ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e"), + ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5"), + ], } def show_downloadable_models(): LOGGER.info("\nAvailable .pt ReID models for automatic download") - LOGGER.info(list(__trained_urls.keys())) + for model_type, model_infos in trained_urls.items(): + for model_info in model_infos: + LOGGER.info(f"{model_type.value} - {model_info.url}") -def get_model_url(model): - if model.name in __trained_urls: - return __trained_urls[model.name] - else: - None - - -def is_model_in_model_types(model): - if model.name in __model_types: - return True - else: - return False +def get_model_urls(model_name: ModelType): + return trained_urls.get(model_name, []) def get_model_name(model): - for x in __model_types: - if x in model.name: - return x - return None - - -def download_url(url, dst): - """Downloads file from a url to a destination. - - Args: - url (str): url to download file. - dst (str): destination path. - """ - from six.moves import urllib - - LOGGER.info('* url="{}"'.format(url)) - LOGGER.info('* destination="{}"'.format(dst)) - - def _reporthook(count, block_size, total_size): - global start_time - if count == 0: - start_time = time.time() - return - duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) - percent = int(count * block_size * 100 / total_size) - sys.stdout.write( - "\r...%d%%, %d MB, %d KB/s, %d seconds passed" - % (percent, progress_size / (1024 * 1024), speed, duration) - ) - sys.stdout.flush() - - urllib.request.urlretrieve(url, dst, _reporthook) - sys.stdout.write("\n") + return next((x for x in __model_types if x in model.name), None) def load_pretrained_weights(model, weight_path): - r"""Loads pretrianed weights to model. + """Loads pretrained weights to model. - Features:: + Features: - Incompatible layers (unmatched in name or size) will be ignored. - Can automatically deal with keys containing "module.". @@ -146,7 +126,7 @@ def load_pretrained_weights(model, weight_path): model (nn.Module): network model. weight_path (str): path to pretrained weights. - Examples:: + Examples: >>> from boxmot.appearance.backbones import build_model >>> from boxmot.appearance.reid_model_factory import load_pretrained_weights >>> weight_path = 'log/my_model/model-best.pth.tar' @@ -154,39 +134,29 @@ def load_pretrained_weights(model, weight_path): >>> load_pretrained_weights(model, weight_path) """ - if not torch.cuda.is_available(): - checkpoint = torch.load(weight_path, map_location=torch.device("cpu")) - else: - checkpoint = torch.load(weight_path) - - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint + checkpoint = torch.load(weight_path, map_location=torch.device("cpu") if not torch.cuda.is_available() else None) + state_dict = checkpoint.get("state_dict", checkpoint) model_dict = model.state_dict() - if "lmbn" in str(weight_path): + if "lmbn" in weight_path: model.load_state_dict(model_dict, strict=True) - elif "clip" in str(weight_path): + elif "clip" in weight_path: def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None): _, image_features, image_features_proj = old_forward(x, cv_emb) return torch.cat([image_features[:, 0], image_features_proj[:, 0]], dim=1) - # print('model.load_param(str(weight_path))', str(weight_path)) + model.load_param(str(weight_path)) model = model.image_encoder - # old_forward = model.forward - # model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs) - LOGGER.success( - f'Successfully loaded pretrained weights from "{weight_path}"' - ) + + LOGGER.success(f'Successfully loaded pretrained weights from "{weight_path}"') else: new_state_dict = OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): if k.startswith("module."): - k = k[7:] # discard module. + k = k[7:] if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v @@ -197,18 +167,15 @@ def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None): model_dict.update(new_state_dict) model.load_state_dict(model_dict) - if len(matched_layers) == 0: + if not matched_layers: LOGGER.debug( f'The pretrained weights "{weight_path}" cannot be loaded, ' - "please check the key names manually " - "(** ignored and continue **)" + "please check the key names manually (** ignored and continue **)" ) else: - LOGGER.success( - f'Successfully loaded pretrained weights from "{weight_path}"' - ) - if len(discarded_layers) > 0: + LOGGER.success(f'Successfully loaded pretrained weights from "{weight_path}"') + if discarded_layers: LOGGER.debug( - "The following layers are discarded " - f"due to unmatched keys or layer size: {*discarded_layers,}" - ) + "The following layers are discarded due to unmatched keys or layer size: " + f"{', '.join(discarded_layers)}" + ) \ No newline at end of file From e71ff27c786aebb0f3fcd8f0c9ffb54e252cb19b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 14:41:37 +0200 Subject: [PATCH 02/11] move out stuff from __init__ --- boxmot/appearance/backbones/__init__.py | 93 +--------------------- boxmot/appearance/backends/base_backend.py | 17 +++- boxmot/appearance/reid_model_factory.py | 36 ++++++++- 3 files changed, 51 insertions(+), 95 deletions(-) diff --git a/boxmot/appearance/backbones/__init__.py b/boxmot/appearance/backbones/__init__.py index 38afeba9d..f6d3b5e38 100644 --- a/boxmot/appearance/backbones/__init__.py +++ b/boxmot/appearance/backbones/__init__.py @@ -1,92 +1 @@ -# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license - -from __future__ import absolute_import - -from boxmot.appearance.backbones.clip.make_model import make_model -from boxmot.appearance.backbones.hacnn import HACNN -from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n -from boxmot.appearance.backbones.mlfn import mlfn -from boxmot.appearance.backbones.mobilenetv2 import (mobilenetv2_x1_0, - mobilenetv2_x1_4) -from boxmot.appearance.backbones.osnet import (osnet_ibn_x1_0, osnet_x0_5, - osnet_x0_25, osnet_x0_75, - osnet_x1_0) -from boxmot.appearance.backbones.osnet_ain import (osnet_ain_x0_5, - osnet_ain_x0_25, - osnet_ain_x0_75, - osnet_ain_x1_0) -from boxmot.appearance.backbones.resnet import resnet50, resnet101 - -NR_CLASSES_DICT = {'market1501': 751, 'duke': 702, 'veri': 576, 'vehicleid': 576} - - -__model_factory = { - # image classification models - "resnet50": resnet50, - "resnet101": resnet101, - "mobilenetv2_x1_0": mobilenetv2_x1_0, - "mobilenetv2_x1_4": mobilenetv2_x1_4, - # reid-specific models - "hacnn": HACNN, - "mlfn": mlfn, - "osnet_x1_0": osnet_x1_0, - "osnet_x0_75": osnet_x0_75, - "osnet_x0_5": osnet_x0_5, - "osnet_x0_25": osnet_x0_25, - "osnet_ibn_x1_0": osnet_ibn_x1_0, - "osnet_ain_x1_0": osnet_ain_x1_0, - "osnet_ain_x0_75": osnet_ain_x0_75, - "osnet_ain_x0_5": osnet_ain_x0_5, - "osnet_ain_x0_25": osnet_ain_x0_25, - "lmbn_n": LMBN_n, - "clip": make_model, -} - - -def show_avai_models(): - """Displays available models. - - Examples:: - >>> from torchreid import models - >>> models.show_avai_models() - """ - print(list(__model_factory.keys())) - - -def get_nr_classes(weigths): - num_classes = [value for key, value in NR_CLASSES_DICT.items() if key in str(weigths.name)] - if len(num_classes) == 0: - num_classes = 1 - else: - num_classes = num_classes[0] - return num_classes - - -def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True): - """A function wrapper for building a model. - - Args: - name (str): model name. - num_classes (int): number of training identities. - loss (str, optional): loss function to optimize the model. Currently - supports "softmax" and "triplet". Default is "softmax". - pretrained (bool, optional): whether to load ImageNet-pretrained weights. - Default is True. - use_gpu (bool, optional): whether to use gpu. Default is True. - - Returns: - nn.Module - - Examples:: - >>> from torchreid import models - >>> model = models.build_model('resnet50', 751, loss='softmax') - """ - avai_models = list(__model_factory.keys()) - if name not in avai_models: - raise KeyError("Unknown model: {}. Must be one of {}".format(name, avai_models)) - if 'clip' in name: - from boxmot.appearance.backbones.clip.config.defaults import _C as cfg - return __model_factory[name](cfg, num_class=num_classes, camera_num=2, view_num=1) - return __model_factory[name]( - num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu - ) +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license \ No newline at end of file diff --git a/boxmot/appearance/backends/base_backend.py b/boxmot/appearance/backends/base_backend.py index 687bf20dc..8889a2600 100644 --- a/boxmot/appearance/backends/base_backend.py +++ b/boxmot/appearance/backends/base_backend.py @@ -3,12 +3,25 @@ import gdown import numpy as np from abc import ABC, abstractmethod -from boxmot.appearance.backbones import build_model, get_nr_classes +from boxmot.appearance.backbones import build_model from boxmot.appearance.reid_model_factory import ( get_model_name, - get_model_url + get_model_url, + build_model ) +NR_CLASSES_DICT = {'market1501': 751, 'duke': 702, 'veri': 576, 'vehicleid': 576} + + +def get_nr_classes(weigths): + num_classes = [value for key, value in NR_CLASSES_DICT.items() if key in str(weigths.name)] + if len(num_classes) == 0: + num_classes = 1 + else: + num_classes = num_classes[0] + return num_classes + + class BaseModelBackend: def __init__(self, weights, device, half): self.weights = weights[0] if isinstance(weights, list) else weights diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index e54f8feda..c214a4415 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -178,4 +178,38 @@ def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None): LOGGER.debug( "The following layers are discarded due to unmatched keys or layer size: " f"{', '.join(discarded_layers)}" - ) \ No newline at end of file + ) + + +def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True): + """A function wrapper for building a model. + + Args: + name (str): model name. + num_classes (int): number of training identities. + loss (str, optional): loss function to optimize the model. Currently + supports "softmax" and "triplet". Default is "softmax". + pretrained (bool, optional): whether to load ImageNet-pretrained weights. + Default is True. + use_gpu (bool, optional): whether to use gpu. Default is True. + + Returns: + nn.Module + + Examples:: + >>> from torchreid import models + >>> model = models.build_model('resnet50', 751, loss='softmax') + """ + if name not in __model_factory: + raise KeyError(f"Unknown model: {name}. Must be one of {list(__model_factory.keys())}") + + if 'clip' in name: + from boxmot.appearance.backbones.clip.config.defaults import _C as cfg + return __model_factory[name](cfg, num_class=num_classes, camera_num=2, view_num=1) + + return __model_factory[name]( + num_classes=num_classes, + loss=loss, + pretrained=pretrained, + use_gpu=use_gpu + ) \ No newline at end of file From e2af4e51d21c021028b87ccc9deff1e7fcf11e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 14:44:20 +0200 Subject: [PATCH 03/11] move out stuff from __init__ --- boxmot/appearance/backends/base_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/boxmot/appearance/backends/base_backend.py b/boxmot/appearance/backends/base_backend.py index 8889a2600..eeb0cdbb2 100644 --- a/boxmot/appearance/backends/base_backend.py +++ b/boxmot/appearance/backends/base_backend.py @@ -3,7 +3,6 @@ import gdown import numpy as np from abc import ABC, abstractmethod -from boxmot.appearance.backbones import build_model from boxmot.appearance.reid_model_factory import ( get_model_name, get_model_url, From 4ea550f95e2aa175e1d4cb95965253614a98aa5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 14:50:43 +0200 Subject: [PATCH 04/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index c214a4415..8c4e693cf 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -3,6 +3,9 @@ import sys import time from collections import OrderedDict +from enum import Enum +from dataclasses import dataclass + import torch From 77d40ccba5b7c87dcf4f9c3a39386d03f06be3a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 14:54:04 +0200 Subject: [PATCH 05/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 8c4e693cf..2b47d82f8 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -110,7 +110,7 @@ def show_downloadable_models(): LOGGER.info(f"{model_type.value} - {model_info.url}") -def get_model_urls(model_name: ModelType): +def get_model_url(model_name: ModelType): return trained_urls.get(model_name, []) From bb67b06dbec0bc793eb1d86bfc6e7a9a6cadeca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 15:02:18 +0200 Subject: [PATCH 06/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 2b47d82f8..84f76fe3f 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -110,8 +110,11 @@ def show_downloadable_models(): LOGGER.info(f"{model_type.value} - {model_info.url}") -def get_model_url(model_name: ModelType): - return trained_urls.get(model_name, []) +def get_model_name(model): + try: + return next((x for x in __model_types if x in model.name), None) + except AttributeError: + return None def get_model_name(model): From fdb091eaa2d9ebe8516ec6aa3d7bc13b47e660c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 15:04:52 +0200 Subject: [PATCH 07/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 84f76fe3f..715d05a3a 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -116,10 +116,12 @@ def get_model_name(model): except AttributeError: return None - -def get_model_name(model): - return next((x for x in __model_types if x in model.name), None) - +def get_model_url(model): + for model_type, model_infos in trained_urls.items(): + for model_info in model_infos: + if model_info.name.value == model.name: + return model_info.url + return None def load_pretrained_weights(model, weight_path): """Loads pretrained weights to model. From b44b992950ba53a9cf70a4b7d36b61de059066ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 16:28:34 +0200 Subject: [PATCH 08/11] move out stuff from __init__ --- boxmot/appearance/backends/base_backend.py | 15 +- boxmot/appearance/reid_model_factory.py | 301 +++++++++------------ 2 files changed, 128 insertions(+), 188 deletions(-) diff --git a/boxmot/appearance/backends/base_backend.py b/boxmot/appearance/backends/base_backend.py index eeb0cdbb2..09a1501cc 100644 --- a/boxmot/appearance/backends/base_backend.py +++ b/boxmot/appearance/backends/base_backend.py @@ -6,21 +6,10 @@ from boxmot.appearance.reid_model_factory import ( get_model_name, get_model_url, - build_model + build_model, + get_nr_classes ) -NR_CLASSES_DICT = {'market1501': 751, 'duke': 702, 'veri': 576, 'vehicleid': 576} - - -def get_nr_classes(weigths): - num_classes = [value for key, value in NR_CLASSES_DICT.items() if key in str(weigths.name)] - if len(num_classes) == 0: - num_classes = 1 - else: - num_classes = num_classes[0] - return num_classes - - class BaseModelBackend: def __init__(self, weights, device, half): self.weights = weights[0] if isinstance(weights, list) else weights diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 715d05a3a..62eacb18c 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -1,170 +1,127 @@ -# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license - import sys import time from collections import OrderedDict -from enum import Enum -from dataclasses import dataclass - import torch - +from torch import nn from boxmot.utils import logger as LOGGER -class ModelType(Enum): - RESNET50 = "resnet50" - RESNET101 = "resnet101" - MLFN = "mlfn" - HACNN = "hacnn" - MOBILENETV2_X1_0 = "mobilenetv2_x1_0" - MOBILENETV2_X1_4 = "mobilenetv2_x1_4" - OSNET_X1_0 = "osnet_x1_0" - OSNET_X0_75 = "osnet_x0_75" - OSNET_X0_5 = "osnet_x0_5" - OSNET_X0_25 = "osnet_x0_25" - OSNET_IBN_X1_0 = "osnet_ibn_x1_0" - OSNET_AIN_X1_0 = "osnet_ain_x1_0" - LMBN_N = "lmbn_n" - CLIP = "clip" - - -@dataclass -class ModelInfo: - name: ModelType - url: str - - -trained_urls = { - ModelType.RESNET50: [ - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV"), - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg"), - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj"), - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt"), - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx"), - ModelInfo(name=ModelType.RESNET50, url="https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud"), - ], - ModelType.MLFN: [ - ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS"), - ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum"), - ModelInfo(name=ModelType.MLFN, url="https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-"), - ], - ModelType.HACNN: [ - ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF"), - ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH"), - ModelInfo(name=ModelType.HACNN, url="https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ"), - ], - ModelType.MOBILENETV2_X1_0: [ - ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp"), - ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds"), - ModelInfo(name=ModelType.MOBILENETV2_X1_0, url="https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ"), - ], - ModelType.MOBILENETV2_X1_4: [ - ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5"), - ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN"), - ModelInfo(name=ModelType.MOBILENETV2_X1_4, url="https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz"), - ], - ModelType.OSNET_X1_0: [ - ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA"), - ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq"), - ModelInfo(name=ModelType.OSNET_X1_0, url="https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M"), - ], - ModelType.OSNET_X0_75: [ - ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer"), - ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or"), - ModelInfo(name=ModelType.OSNET_X0_75, url="https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc"), - ], - ModelType.OSNET_X0_5: [ - ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT"), - ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu"), - ModelInfo(name=ModelType.OSNET_X0_5, url="https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv"), - ], - ModelType.OSNET_X0_25: [ - ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj"), - ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l"), - ModelInfo(name=ModelType.OSNET_X0_25, url="https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF"), - ], - ModelType.OSNET_IBN_X1_0: [ - ModelInfo(name=ModelType.OSNET_IBN_X1_0, url="https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ") - ], - ModelType.OSNET_AIN_X1_0: [ - ModelInfo(name=ModelType.OSNET_AIN_X1_0, url="https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal") - ], - ModelType.LMBN_N: [ - ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth"), - ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth"), - ModelInfo(name=ModelType.LMBN_N, url="https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth"), - ], - ModelType.CLIP: [ - ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7"), - ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY"), - ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e"), - ModelInfo(name=ModelType.CLIP, url="https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5"), - ], +# Model Factory and Construction +from boxmot.appearance.backbones.clip.make_model import make_model +from boxmot.appearance.backbones.hacnn import HACNN +from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n +from boxmot.appearance.backbones.mlfn import mlfn +from boxmot.appearance.backbones.mobilenetv2 import mobilenetv2_x1_0, mobilenetv2_x1_4 +from boxmot.appearance.backbones.osnet import ( + osnet_ibn_x1_0, + osnet_x0_5, + osnet_x0_25, + osnet_x0_75, + osnet_x1_0, +) +from boxmot.appearance.backbones.osnet_ain import ( + osnet_ain_x0_5, + osnet_ain_x0_25, + osnet_ain_x0_75, + osnet_ain_x1_0, +) +from boxmot.appearance.backbones.resnet import resnet50, resnet101 + +# Constants +__model_types = [ + "resnet50", + "resnet101", + "mlfn", + "hacnn", + "mobilenetv2_x1_0", + "mobilenetv2_x1_4", + "osnet_x1_0", + "osnet_x0_75", + "osnet_x0_5", + "osnet_x0_25", + "osnet_ibn_x1_0", + "osnet_ain_x1_0", + "lmbn_n", + "clip", +] + +__trained_urls = { + # Example URLs for pretrained models (partial list for brevity) + "resnet50_market1501.pt": "https://example.com/resnet50_market1501.pt", + "mlfn_market1501.pt": "https://example.com/mlfn_market1501.pt", + # Add more URLs as needed } +NR_CLASSES_DICT = { + 'market1501': 751, + 'duke': 702, + 'veri': 576, + 'vehicleid': 576 +} -def show_downloadable_models(): - LOGGER.info("\nAvailable .pt ReID models for automatic download") - for model_type, model_infos in trained_urls.items(): - for model_info in model_infos: - LOGGER.info(f"{model_type.value} - {model_info.url}") +__model_factory = { + "resnet50": resnet50, + "resnet101": resnet101, + "mobilenetv2_x1_0": mobilenetv2_x1_0, + "mobilenetv2_x1_4": mobilenetv2_x1_4, + "hacnn": HACNN, + "mlfn": mlfn, + "osnet_x1_0": osnet_x1_0, + "osnet_x0_75": osnet_x0_75, + "osnet_x0_5": osnet_x0_5, + "osnet_x0_25": osnet_x0_25, + "osnet_ibn_x1_0": osnet_ibn_x1_0, + "osnet_ain_x1_0": osnet_ain_x1_0, + "osnet_ain_x0_75": osnet_ain_x0_75, + "osnet_ain_x0_5": osnet_ain_x0_5, + "osnet_ain_x0_25": osnet_ain_x0_25, + "lmbn_n": LMBN_n, + "clip": make_model, +} +# Utility functions +def show_downloadable_models(): + LOGGER.info("Available .pt ReID models for automatic download") + LOGGER.info(list(__trained_urls.keys())) + + def get_model_name(model): - try: - return next((x for x in __model_types if x in model.name), None) - except AttributeError: - return None - -def get_model_url(model): - for model_type, model_infos in trained_urls.items(): - for model_info in model_infos: - if model_info.name.value == model.name: - return model_info.url + for x in __model_types: + if x in model.name: + return x return None -def load_pretrained_weights(model, weight_path): - """Loads pretrained weights to model. - - Features: - - Incompatible layers (unmatched in name or size) will be ignored. - - Can automatically deal with keys containing "module.". +def get_model_url(model): + if model.name in __trained_urls: + return __trained_urls[model.name] + else: + None - Args: - model (nn.Module): network model. - weight_path (str): path to pretrained weights. - Examples: - >>> from boxmot.appearance.backbones import build_model - >>> from boxmot.appearance.reid_model_factory import load_pretrained_weights - >>> weight_path = 'log/my_model/model-best.pth.tar' - >>> model = build_model() - >>> load_pretrained_weights(model, weight_path) - """ +def load_pretrained_weights(model, weight_path): + """Loads pretrained weights to a model.""" + if not torch.cuda.is_available(): + checkpoint = torch.load(weight_path, map_location=torch.device("cpu")) + else: + checkpoint = torch.load(weight_path) - checkpoint = torch.load(weight_path, map_location=torch.device("cpu") if not torch.cuda.is_available() else None) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint - state_dict = checkpoint.get("state_dict", checkpoint) model_dict = model.state_dict() - if "lmbn" in weight_path: + if "lmbn" in str(weight_path): model.load_state_dict(model_dict, strict=True) - elif "clip" in weight_path: - def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None): - _, image_features, image_features_proj = old_forward(x, cv_emb) - return torch.cat([image_features[:, 0], image_features_proj[:, 0]], dim=1) - - model.load_param(str(weight_path)) - model = model.image_encoder - - LOGGER.success(f'Successfully loaded pretrained weights from "{weight_path}"') else: new_state_dict = OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): if k.startswith("module."): - k = k[7:] + k = k[7:] # remove 'module.' prefix if present if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v @@ -175,49 +132,43 @@ def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None): model_dict.update(new_state_dict) model.load_state_dict(model_dict) - if not matched_layers: + if len(matched_layers) == 0: LOGGER.debug( - f'The pretrained weights "{weight_path}" cannot be loaded, ' - "please check the key names manually (** ignored and continue **)" + f"Pretrained weights from {weight_path} cannot be loaded. Check key names manually." ) else: - LOGGER.success(f'Successfully loaded pretrained weights from "{weight_path}"') - if discarded_layers: - LOGGER.debug( - "The following layers are discarded due to unmatched keys or layer size: " - f"{', '.join(discarded_layers)}" - ) - - + LOGGER.success(f"Loaded pretrained weights from {weight_path}") + + if len(discarded_layers) > 0: + LOGGER.debug( + f"Discarded layers due to unmatched keys or layer size: {discarded_layers}" + ) + + +def show_available_models(): + """Displays available models.""" + LOGGER.info("Available models:") + LOGGER.info(list(__model_factory.keys())) + + +def get_nr_classes(weights): + """Returns the number of classes based on weights.""" + num_classes = NR_CLASSES_DICT.get(weights.name.split('_')[1], 1) + return num_classes + + def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True): - """A function wrapper for building a model. - - Args: - name (str): model name. - num_classes (int): number of training identities. - loss (str, optional): loss function to optimize the model. Currently - supports "softmax" and "triplet". Default is "softmax". - pretrained (bool, optional): whether to load ImageNet-pretrained weights. - Default is True. - use_gpu (bool, optional): whether to use gpu. Default is True. - - Returns: - nn.Module - - Examples:: - >>> from torchreid import models - >>> model = models.build_model('resnet50', 751, loss='softmax') - """ - if name not in __model_factory: - raise KeyError(f"Unknown model: {name}. Must be one of {list(__model_factory.keys())}") + """Builds a model based on specified parameters.""" + available_models = list(__model_factory.keys()) + + if name not in available_models: + raise KeyError(f"Unknown model '{name}'. Must be one of {available_models}") if 'clip' in name: + # Assuming clip requires special configuration, adjust as needed from boxmot.appearance.backbones.clip.config.defaults import _C as cfg return __model_factory[name](cfg, num_class=num_classes, camera_num=2, view_num=1) return __model_factory[name]( - num_classes=num_classes, - loss=loss, - pretrained=pretrained, - use_gpu=use_gpu - ) \ No newline at end of file + num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu + ) From f0cf71f88ab67727cae6dfe1f4101de61dc9c012 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 16:41:06 +0200 Subject: [PATCH 09/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 68 +++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 62eacb18c..1a1ede8e8 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -46,10 +46,53 @@ ] __trained_urls = { - # Example URLs for pretrained models (partial list for brevity) - "resnet50_market1501.pt": "https://example.com/resnet50_market1501.pt", - "mlfn_market1501.pt": "https://example.com/mlfn_market1501.pt", - # Add more URLs as needed + # resnet50 + "resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV", + "resnet50_dukemtmcreid.pt": "https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg", + "resnet50_msmt17.pt": "https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj", + "resnet50_fc512_market1501.pt": "https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt", + "resnet50_fc512_dukemtmcreid.pt": "https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx", + "resnet50_fc512_msmt17.pt": "https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud", + # mlfn + "mlfn_market1501.pt": "https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS", + "mlfn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum", + "mlfn_msmt17.pt": "https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-", + # hacnn + "hacnn_market1501.pt": "https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF", + "hacnn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH", + "hacnn_msmt17.pt": "https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ", + # mobilenetv2 + "mobilenetv2_x1_0_market1501.pt": "https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp", + "mobilenetv2_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds", + "mobilenetv2_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ", + "mobilenetv2_x1_4_market1501.pt": "https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5", + "mobilenetv2_x1_4_dukemtmcreid.pt": "https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN", + "mobilenetv2_x1_4_msmt17.pt": "https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz", + # osnet + "osnet_x1_0_market1501.pt": "https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA", + "osnet_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq", + "osnet_x1_0_msmt17.pt": "https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M", + "osnet_x0_75_market1501.pt": "https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer", + "osnet_x0_75_dukemtmcreid.pt": "https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or", + "osnet_x0_75_msmt17.pt": "https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc", + "osnet_x0_5_market1501.pt": "https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT", + "osnet_x0_5_dukemtmcreid.pt": "https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu", + "osnet_x0_5_msmt17.pt": "https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv", + "osnet_x0_25_market1501.pt": "https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj", + "osnet_x0_25_dukemtmcreid.pt": "https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l", + "osnet_x0_25_msmt17.pt": "https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF", + # osnet_ain | osnet_ibn + "osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ", + "osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal", + # lmbn + "lmbn_n_duke.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth", + "lmbn_n_market.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth", + "lmbn_n_cuhk03_d.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth", + # clip + "clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7", + "clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY", + "clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e", + "clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5" } NR_CLASSES_DICT = { @@ -80,6 +123,23 @@ } +# Combine into a single data structure +combined_data = {} + + +for model_type in __model_types: + for url_key, url_value in __trained_urls.items(): + if model_type in url_key: + dataset = next(key for key in NR_CLASSES_DICT.keys() if key in url_key) + num_classes = NR_CLASSES_DICT[dataset] + combined_data[model_type] = { + 'trained_url': url_value, + 'dataset': dataset, + 'num_classes': num_classes + } + break # Stop searching further once found +print(combined_data) + # Utility functions def show_downloadable_models(): LOGGER.info("Available .pt ReID models for automatic download") From d4fa0e9af2145c79d334a159784422b28beadd54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 16:47:10 +0200 Subject: [PATCH 10/11] move out stuff from __init__ --- boxmot/appearance/reid_model_factory.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/boxmot/appearance/reid_model_factory.py b/boxmot/appearance/reid_model_factory.py index 1a1ede8e8..4d8867029 100644 --- a/boxmot/appearance/reid_model_factory.py +++ b/boxmot/appearance/reid_model_factory.py @@ -123,23 +123,6 @@ } -# Combine into a single data structure -combined_data = {} - - -for model_type in __model_types: - for url_key, url_value in __trained_urls.items(): - if model_type in url_key: - dataset = next(key for key in NR_CLASSES_DICT.keys() if key in url_key) - num_classes = NR_CLASSES_DICT[dataset] - combined_data[model_type] = { - 'trained_url': url_value, - 'dataset': dataset, - 'num_classes': num_classes - } - break # Stop searching further once found -print(combined_data) - # Utility functions def show_downloadable_models(): LOGGER.info("Available .pt ReID models for automatic download") From 2f297f356eaaf9aef70e89ac88031e55bb97baac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Mon, 8 Jul 2024 17:12:15 +0200 Subject: [PATCH 11/11] move out stuff from __init__ --- boxmot/appearance/reid_export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/boxmot/appearance/reid_export.py b/boxmot/appearance/reid_export.py index 871b34701..fbabc420c 100644 --- a/boxmot/appearance/reid_export.py +++ b/boxmot/appearance/reid_export.py @@ -11,9 +11,10 @@ from torch.utils.mobile_optimizer import optimize_for_mobile from boxmot.appearance import export_formats -from boxmot.appearance.backbones import build_model, get_nr_classes from boxmot.appearance.reid_model_factory import (get_model_name, - load_pretrained_weights) + load_pretrained_weights,build_model, + get_nr_classes + ) from boxmot.utils import WEIGHTS from boxmot.utils import logger as LOGGER from boxmot.utils.checks import TestRequirements