From e16e9e43e1b4b8f9b9819fe81d537c00ec6d606e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 9 Jul 2020 17:10:43 -0700 Subject: [PATCH] new nc=len(names) check --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 6bfa4286b1aa..bba0883ca8d6 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,7 @@ def train(hyp): os.remove(f) # Create model - model = Model(opt.cfg, nc=data_dict['nc']).to(device) + model = Model(opt.cfg, nc=nc).to(device) # Image sizes gs = int(max(model.stride)) # grid size (max stride) @@ -177,7 +177,7 @@ def train(hyp): model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights - model.names = data_dict['names'] + model.names = names # Class frequency labels = np.concatenate(dataset.labels, 0)