Skip to content

Commit

Permalink
model.names multi-GPU bug fix #94
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 25, 2020
1 parent 9a9c4f1 commit b50fdf1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def detect(save_img=False):
dataset = LoadImages(source, img_size=imgsz)

# Get names and colors
names = model.names if hasattr(model, 'names') else model.modules.names
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

# Run inference
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def train(hyp):
# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
model.names = data_dict['names']

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
Expand Down Expand Up @@ -193,7 +194,6 @@ def train(hyp):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

# Class frequency
labels = np.concatenate(dataset.labels, 0)
Expand Down

0 comments on commit b50fdf1

Please sign in to comment.