From c9042dc2adbb635aeca407c10cf492a6eb14d772 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 19 Apr 2022 17:32:15 -0700 Subject: [PATCH] Improved non-latin `Annotator()` plotting (#7488) * Improved non-latin labels Annotator plotting May resolve https://github.com/ultralytics/yolov5/issues/7460 * Update train.py * Update train.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add progress arg Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- train.py | 8 +++++--- utils/general.py | 4 ++-- utils/plots.py | 7 ++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 806e2cebe561..c774430df293 100644 --- a/train.py +++ b/train.py @@ -48,13 +48,13 @@ from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, - intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, - print_args, print_mutation, strip_optimizer) + intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods, + one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss from utils.metrics import fitness -from utils.plots import plot_evolve, plot_labels +from utils.plots import check_font, plot_evolve, plot_labels from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html @@ -105,6 +105,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio init_seeds(1 + RANK) with torch_distributed_zero_first(LOCAL_RANK): data_dict = data_dict or check_dataset(data) # check if None + if not is_ascii(data_dict['names']): # non-latin labels, i.e. asian, arabic, cyrillic + check_font('Arial.Unicode.ttf', progress=True) train_path, val_path = data_dict['train'], data_dict['val'] nc = 1 if single_cls else int(data_dict['nc']) # number of classes names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names diff --git a/utils/general.py b/utils/general.py index daef2a427111..a4bc3cae9315 100755 --- a/utils/general.py +++ b/utils/general.py @@ -424,13 +424,13 @@ def check_file(file, suffix=''): return files[0] # return file -def check_font(font=FONT): +def check_font(font=FONT, progress=False): # Download font to CONFIG_DIR if necessary font = Path(font) if not font.exists() and not (CONFIG_DIR / font.name).exists(): url = "https://ultralytics.com/assets/" + font.name LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...') - torch.hub.download_url_to_file(url, str(font), progress=False) + torch.hub.download_url_to_file(url, str(font), progress=progress) def check_dataset(data, autodownload=True): diff --git a/utils/plots.py b/utils/plots.py index 51e9cfdf6e04..842894e745df 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords, - increment_path, is_ascii, is_chinese, try_except, xywh2xyxy, xyxy2xywh) + increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh) from utils.metrics import fitness # Settings @@ -72,11 +72,12 @@ class Annotator: # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.' - self.pil = pil or not is_ascii(example) or is_chinese(example) + non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic + self.pil = pil or non_ascii if self.pil: # use PIL self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) self.draw = ImageDraw.Draw(self.im) - self.font = check_pil_font(font='Arial.Unicode.ttf' if is_chinese(example) else font, + self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font, size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)) else: # use cv2 self.im = im