Skip to content

Commit

Permalink
new nc=len(names) check
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jul 10, 2020
1 parent cb527d3 commit e16e9e4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def train(hyp):
os.remove(f)

# Create model
model = Model(opt.cfg, nc=data_dict['nc']).to(device)
model = Model(opt.cfg, nc=nc).to(device)

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
Expand Down Expand Up @@ -177,7 +177,7 @@ def train(hyp):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']
model.names = names

# Class frequency
labels = np.concatenate(dataset.labels, 0)
Expand Down

0 comments on commit e16e9e4

Please sign in to comment.