diff --git a/hubconf.py b/hubconf.py index 51f658a532ff..3488fef76ac5 100644 --- a/hubconf.py +++ b/hubconf.py @@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo from models.experimental import attempt_load from models.yolo import Model from utils.downloads import attempt_download - from utils.general import check_requirements, set_logging + from utils.general import check_requirements, intersect_dicts, set_logging from utils.torch_utils import select_device file = Path(__file__).resolve() @@ -49,9 +49,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo model = Model(cfg, channels, classes) # create model if pretrained: ckpt = torch.load(attempt_download(path), map_location=device) # load - msd = model.state_dict() # model state_dict csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 - csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter + csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect model.load_state_dict(csd, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute diff --git a/train.py b/train.py index 75f3b7cb36a7..90abdc59db88 100644 --- a/train.py +++ b/train.py @@ -43,15 +43,14 @@ from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, - labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, - print_mutation, strip_optimizer) + intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, + print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels -from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, - torch_distributed_zero_first) +from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) diff --git a/utils/general.py b/utils/general.py index 15b58257eabb..46cb1ddef983 100755 --- a/utils/general.py +++ b/utils/general.py @@ -125,6 +125,11 @@ def init_seeds(seed=0): cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) +def intersect_dicts(da, db, exclude=()): + # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values + return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} + + def get_latest_run(search_dir='.'): # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 793e8d8ffd3e..b36e98d0b656 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -153,11 +153,6 @@ def de_parallel(model): return model.module if is_parallel(model) else model -def intersect_dicts(da, db, exclude=()): - # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values - return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} - - def initialize_weights(model): for m in model.modules(): t = type(m)