diff --git a/detect.py b/detect.py index 802b99fd761d..80e006251b69 100644 --- a/detect.py +++ b/detect.py @@ -9,8 +9,8 @@ from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages -from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \ - strip_optimizer, set_logging, increment_path +from utils.general import check_img_size, check_requirements, non_max_suppression, apply_classifier, scale_coords, \ + xyxy2xywh, strip_optimizer, set_logging, increment_path from utils.plots import plot_one_box from utils.torch_utils import select_device, load_classifier, time_synchronized @@ -162,6 +162,7 @@ def detect(save_img=False): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() print(opt) + check_requirements() with torch.no_grad(): if opt.update: # update all models (to fix SourceChangeWarning) diff --git a/test.py b/test.py index b520eae98d00..de63d365b0ee 100644 --- a/test.py +++ b/test.py @@ -11,8 +11,8 @@ from models.experimental import attempt_load from utils.datasets import create_dataloader -from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \ - non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path +from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ + box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path from utils.loss import compute_loss from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target, plot_study_txt @@ -302,6 +302,7 @@ def test(data, opt.save_json |= opt.data.endswith('coco.yaml') opt.data = check_file(opt.data) # check file print(opt) + check_requirements() if opt.task in ['val', 'test']: # run normally test(opt.data, diff --git a/train.py b/train.py index d61ca4b3ac2a..3a42db7f767d 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ - print_mutation, set_logging, one_cycle + check_requirements, print_mutation, set_logging, one_cycle from utils.google_utils import attempt_download from utils.loss import compute_loss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution @@ -472,6 +472,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): set_logging(opt.global_rank) if opt.global_rank in [-1, 0]: check_git_status() + check_requirements() # Resume if opt.resume: # resume an interrupted run diff --git a/utils/general.py b/utils/general.py index 90265c692f6d..f1fb7d2af539 100755 --- a/utils/general.py +++ b/utils/general.py @@ -53,6 +53,14 @@ def check_git_status(): print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') +def check_requirements(file='requirements.txt'): + # Check installed dependencies meet requirements + import pkg_resources + requirements = pkg_resources.parse_requirements(Path(file).open()) + requirements = [x.name + ''.join(*x.specs) if len(x.specs) else x.name for x in requirements] + pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met + + def check_img_size(img_size, s=32): # Verify img_size is a multiple of stride s new_size = make_divisible(img_size, int(s)) # ceil gs-multiple