diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000000..dad4239ebad5 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# this drop notebooks from GitHub language stats +*.ipynb linguist-vendored diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml new file mode 100644 index 000000000000..0ee330a45483 --- /dev/null +++ b/.github/workflows/ci-testing.yml @@ -0,0 +1,72 @@ +name: CI CPU testing + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: [push, pull_request] + +jobs: + cpu-tests: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.8] + model: ['yolov5s'] # models to test + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 50 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get pip cache + id: pip-cache + run: | + python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" + + - name: Cache pip + uses: actions/cache@v1 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.python-version }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -qr requirements.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -q onnx + python --version + pip --version + pip list + shell: bash + + - name: Download data + run: | + python -c "from utils.google_utils import * ; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')" + mv ./coco128 ../ + + - name: Tests workflow + run: | + export PYTHONPATH="$PWD" # to run *.py. files in subdirectories + di=cpu # inference devices # define device + + # train + python train.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --cfg models/${{ matrix.model }}.yaml --epochs 1 --device $di + # detect + python detect.py --weights weights/${{ matrix.model }}.pt --device $di + python detect.py --weights runs/exp0/weights/last.pt --device $di + # test + python test.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --device $di + python test.py --img 256 --batch 8 --weights runs/exp0/weights/last.pt --device $di + + python models/yolo.py --cfg models/${{ matrix.model }}.yaml # inspect + python models/export.py --img 256 --batch 1 --weights weights/${{ matrix.model }}.pt # export + shell: bash diff --git a/README.md b/README.md index df4060b813c2..c80b139a2014 100755 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@   +![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) + This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. ** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 8, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. diff --git a/hubconf.py b/hubconf.py index 29e93bdf2135..bbca702f326b 100644 --- a/hubconf.py +++ b/hubconf.py @@ -37,9 +37,11 @@ def create(name, pretrained, channels, classes): state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter model.load_state_dict(state_dict, strict=False) # load return model + except Exception as e: help_url = 'https://github.com/ultralytics/yolov5/issues/36' - print('%s\nCache maybe be out of date. Delete cache and retry. See %s for help.' % (e, help_url)) + s = 'Cache maybe be out of date, deleting cache and retrying may solve this. See %s for help.' % help_url + raise Exception(s) from e def yolov5s(pretrained=False, channels=3, classes=80): diff --git a/models/common.py b/models/common.py index 2c2d600394c1..7a7272be9a5c 100644 --- a/models/common.py +++ b/models/common.py @@ -76,12 +76,6 @@ def forward(self, x): return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) -class Flatten(nn.Module): - # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions - def forward(self, x): - return x.view(x.size(0), -1) - - class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups @@ -100,3 +94,23 @@ def __init__(self, dimension=1): def forward(self, x): return torch.cat(x, self.d) + + +class Flatten(nn.Module): + # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions + @staticmethod + def forward(x): + return x.view(x.size(0), -1) + + +class Classify(nn.Module): + # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + super(Classify, self).__init__() + self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1) + self.flat = Flatten() + + def forward(self, x): + z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list + return self.flat(self.conv(z)) # flatten to x(b,c2) diff --git a/requirements.txt b/requirements.txt index 0deceacc74fb..c3926610d4e6 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ # pip install -U -r requirements.txt Cython -numpy==1.17.3 +numpy>=1.18.5 opencv-python torch>=1.5.1 matplotlib pillow tensorboard PyYAML>=5.3 -torchvision +torchvision>=0.6 scipy tqdm -git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI +# pycocotools>=2.0 # Nvidia Apex (optional) for mixed precision training -------------------------- # git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . --user && cd .. && rm -rf apex diff --git a/test.py b/test.py index ed7e29caf66a..b1e6a231eec1 100644 --- a/test.py +++ b/test.py @@ -126,13 +126,13 @@ def test(data, # Append to pycocotools JSON dictionary if save_json: # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... - image_id = int(Path(paths[si]).stem.split('_')[-1]) + image_id = Path(paths[si]).stem box = pred[:, :4].clone() # xyxy scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape box = xyxy2xywh(box) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(pred.tolist(), box.tolist()): - jdict.append({'image_id': image_id, + jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id, 'category_id': coco91class[int(p[5])], 'bbox': [round(x, 3) for x in b], 'score': round(p[4], 5)}) @@ -200,8 +200,7 @@ def test(data, print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) # Save JSON - if save_json and map50 and len(jdict): - imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] + if save_json and len(jdict): f = 'detections_val2017_%s_results.json' % \ (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename print('\nCOCO mAP with pycocotools... saving %s...' % f) @@ -212,6 +211,7 @@ def test(data, from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval + imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api cocoDt = cocoGt.loadRes(f) # initialize COCO pred api cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') @@ -220,9 +220,8 @@ def test(data, cocoEval.accumulate() cocoEval.summarize() map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) - except: - print('WARNING: pycocotools must be installed with numpy==1.17 to run correctly. ' - 'See https://github.com/cocodataset/cocoapi/issues/356') + except Exception as e: + print('ERROR: pycocotools unable to run: %s' % e) # Return results model.float() # for training diff --git a/train.py b/train.py index 40f82bbed9c9..ac381b316fd7 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter import torch.multiprocessing as mp @@ -25,6 +26,7 @@ print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') mixed_precision = False # not installed + def train(local_rank, hyp, opt, device): print(f'Hyperparameters {hyp}') if local_rank in [-1, 0]: @@ -58,11 +60,12 @@ def train(local_rank, hyp, opt, device): torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:9999', rank=local_rank, world_size=opt.world_size) # distributed backend + # TODO: Init DDP logging. Only the first process is allowed to log. # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs. # Configure - init_seeds(2+local_rank) + init_seeds(2 + local_rank) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict train_path = data_dict['train'] @@ -124,7 +127,7 @@ def train(local_rank, hyp, opt, device): # load model try: ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() - if model.state_dict()[k].shape == v.shape} # to FP32, filter + if k in model.state_dict() and model.state_dict()[k].shape == v.shape} model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ @@ -160,7 +163,6 @@ def train(local_rank, hyp, opt, device): scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) - # Exponential moving average # From https://github.com/rwightman/pytorch-image-models/blob/master/train.py: # "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper" @@ -175,20 +177,22 @@ def train(local_rank, hyp, opt, device): model = DDP(model, device_ids=[local_rank], output_device=local_rank) elif (opt.parallel): model = DP(model) - + # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, - cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size) + cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, + world_size=opt.world_size) + mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches - assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) + assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) # Testloader if local_rank in [-1, 0]: # local_rank is set to -1. Because only the first process is expected to do evaluation. - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, - cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] - + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, + cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] + # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model @@ -233,7 +237,8 @@ def train(local_rank, hyp, opt, device): if local_rank in [-1, 0]: w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) - dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx + dataset.indices = random.choices(range(dataset.n), weights=image_weights, + k=dataset.n) # rand weighted idx # Broadcast. if local_rank != -1: indices = torch.zeros([dataset.n], dtype=torch.int) @@ -248,6 +253,7 @@ def train(local_rank, hyp, opt, device): # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = torch.zeros(4, device=device) # mean losses + if opt.distributed: dataloader.sampler.set_epoch(epoch) pbar = enumerate(dataloader) @@ -393,7 +399,7 @@ def train(local_rank, hyp, opt, device): plot_results() # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) - dist.destroy_process_group() if local_rank not in [-1,0] else None + dist.destroy_process_group() if opt.distributed else None torch.cuda.empty_cache() return results @@ -449,7 +455,6 @@ def run(fn, hyp, opt, device): parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') # Parameter For DDP. - # parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.") parser.add_argument("--distributed", action="store_true", help="Set ddp mode") opt = parser.parse_args() @@ -458,8 +463,6 @@ def run(fn, hyp, opt, device): if last and not opt.weights: print(f'Resuming training from {last}') opt.weights = last if opt.resume and not opt.weights else opt.weights - # if opt.local_rank in [-1, 0]: - # check_git_status() check_git_status() opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file @@ -479,16 +482,6 @@ def run(fn, hyp, opt, device): if (opt.distributed): assert torch.cuda.is_available() and torch.cuda.device_count() > 1, "DDP is not available" if device.type == 'cpu': mixed_precision = False - # elif opt.local_rank != -1: - # # DDP mode - # assert torch.cuda.device_count() > opt.local_rank - # torch.cuda.set_device(opt.local_rank) - # device = torch.device("cuda", opt.local_rank) - # dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - # opt.world_size = dist.get_world_size() - # assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" - # opt.batch_size = opt.total_batch_size // opt.world_size elif torch.cuda.is_available() and torch.cuda.device_count() > 1: opt.parallel = True if (opt.distributed): diff --git a/utils/datasets.py b/utils/datasets.py index d0d647fb9964..a0a077528b9e 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -17,7 +17,7 @@ from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' -img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng'] +img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff','.dng'] vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv'] # Get orientation exif tag diff --git a/utils/google_utils.py b/utils/google_utils.py index 0a3dec1d4bab..ca9600b35a13 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -51,7 +51,7 @@ def gdrive_download(id='1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', name='coco128.zip'): s = "curl -Lb ./cookie \"drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=%s\" -o %s" % ( id, name) else: # small file - s = "curl -s -L -o %s 'drive.google.com/uc?export=download&id=%s'" % (name, id) + s = 'curl -s -L -o %s "drive.google.com/uc?export=download&id=%s"' % (name, id) r = os.system(s) # execute, capture return values os.remove('cookie') if os.path.exists('cookie') else None diff --git a/weights/download_weights.sh b/weights/download_weights.sh index 6834ddb37bb2..206b7002aeca 100755 --- a/weights/download_weights.sh +++ b/weights/download_weights.sh @@ -1,8 +1,10 @@ #!/bin/bash # Download common models -python3 -c "from utils.google_utils import *; +python -c " +from utils.google_utils import *; attempt_download('weights/yolov5s.pt'); attempt_download('weights/yolov5m.pt'); attempt_download('weights/yolov5l.pt'); -attempt_download('weights/yolov5x.pt')" +attempt_download('weights/yolov5x.pt') +"