diff --git a/hubconf.py b/hubconf.py index 55d15abe2ac5..39fa614b2e34 100644 --- a/hubconf.py +++ b/hubconf.py @@ -12,10 +12,10 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): - """Creates a specified YOLOv5 model + """Creates or loads a YOLOv5 model Arguments: - name (str): name of model, i.e. 'yolov5s' + name (str): model name 'yolov5s' or path 'path/to/best.pt' pretrained (bool): load pretrained weights into the model channels (int): number of input channels classes (int): number of model classes @@ -24,19 +24,19 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo device (str, torch.device, None): device to use for model parameters Returns: - YOLOv5 pytorch model + YOLOv5 model """ from pathlib import Path from models.common import AutoShape, DetectMultiBackend from models.yolo import Model from utils.downloads import attempt_download - from utils.general import check_requirements, intersect_dicts, set_logging + from utils.general import LOGGER, check_requirements, intersect_dicts, logging from utils.torch_utils import select_device + if not verbose: + LOGGER.setLevel(logging.WARNING) check_requirements(exclude=('tensorboard', 'thop', 'opencv-python')) - set_logging(verbose=verbose) - name = Path(name) path = name.with_suffix('.pt') if name.suffix == '' else name # checkpoint path try: diff --git a/utils/general.py b/utils/general.py index 3d6da2fdb173..e9f5ec2ac128 100755 --- a/utils/general.py +++ b/utils/general.py @@ -36,7 +36,7 @@ FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads -VERBOSE = str(os.getenv('VERBOSE', True)).lower() == 'true' # global verbose mode +VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode torch.set_printoptions(linewidth=320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 @@ -241,20 +241,20 @@ def check_online(): def check_git_status(): # Recommend 'git pull' if code is out of date msg = ', for updates see https://github.com/ultralytics/yolov5' - print(colorstr('github: '), end='') - assert Path('.git').exists(), 'skipping check (not a git repository)' + msg - assert not is_docker(), 'skipping check (Docker image)' + msg - assert check_online(), 'skipping check (offline)' + msg + s = colorstr('github: ') # string + assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg + assert not is_docker(), s + 'skipping check (Docker image)' + msg + assert check_online(), s + 'skipping check (offline)' + msg cmd = 'git fetch && git config --get remote.origin.url' url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind if n > 0: - s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update." + s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update." else: - s = f'up to date with {url} ✅' - print(emojis(s)) # emoji-safe + s += f'up to date with {url} ✅' + LOGGER.info(emojis(s)) # emoji-safe def check_python(minimum='3.6.2'): @@ -294,21 +294,21 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta except Exception as e: # DistributionNotFound or VersionConflict if requirements not met s = f"{prefix} {r} not found and is required by YOLOv5" if install: - print(f"{s}, attempting auto-update...") + LOGGER.info(f"{s}, attempting auto-update...") try: assert check_online(), f"'pip install {r}' skipped (offline)" - print(check_output(f"pip install '{r}'", shell=True).decode()) + LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode()) n += 1 except Exception as e: - print(f'{prefix} {e}') + LOGGER.warning(f'{prefix} {e}') else: - print(f'{s}. Please install and rerun your command.') + LOGGER.info(f'{s}. Please install and rerun your command.') if n: # if packages updated source = file.resolve() if 'file' in locals() else requirements s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" - print(emojis(s)) + LOGGER.info(emojis(s)) def check_img_size(imgsz, s=32, floor=0): @@ -318,7 +318,7 @@ def check_img_size(imgsz, s=32, floor=0): else: # list i.e. img_size=[640, 480] new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz] if new_size != imgsz: - print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') + LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') return new_size @@ -333,7 +333,7 @@ def check_imshow(): cv2.waitKey(1) return True except Exception as e: - print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') + LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') return False @@ -363,9 +363,9 @@ def check_file(file, suffix=''): url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).is_file(): - print(f'Found {url} locally at {file}') # file already exists + LOGGER.info(f'Found {url} locally at {file}') # file already exists else: - print(f'Downloading {url} to {file}...') + LOGGER.info(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url, file) assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check return file @@ -407,23 +407,23 @@ def check_dataset(data, autodownload=True): if val: val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path if not all(x.exists() for x in val): - print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()]) + LOGGER.info('\nDataset not found, missing paths: %s' % [str(x) for x in val if not x.exists()]) if s and autodownload: # download script root = path.parent if 'path' in data else '..' # unzip directory i.e. '../' if s.startswith('http') and s.endswith('.zip'): # URL f = Path(s).name # filename - print(f'Downloading {s} to {f}...') + LOGGER.info(f'Downloading {s} to {f}...') torch.hub.download_url_to_file(s, f) Path(root).mkdir(parents=True, exist_ok=True) # create root ZipFile(f).extractall(path=root) # unzip Path(f).unlink() # remove zip r = None # success elif s.startswith('bash '): # bash script - print(f'Running {s} ...') + LOGGER.info(f'Running {s} ...') r = os.system(s) else: # python script r = exec(s, {'yaml': data}) # return None - print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n") + LOGGER.info(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n") else: raise Exception('Dataset not found.') @@ -445,13 +445,13 @@ def download_one(url, dir): if Path(url).is_file(): # exists in current path Path(url).rename(f) # move to dir elif not f.exists(): - print(f'Downloading {url} to {f}...') + LOGGER.info(f'Downloading {url} to {f}...') if curl: os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail else: torch.hub.download_url_to_file(url, f, progress=True) # torch download if unzip and f.suffix in ('.zip', '.gz'): - print(f'Unzipping {f}...') + LOGGER.info(f'Unzipping {f}...') if f.suffix == '.zip': ZipFile(f).extractall(path=dir) # unzip elif f.suffix == '.gz': @@ -744,7 +744,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non output[xi] = x[i] if (time.time() - t) > time_limit: - print(f'WARNING: NMS time limit {time_limit}s exceeded') + LOGGER.warning(f'WARNING: NMS time limit {time_limit}s exceeded') break # time limit exceeded return output @@ -763,7 +763,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op p.requires_grad = False torch.save(x, s or f) mb = os.path.getsize(s or f) / 1E6 # filesize - print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") + LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") def print_mutation(results, hyp, save_dir, bucket): @@ -786,8 +786,8 @@ def print_mutation(results, hyp, save_dir, bucket): f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n') # Print to screen - print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys)) - print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n') + LOGGER.info(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys)) + LOGGER.info(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n') # Save yaml with open(evolve_yaml, 'w') as f: diff --git a/utils/plots.py b/utils/plots.py index 69037ee9af70..74868403edc0 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -57,7 +57,7 @@ def check_font(font='Arial.ttf', size=10): return ImageFont.truetype(str(font) if font.exists() else font.name, size) except Exception as e: # download if missing url = "https://ultralytics.com/assets/" + font.name - print(f'Downloading {url} to {font}...') + LOGGER.info(f'Downloading {url} to {font}...') torch.hub.download_url_to_file(url, str(font), progress=False) try: return ImageFont.truetype(str(font), size) @@ -143,7 +143,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec ax[i].imshow(blocks[i].squeeze()) # cmap='gray' ax[i].axis('off') - print(f'Saving {f}... ({n}/{channels})') + LOGGER.info(f'Saving {f}... ({n}/{channels})') plt.savefig(f, dpi=300, bbox_inches='tight') plt.close() np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save @@ -417,7 +417,7 @@ def plot_results(file='path/to/results.csv', dir=''): # if j in [8, 9, 10]: # share train and val loss y axes # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) except Exception as e: - print(f'Warning: Plotting error for {f}: {e}') + LOGGER.info(f'Warning: Plotting error for {f}: {e}') ax[1].legend() fig.savefig(save_dir / 'results.png', dpi=200) plt.close()