From 98f905869c2c952e2b6921522e287e0830ad4cb4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 9 Mar 2021 17:47:36 -0800 Subject: [PATCH] add checks --- classifier.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/classifier.py b/classifier.py index acbd17ccddc5..566de6cef2be 100644 --- a/classifier.py +++ b/classifier.py @@ -18,7 +18,7 @@ from tqdm import tqdm from models.common import Classify -from utils.general import set_logging, check_file, increment_path +from utils.general import set_logging, check_file, increment_path, check_git_status, check_requirements from utils.torch_utils import model_info, select_device, is_parallel # Settings @@ -225,10 +225,16 @@ def test(model, dataloader, names, criterion=None, verbose=False, pbar=None): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() + # Checks + check_git_status() + check_requirements() + + # Parameters device = select_device(opt.device, batch_size=opt.batch_size) cuda = device.type != 'cpu' opt.hyp = check_file(opt.hyp) # check files opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize + # Train train()