diff --git a/.github/ISSUE_TEMPLATE/-question.md b/.github/ISSUE_TEMPLATE/-question.md new file mode 100644 index 000000000000..2c22aea70a7b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/-question.md @@ -0,0 +1,13 @@ +--- +name: "❓Question" +about: Ask a general question +title: '' +labels: question +assignees: '' + +--- + +## ❔Question + + +## Additional context diff --git a/README.md b/README.md index 1e29d1835196..d97be1597184 100755 --- a/README.md +++ b/README.md @@ -41,9 +41,13 @@ $ pip install -U -r requirements.txt ## Tutorials * [Notebook](https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb) Open In Colab +* [Kaggle](https://www.kaggle.com/ultralytics/yolov5-tutorial) * [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data) -* [Google Cloud Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) -* [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker) +* [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36) +* [ONNX and TorchScript Export](https://github.com/ultralytics/yolov5/issues/251) +* [Test-Time Augmentation (TTA)](https://github.com/ultralytics/yolov5/issues/303) +* [Google Cloud Quickstart](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) +* [Docker Quickstart](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker) ## Inference diff --git a/detect.py b/detect.py index 2650c202d49d..7b9d9b69b142 100644 --- a/detect.py +++ b/detect.py @@ -2,7 +2,7 @@ import torch.backends.cudnn as cudnn -from utils import google_utils +from models.experimental import * from utils.datasets import * from utils.utils import * @@ -20,12 +20,8 @@ def detect(save_img=False): half = device.type != 'cpu' # half precision only supported on CUDA # Load model - google_utils.attempt_download(weights) - model = torch.load(weights, map_location=device)['model'].float() # load to FP32 - # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning - # model.fuse() - model.to(device).eval() - imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size + model = attempt_load(weights, map_location=device) # load FP32 model + imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size if half: model.half() # to FP16 @@ -123,10 +119,11 @@ def detect(save_img=False): if isinstance(vid_writer, cv2.VideoWriter): vid_writer.release() # release previous video writer + fourcc = 'mp4v' # output video codec fps = vid_cap.get(cv2.CAP_PROP_FPS) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h)) + vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) vid_writer.write(im0) if save_txt or save_img: @@ -139,26 +136,26 @@ def detect(save_img=False): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS') - parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--classes', nargs='+', type=int, help='filter by class') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--update', action='store_true', help='update all models') opt = parser.parse_args() print(opt) with torch.no_grad(): - detect() - - # # Update all models - # for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: - # detect() - # create_pretrained(opt.weights, opt.weights) + if opt.update: # update all models (to fix SourceChangeWarning) + for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: + detect() + create_pretrained(opt.weights, opt.weights) + else: + detect() diff --git a/models/experimental.py b/models/experimental.py index 539e7f970ac3..a22f6bbf60e8 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -1,6 +1,7 @@ # This file contains experimental modules from models.common import * +from utils import google_utils class CrossConv(nn.Module): @@ -107,3 +108,34 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): def forward(self, x): return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) + + +class Ensemble(nn.ModuleList): + # Ensemble of models + def __init__(self): + super(Ensemble, self).__init__() + + def forward(self, x, augment=False): + y = [] + for module in self: + y.append(module(x, augment)[0]) + # y = torch.stack(y).max(0)[0] # max ensemble + # y = torch.cat(y, 1) # nms ensemble + y = torch.stack(y).mean(0) # mean ensemble + return y, None # inference, train output + + +def attempt_load(weights, map_location=None): + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + google_utils.attempt_download(w) + model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model + + if len(model) == 1: + return model[-1] # return model + else: + print('Ensemble created with %s\n' % weights) + for k in ['names', 'stride']: + setattr(model, k, getattr(model[-1], k)) + return model # return ensemble diff --git a/models/export.py b/models/export.py index c11c0a391197..990c86e92313 100644 --- a/models/export.py +++ b/models/export.py @@ -61,7 +61,8 @@ import coremltools as ct print('\nStarting CoreML export with coremltools %s...' % ct.__version__) - model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape)]) # convert + # convert model from torchscript and apply pixel scaling as per detect.py + model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])]) f = opt.weights.replace('.pt', '.mlmodel') # filename model.save(f) print('CoreML export success, saved as %s' % f) diff --git a/models/yolo.py b/models/yolo.py index 7961fabbd62d..3fd87a336c68 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -48,6 +48,7 @@ def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input cha if type(model_cfg) is dict: self.md = model_cfg # model dict else: # is *.yaml + import yaml # for torch hub with open(model_cfg) as f: self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict @@ -141,14 +142,14 @@ def _print_biases(self): # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers - print('Fusing layers...') + print('Fusing layers... ', end='') for m in self.model.modules(): if type(m) is Conv: m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv m.bn = None # remove batchnorm m.forward = m.fuseforward # update forward torch_utils.model_info(self) - + return self def parse_model(md, ch): # model_dict, input_channels(3) print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) diff --git a/requirements.txt b/requirements.txt index 1100495b9c0d..bca726aa33f0 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ Cython numpy==1.17 opencv-python -torch>=1.4 +torch>=1.5.1 matplotlib pillow tensorboard diff --git a/test.py b/test.py index 259d44444bcd..ef97d26c1942 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,8 @@ import argparse import json -from utils import google_utils +from models.experimental import * from utils.datasets import * -from utils.utils import * def test(data, @@ -18,32 +17,29 @@ def test(data, verbose=False, model=None, dataloader=None, + save_dir='', merge=False): # Initialize/load model and set device - if model is None: - training = False + training = model is not None + if training: # called by train.py + device = next(model.parameters()).device # get model device + + else: # called directly device = torch_utils.select_device(opt.device, batch_size=batch_size) + merge = opt.merge # use Merge NMS # Remove previous - for f in glob.glob('test_batch*.jpg'): + for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): os.remove(f) # Load model - google_utils.attempt_download(weights) - model = torch.load(weights, map_location=device)['model'].float() # load to FP32 - torch_utils.model_info(model) - model.fuse() - model.to(device) - imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size + model = attempt_load(weights, map_location=device) # load FP32 model + imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 # if device.type != 'cpu' and torch.cuda.device_count() > 1: # model = nn.DataParallel(model) - else: # called by train.py - training = True - device = next(model.parameters()).device # get model device - # Half half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU if half: @@ -58,12 +54,11 @@ def test(data, niou = iouv.numel() # Dataloader - if dataloader is None: # not training - merge = opt.merge # use Merge NMS + if not training: img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images - dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt, + dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] seen = 0 @@ -163,10 +158,10 @@ def test(data, # Plot images if batch_i < 1: - f = 'test_batch%g_gt.jpg' % batch_i # filename - plot_images(img, targets, paths, f, names) # ground truth - f = 'test_batch%g_pred.jpg' % batch_i - plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions + f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename + plot_images(img, targets, paths, str(f), names) # ground truth + f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) + plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy @@ -196,7 +191,7 @@ def test(data, if save_json and map50 and len(jdict): imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] f = 'detections_val2017_%s_results.json' % \ - (weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename + (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename print('\nCOCO mAP with pycocotools... saving %s...' % f) with open(f, 'w') as file: json.dump(jdict, file) @@ -229,7 +224,7 @@ def test(data, if __name__ == '__main__': parser = argparse.ArgumentParser(prog='test.py') - parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') diff --git a/train.py b/train.py index 61ed84ea8fa0..bba0883ca8d6 100644 --- a/train.py +++ b/train.py @@ -20,15 +20,10 @@ print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') mixed_precision = False # not installed -wdir = 'weights' + os.sep # weights dir -os.makedirs(wdir, exist_ok=True) -last = wdir + 'last.pt' -best = wdir + 'best.pt' -results_file = 'results.txt' - # Hyperparameters -hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) - 'momentum': 0.937, # SGD momentum +hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD + 'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) + 'momentum': 0.937, # SGD momentum/Adam beta1 'weight_decay': 5e-4, # optimizer weight decay 'giou': 0.05, # giou loss gain 'cls': 0.58, # cls loss gain @@ -45,21 +40,24 @@ 'translate': 0.0, # image translation (+/- fraction) 'scale': 0.5, # image scale (+/- gain) 'shear': 0.0} # image shear (+/- deg) -print(hyp) -# Overwrite hyp with hyp*.txt (optional) -f = glob.glob('hyp*.txt') -if f: - print('Using %s' % f[0]) - for k, v in zip(hyp.keys(), np.loadtxt(f[0])): - hyp[k] = v -# Print focal loss if gamma > 0 -if hyp['fl_gamma']: - print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) +def train(hyp): + print(f'Hyperparameters {hyp}') + log_dir = tb_writer.log_dir # run directory + wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory + os.makedirs(wdir, exist_ok=True) + last = wdir + 'last.pt' + best = wdir + 'best.pt' + results_file = log_dir + os.sep + 'results.txt' + + # Save run settings + with open(Path(log_dir) / 'hyp.yaml', 'w') as f: + yaml.dump(hyp, f, sort_keys=False) + with open(Path(log_dir) / 'opt.yaml', 'w') as f: + yaml.dump(vars(opt), f, sort_keys=False) -def train(hyp): epochs = opt.epochs # 300 batch_size = opt.batch_size # 64 weights = opt.weights # initial training weights @@ -70,14 +68,15 @@ def train(hyp): data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict train_path = data_dict['train'] test_path = data_dict['val'] - nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes + nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names + assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check # Remove previous results for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): os.remove(f) # Create model - model = Model(opt.cfg, nc=data_dict['nc']).to(device) + model = Model(opt.cfg, nc=nc).to(device) # Image sizes gs = int(max(model.stride)) # grid size (max stride) @@ -97,15 +96,20 @@ def train(hyp): else: pg0.append(v) # all else - optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \ - optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) + if hyp['optimizer'] == 'adam': # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR + optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum + else: + optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) + optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay optimizer.add_param_group({'params': pg2}) # add pg2 (biases) + print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) + del pg0, pg1, pg2 + # Scheduler https://arxiv.org/pdf/1812.01187.pdf lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) - print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) - del pg0, pg1, pg2 + # plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) # Load Model google_utils.attempt_download(weights) @@ -147,12 +151,7 @@ def train(hyp): if mixed_precision: model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) - - scheduler.last_epoch = start_epoch - 1 # do not move - # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 - # plot_lr_scheduler(optimizer, scheduler, epochs) - - # Initialize distributed training + # Distributed training if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): dist.init_process_group(backend='nccl', # distributed backend init_method='tcp://127.0.0.1:9999', # init method @@ -165,6 +164,7 @@ def train(hyp): dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class + nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) # Testloader @@ -177,15 +177,15 @@ def train(hyp): model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights - model.names = data_dict['names'] + model.names = names # Class frequency labels = np.concatenate(dataset.labels, 0) c = torch.tensor(labels[:, 0]) # classes # cf = torch.bincount(c.long(), minlength=nc) + 1. # model._initialize_biases(cf.to(device)) + plot_labels(labels, save_dir=log_dir) if tb_writer: - plot_labels(labels) tb_writer.add_histogram('classes', c, 0) # Check anchors @@ -193,14 +193,14 @@ def train(hyp): check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average - ema = torch_utils.ModelEMA(model) + ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate) # Start training t0 = time.time() - nb = len(dataloader) # number of batches - n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations) + nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' + scheduler.last_epoch = start_epoch - 1 # do not move print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) print('Using %g dataloader workers' % dataloader.num_workers) print('Starting training for %g epochs...' % epochs) @@ -225,9 +225,9 @@ def train(hyp): ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 - # Burn-in - if ni <= n_burn: - xi = [0, n_burn] # x interp + # Warmup + if ni <= nw: + xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): @@ -275,7 +275,7 @@ def train(hyp): # Plot if ni < 3: - f = 'train_batch%g.jpg' % ni # filename + f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) if tb_writer and result is not None: tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) @@ -296,7 +296,8 @@ def train(hyp): save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), model=ema.ema, single_cls=opt.single_cls, - dataloader=testloader) + dataloader=testloader, + save_dir=log_dir) # Write with open(results_file, 'a') as f: @@ -348,7 +349,7 @@ def train(hyp): # Finish if not opt.evolve: - plot_results() # save as results.png + plot_results(save_dir=log_dir) # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None torch.cuda.empty_cache() @@ -358,13 +359,15 @@ def train(hyp): if __name__ == '__main__': check_git_status() parser = argparse.ArgumentParser() + parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path') + parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') + parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)') parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--batch-size', type=int, default=16) - parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path') - parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--rect', action='store_true', help='rectangular training') - parser.add_argument('--resume', action='store_true', help='resume training from last.pt') + parser.add_argument('--resume', nargs='?', const='get_last', default=False, + help='resume from given path/to/last.pt, or most recent run if blank.') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') @@ -374,13 +377,17 @@ def train(hyp): parser.add_argument('--weights', type=str, default='', help='initial weights path') parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') - parser.add_argument('--adam', action='store_true', help='use adam optimizer') - parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%') + parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') opt = parser.parse_args() + + last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run + if last and not opt.weights: + print(f'Resuming training from {last}') opt.weights = last if opt.resume and not opt.weights else opt.weights opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file + opt.hyp = check_file(opt.hyp) if opt.hyp else '' # check file print(opt) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) @@ -389,8 +396,12 @@ def train(hyp): # Train if not opt.evolve: - tb_writer = SummaryWriter(comment=opt.name) print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') + tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) + if opt.hyp: # update hyps + with open(opt.hyp) as f: + hyp.update(yaml.load(f, Loader=yaml.FullLoader)) + train(hyp) # Evolve hyperparameters (optional) diff --git a/utils/datasets.py b/utils/datasets.py index 1ebd709482fe..406744a8465a 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -26,6 +26,11 @@ break +def get_hash(files): + # Returns a single hash value of a list of files + return sum(os.path.getsize(f) for f in files) + + def exif_size(img): # Returns exif-corrected PIL size s = img.size # (width, height) @@ -48,7 +53,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa rect=rect, # rectangular training cache_images=cache, single_cls=opt.single_cls, - stride=stride, + stride=int(stride), pad=pad) batch_size = min(batch_size, len(dataset)) @@ -280,19 +285,21 @@ class LoadImagesAndLabels(Dataset): # for training/testing def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, cache_images=False, single_cls=False, stride=32, pad=0.0): try: - path = str(Path(path)) # os-agnostic - parent = str(Path(path).parent) + os.sep - if os.path.isfile(path): # file - with open(path, 'r') as f: - f = f.read().splitlines() - f = [x.replace('./', parent) if x.startswith('./') else x for x in f] # local to global path - elif os.path.isdir(path): # folder - f = glob.iglob(path + os.sep + '*.*') - else: - raise Exception('%s does not exist' % path) + f = [] # image files + for p in path if isinstance(path, list) else [path]: + p = str(Path(p)) # os-agnostic + parent = str(Path(p).parent) + os.sep + if os.path.isfile(p): # file + with open(p, 'r') as t: + t = t.read().splitlines() + f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path + elif os.path.isdir(p): # folder + f += glob.iglob(p + os.sep + '*.*') + else: + raise Exception('%s does not exist' % p) self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats] - except: - raise Exception('Error loading data from %s. See %s' % (path, help_url)) + except Exception as e: + raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url)) n = len(self.img_files) assert n > 0, 'No images found in %s. See %s' % (path, help_url) @@ -311,20 +318,22 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r self.stride = stride # Define labels - self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') - for x in self.img_files] - - # Read image shapes (wh) - sp = path.replace('.txt', '') + '.shapes' # shapefile path - try: - with open(sp, 'r') as f: # read existing shapefile - s = [x.split() for x in f.read().splitlines()] - assert len(s) == n, 'Shapefile out of sync' - except: - s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')] - np.savetxt(sp, s, fmt='%g') # overwrites existing (if any) + self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in + self.img_files] + + # Check cache + cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels + if os.path.isfile(cache_path): + cache = torch.load(cache_path) # load + if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed + cache = self.cache_labels(cache_path) # re-cache + else: + cache = self.cache_labels(cache_path) # cache - self.shapes = np.array(s, dtype=np.float64) + # Get labels + labels, shapes = zip(*[cache[x] for x in self.img_files]) + self.shapes = np.array(shapes, dtype=np.float64) + self.labels = list(labels) # Rectangular Training https://github.com/ultralytics/yolov3/issues/232 if self.rect: @@ -350,33 +359,11 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride # Cache labels - self.imgs = [None] * n - self.labels = [np.zeros((0, 5), dtype=np.float32)] * n create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate - np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file - if os.path.isfile(np_labels_path): - s = np_labels_path # print string - x = np.load(np_labels_path, allow_pickle=True) - if len(x) == n: - self.labels = x - labels_loaded = True - else: - s = path.replace('images', 'labels') - pbar = tqdm(self.label_files) for i, file in enumerate(pbar): - if labels_loaded: - l = self.labels[i] - # np.savetxt(file, l, '%g') # save *.txt from *.npy file - else: - try: - with open(file, 'r') as f: - l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) - except: - nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing - continue - + l = self.labels[i] # label if l.shape[0]: assert l.shape[1] == 5, '> 5 label columns: %s' % file assert (l >= 0).all(), 'negative labels: %s' % file @@ -422,15 +409,13 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove - pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( - s, nf, nm, ne, nd, n) - assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) - if not labels_loaded and n > 1000: - print('Saving labels to %s for faster future loading' % np_labels_path) - np.save(np_labels_path, self.labels) # save for next time + pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( + cache_path, nf, nm, ne, nd, n) + assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) - if cache_images: # if training + self.imgs = [None] * n + if cache_images: gb = 0 # Gigabytes of cached images pbar = tqdm(range(len(self.img_files)), desc='Caching images') self.img_hw0, self.img_hw = [None] * n, [None] * n @@ -439,15 +424,30 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r gb += self.imgs[i].nbytes pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9) - # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3 - detect_corrupted_images = False - if detect_corrupted_images: - from skimage import io # conda install -c conda-forge scikit-image - for file in tqdm(self.img_files, desc='Detecting corrupted images'): - try: - _ = io.imread(file) - except: - print('Corrupted image detected: %s' % file) + def cache_labels(self, path='labels.cache'): + # Cache dataset labels, check images and read shapes + x = {} # dict + pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) + for (img, label) in pbar: + try: + l = [] + image = Image.open(img) + image.verify() # PIL verify + # _ = io.imread(img) # skimage verify (from skimage import io) + shape = exif_size(image) # image size + if os.path.isfile(label): + with open(label, 'r') as f: + l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels + if len(l) == 0: + l = np.zeros((0, 5), dtype=np.float32) + x[img] = [l, shape] + except Exception as e: + x[img] = None + print('WARNING: %s: %s' % (img, e)) + + x['hash'] = get_hash(self.label_files + self.img_files) + torch.save(x, path) # save for next time + return x def __len__(self): return len(self.img_files) @@ -679,8 +679,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding elif scaleFill: # stretch dw, dh = 0.0, 0.0 - new_unpad = new_shape - ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios dw /= 2 # divide padding into 2 sides dh /= 2 diff --git a/utils/torch_utils.py b/utils/torch_utils.py index fd00b8bde080..59f9268d2601 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -76,16 +76,36 @@ def find_modules(model, mclass=nn.Conv2d): return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] +def sparsity(model): + # Return global model sparsity + a, b = 0., 0. + for p in model.parameters(): + a += p.numel() + b += (p == 0).sum() + return b / a + + +def prune(model, amount=0.3): + # Prune model to requested global sparsity + import torch.nn.utils.prune as prune + print('Pruning model... ', end='') + for name, m in model.named_modules(): + if isinstance(m, nn.Conv2d): + prune.l1_unstructured(m, name='weight', amount=amount) # prune + prune.remove(m, 'weight') # make permanent + print(' %.3g global sparsity' % sparsity(model)) + + def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad(): # init - fusedconv = torch.nn.Conv2d(conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - bias=True) + fusedconv = nn.Conv2d(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + bias=True).to(conv.weight.device) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) @@ -93,10 +113,7 @@ def fuse_conv_and_bn(conv, bn): fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) # prepare spatial bias - if conv.bias is not None: - b_conv = conv.bias - else: - b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) @@ -139,8 +156,8 @@ def load_classifier(name='resnet101', n=2): # Reshape output to n classes filters = model.fc.weight.shape[1] - model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True) - model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True) + model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) + model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) model.fc.out_features = n return model @@ -174,15 +191,11 @@ class ModelEMA: I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ - def __init__(self, model, decay=0.9999, device=''): + def __init__(self, model, decay=0.9999, updates=0): # Create EMA - self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA - self.ema.eval() - self.updates = 0 # number of EMA updates + self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA + self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) - self.device = device # perform ema on different device from model if set - if device: - self.ema.to(device) for p in self.ema.parameters(): p.requires_grad_(False) diff --git a/utils/utils.py b/utils/utils.py index 85253a68d8da..439a4e03469c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -37,6 +37,12 @@ def init_seeds(seed=0): torch_utils.init_seeds(seed=seed) +def get_latest_run(search_dir='./runs'): + # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) + last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) + return max(last_list, key=os.path.getctime) + + def check_git_status(): # Suggest 'git pull' if repo is out of date if platform in ['linux', 'darwin']: @@ -173,7 +179,7 @@ def xywh2xyxy(x): def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): # Rescale coords (xyxy) from img1_shape to img0_shape if ratio_pad is None: # calculate from img0_shape - gain = max(img1_shape) / max(img0_shape) # gain = old / new + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding else: gain = ratio_pad[0][0] @@ -898,6 +904,16 @@ def output_to_target(output, width, height): return np.array(targets) +def increment_dir(dir, comment=''): + # Increments a directory runs/exp1 --> runs/exp2_comment + n = 0 # number + d = sorted(glob.glob(dir + '*')) # directories + if len(d): + d = d[-1].replace(dir, '') + n = int(d[:d.find('_')] if '_' in d else d) + 1 # increment + return dir + str(n) + ('_' + comment if comment else '') + + # Plotting functions --------------------------------------------------------------------------------------------------- def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy @@ -1028,7 +1044,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max return mosaic -def plot_lr_scheduler(optimizer, scheduler, epochs=300): +def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): # Plot LR simulating training for full epochs optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals y = [] @@ -1042,7 +1058,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300): plt.xlim(0, epochs) plt.ylim(0) plt.tight_layout() - plt.savefig('LR.png', dpi=200) + plt.savefig(Path(save_dir) / 'LR.png', dpi=200) def plot_test_txt(): # from utils.utils import *; plot_test() @@ -1107,7 +1123,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st plt.savefig(f.replace('.txt', '.png'), dpi=200) -def plot_labels(labels): +def plot_labels(labels, save_dir=''): # plot dataset labels c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes @@ -1128,7 +1144,7 @@ def hist2d(x, y, n=100): ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') ax[2].set_xlabel('width') ax[2].set_ylabel('height') - plt.savefig('labels.png', dpi=200) + plt.savefig(Path(save_dir) / 'labels.png', dpi=200) plt.close() @@ -1174,7 +1190,8 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re fig.savefig(f.replace('.txt', '.png'), dpi=200) -def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results() +def plot_results(start=0, stop=0, bucket='', id=(), labels=(), + save_dir=''): # from utils.utils import *; plot_results() # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training fig, ax = plt.subplots(2, 5, figsize=(12, 6)) ax = ax.ravel() @@ -1184,7 +1201,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut os.system('rm -rf storage.googleapis.com') files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] else: - files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt') + files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt') for fi, f in enumerate(files): try: results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T @@ -1205,4 +1222,4 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut fig.tight_layout() ax[1].legend() - fig.savefig('results.png', dpi=200) + fig.savefig(Path(save_dir) / 'results.png', dpi=200)