Skip to content

Commit

Permalink
VOC names dictionary fix (ultralytics#9034)
Browse files Browse the repository at this point in the history
* VOC names dictionary fix

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update dataloaders.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent de21e44 commit 93748b6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions data/VOC.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ download: |
w = int(size.find('width').text)
h = int(size.find('height').text)
names = list(yaml['names'].values()) # names list
for obj in root.iter('object'):
cls = obj.find('name').text
if cls in yaml['names'] and not int(obj.find('difficult').text) == 1:
if cls in names and int(obj.find('difficult').text) != 1:
xmlbox = obj.find('bndbox')
bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
cls_id = yaml['names'].index(cls) # class id
cls_id = names.index(cls) # class id
out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
Expand Down
10 changes: 6 additions & 4 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from utils.torch_utils import torch_distributed_zero_first

# Parameters
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
Expand Down Expand Up @@ -456,7 +456,7 @@ def __init__(self,
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.im_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')

# Check cache
self.label_files = img2label_paths(self.im_files) # labels
Expand All @@ -475,11 +475,13 @@ def __init__(self,
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'

# Read cache
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
labels, shapes, self.segments = zip(*cache.values())
nl = len(np.concatenate(labels, 0)) # number of labels
assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
self.labels = list(labels)
self.shapes = np.array(shapes)
self.im_files = list(cache.keys()) # update
Expand Down Expand Up @@ -572,7 +574,7 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
if msgs:
LOGGER.info('\n'.join(msgs))
if nf == 0:
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}')
x['hash'] = get_hash(self.label_files + self.im_files)
x['results'] = nf, nm, ne, nc, len(self.im_files)
x['msgs'] = msgs # warnings
Expand Down

0 comments on commit 93748b6

Please sign in to comment.