diff --git a/models/yolo.py b/models/yolo.py index a9e1da43d913..a047fef397ee 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -62,7 +62,7 @@ def _make_grid(nx=20, ny=20): class Model(nn.Module): - def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes + def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes super(Model, self).__init__() if isinstance(cfg, dict): self.yaml = cfg # model dict @@ -75,8 +75,11 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, # Define model ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels if nc and nc != self.yaml['nc']: - logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) + logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") self.yaml['nc'] = nc # override yaml value + if anchors: + logger.info(f'Overriding model.yaml anchors with anchors={anchors}') + self.yaml['anchors'] = round(anchors) # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) diff --git a/train.py b/train.py index e2c82339f7fe..1b8b315ce927 100644 --- a/train.py +++ b/train.py @@ -84,7 +84,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): model.load_state_dict(state_dict, strict=False) # load logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report else: - model = Model(opt.cfg, ch=3, nc=nc).to(device) # create + model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create # Freeze freeze = [] # parameter names to freeze (full or partial)