diff --git a/classify/val.py b/classify/val.py index 9d965d9f1fdc..b76fb5147ecd 100644 --- a/classify/val.py +++ b/classify/val.py @@ -116,7 +116,7 @@ def run( if verbose: # all classes LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}") LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}") - for i, c in enumerate(model.names): + for i, c in model.names.items(): aci = acc[targets == i] top1i, top5i = aci.mean(0).tolist() LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}") @@ -127,6 +127,7 @@ def run( LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") + model.float() # for training return top1, top5, loss diff --git a/export.py b/export.py index 595039b24bce..7b398fdc4d93 100644 --- a/export.py +++ b/export.py @@ -599,7 +599,7 @@ def parse_opt(): parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') parser.add_argument('--include', nargs='+', - default=['torchscript', 'onnx'], + default=['torchscript'], help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') opt = parser.parse_args() print_args(vars(opt)) diff --git a/models/experimental.py b/models/experimental.py index cb32d01ba46a..02d35b9ebd11 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -from models.common import Conv from utils.downloads import attempt_download @@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): for w in weights if isinstance(weights, list) else [weights]: ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + + # Model compatibility updates if not hasattr(ckpt, 'stride'): - ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc. + ckpt.stride = torch.tensor([32.]) + if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)): + ckpt.names = dict(enumerate(ckpt.names)) # convert to dict + model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode - # Compatibility updates + # Module compatibility updates for m in model.modules(): t = type(m) if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model): diff --git a/train.py b/train.py index bbb26cdeafeb..10a3bdb56002 100644 --- a/train.py +++ b/train.py @@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio data_dict = data_dict or check_dataset(data) # check if None train_path, val_path = data_dict['train'], data_dict['val'] nc = 1 if single_cls else int(data_dict['nc']) # number of classes - names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names - assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check + names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset # Model diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1cdbe20f8670..ed56064ce02e 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -408,8 +408,6 @@ class ModelEMA: def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA - # if next(model.parameters()).device.type != 'cpu': - # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) for p in self.ema.parameters(): @@ -423,9 +421,10 @@ def update(self, model): msd = de_parallel(model).state_dict() # model state_dict for k, v in self.ema.state_dict().items(): - if v.dtype.is_floating_point: + if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d - v += (1 - d) * msd[k].detach() + v += (1 - d) * msd[k] + assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes