From 02b66fe63f26df2acdf67a8d09946c2665fa4dc2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 21 Jul 2020 18:53:06 +0700 Subject: [PATCH] Update to master (#11) * update test.py --save-txt * update test.py --save-txt * add GH action tests * requirements * requirements * requirements * fix tests * add badge * lower batch-size * weights * args * parallel * rename eval * rename eval * paths * rename * lower bs * timeout * less xOS * drop xOS * git attrib * paths * paths * Apply suggestions from code review * Update eval.py * Update eval.py * update requirements.txt * Update ci-testing.yml * Update ci-testing.yml * rename test * revert test module to confuse users... * update hubconf.py * update common.py add Classify() * Update ci-testing.yml * Update ci-testing.yml * Update ci-testing.yml * Update ci-testing.yml * update common.py Classify() * Update ci-testing.yml * update test.py * update train.py ckpt loading * update train.py class count assertion #424 * update train.py class count assertion #424 Signed-off-by: Glenn Jocher * Update requirements.txt * [WIP] Feature/ddp fixed (#401) * Squashed commit of the following: commit d738487089e41c22b3b1cd73aa7c1c40320a6ebf Author: NanoCode012 Date: Tue Jul 14 17:33:38 2020 +0700 Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. commit e742dd9619d29306c7541821238d3d7cddcdc508 Author: yizhi.chen Date: Tue Jul 14 15:38:48 2020 +0800 Make SyncBN a choice commit e90d4004387e6103fecad745f8cbc2edc918e906 Merge: 5bf8beb cd90360 Author: yzchen Date: Tue Jul 14 15:32:10 2020 +0800 Merge pull request #6 from NanoCode012/patch-5 Update train.py commit cd9036017e7f8bd519a8b62adab0f47ea67f4962 Author: NanoCode012 Date: Tue Jul 14 13:39:29 2020 +0700 Update train.py Remove redundant `opt.` prefix. commit 5bf8bebe8873afb18b762fe1f409aca116fac073 Merge: c9558a9 a1c8406 Author: yizhi.chen Date: Tue Jul 14 14:09:51 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit c9558a9b51547febb03d9c1ca42e2ef0fc15bb31 Author: yizhi.chen Date: Tue Jul 14 13:51:34 2020 +0800 Add device allocation for loss compute commit 4f08c692fb5e943a89e0ee354ef6c80a50eeb28d Author: yizhi.chen Date: Thu Jul 9 11:16:27 2020 +0800 Revert drop_last commit 1dabe33a5a223b758cc761fc8741c6224205a34b Merge: a1ce9b1 4b8450b Author: yizhi.chen Date: Thu Jul 9 11:15:49 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit a1ce9b1e96b71d7fcb9d3e8143013eb8cebe5e27 Author: yizhi.chen Date: Thu Jul 9 11:15:21 2020 +0800 fix lr warning commit 4b8450b46db76e5e58cd95df965d4736077cfb0e Merge: b9a50ae 02c63ef Author: yzchen Date: Wed Jul 8 21:24:24 2020 +0800 Merge pull request #4 from NanoCode012/patch-4 Add drop_last for multi gpu commit 02c63ef81cf98b28b10344fe2cce08a03b143941 Author: NanoCode012 Date: Wed Jul 8 10:08:30 2020 +0700 Add drop_last for multi gpu commit b9a50aed48ab1536f94d49269977e2accd67748f Merge: ec2dc6c 121d90b Author: yizhi.chen Date: Tue Jul 7 19:48:04 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit ec2dc6cc56de43ddff939e14c450672d0fbf9b3d Merge: d0326e3 82a6182 Author: yizhi.chen Date: Tue Jul 7 19:34:31 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit d0326e398dfeeeac611ccc64198d4fe91b7aa969 Author: yizhi.chen Date: Tue Jul 7 19:31:24 2020 +0800 Add SyncBN commit 82a6182b3ad0689a4432b631b438004e5acb3b74 Merge: 96fa40a 050b2a5 Author: yzchen Date: Tue Jul 7 19:21:01 2020 +0800 Merge pull request #1 from NanoCode012/patch-2 Convert BatchNorm to SyncBatchNorm commit 050b2a5a79a89c9405854d439a1f70f892139b1c Author: NanoCode012 Date: Tue Jul 7 12:38:14 2020 +0700 Add cleanup for process_group commit 2aa330139f3cc1237aeb3132245ed7e5d6da1683 Author: NanoCode012 Date: Tue Jul 7 12:07:40 2020 +0700 Remove apex.parallel. Use torch.nn.parallel For future compatibility commit 77c8e27e603bea9a69e7647587ca8d509dc1990d Author: NanoCode012 Date: Tue Jul 7 01:54:39 2020 +0700 Convert BatchNorm to SyncBatchNorm commit 96fa40a3a925e4ffd815fe329e1b5181ec92adc8 Author: yizhi.chen Date: Mon Jul 6 21:53:56 2020 +0800 Fix the datset inconsistency problem commit 16e7c269d062c8d16c4d4ff70cc80fd87935dc95 Author: yizhi.chen Date: Mon Jul 6 11:34:03 2020 +0800 Add loss multiplication to preserver the single-process performance commit e83805563065ffd2e38f85abe008fc662cc17909 Merge: 625bb49 3bdea3f Author: yizhi.chen Date: Fri Jul 3 20:56:30 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit 625bb49f4e52d781143fea0af36d14e5be8b040c Author: yizhi.chen Date: Thu Jul 2 22:45:15 2020 +0800 DDP established * Squashed commit of the following: commit 94147314e559a6bdd13cb9de62490d385c27596f Merge: 65157e2 37acbdc Author: yizhi.chen Date: Thu Jul 16 14:00:17 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov4 into feature/DDP_fixed commit 37acbdc0b6ef8c3343560834b914c83bbb0abbd1 Author: Glenn Jocher Date: Wed Jul 15 20:03:41 2020 -0700 update test.py --save-txt commit b8c2da4a0d6880afd7857207340706666071145b Author: Glenn Jocher Date: Wed Jul 15 20:00:48 2020 -0700 update test.py --save-txt commit 65157e2fc97d371bc576e18b424e130eb3026917 Author: yizhi.chen Date: Wed Jul 15 16:44:13 2020 +0800 Revert the README.md removal commit 1c802bfa503623661d8617ca3f259835d27c5345 Merge: cd55b44 0f3b8bb Author: yizhi.chen Date: Wed Jul 15 16:43:38 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit cd55b445c4dcd8003ff4b0b46b64adf7c16e5ce7 Author: yizhi.chen Date: Wed Jul 15 16:42:33 2020 +0800 fix the DDP performance deterioration bug. commit 0f3b8bb1fae5885474ba861bbbd1924fb622ee93 Author: Glenn Jocher Date: Wed Jul 15 00:28:53 2020 -0700 Delete README.md commit f5921ba1e35475f24b062456a890238cb7a3cf94 Merge: 85ab2f3 bd3fdbb Author: yizhi.chen Date: Wed Jul 15 11:20:17 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit bd3fdbbf1b08ef87931eef49fa8340621caa7e87 Author: Glenn Jocher Date: Tue Jul 14 18:38:20 2020 -0700 Update README.md commit c1a97a7767ccb2aa9afc7a5e72fd159e7c62ec02 Merge: 2bf86b8 f796708 Author: Glenn Jocher Date: Tue Jul 14 18:36:53 2020 -0700 Merge branch 'master' into feature/DDP_fixed commit 2bf86b892fa2fd712f6530903a0d9b8533d7447a Author: NanoCode012 Date: Tue Jul 14 22:18:15 2020 +0700 Fixed world_size not found when called from test commit 85ab2f38cdda28b61ad15a3a5a14c3aafb620dc8 Merge: 5a19011 c8357ad Author: yizhi.chen Date: Tue Jul 14 22:19:58 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit 5a19011949398d06e744d8d5521ab4e6dfa06ab7 Author: yizhi.chen Date: Tue Jul 14 22:19:15 2020 +0800 Add assertion for <=2 gpus DDP commit c8357ad5b15a0e6aeef4d7fe67ca9637f7322a4d Merge: e742dd9 787582f Author: yzchen Date: Tue Jul 14 22:10:02 2020 +0800 Merge pull request #8 from MagicFrogSJTU/NanoCode012-patch-1 Modify number of dataloaders' workers commit 787582f97251834f955ef05a77072b8c673a8397 Author: NanoCode012 Date: Tue Jul 14 20:38:58 2020 +0700 Fixed issue with single gpu not having world_size commit 63648925288d63a21174a4dd28f92dbfebfeb75a Author: NanoCode012 Date: Tue Jul 14 19:16:15 2020 +0700 Add assert message for clarification Clarify why assertion was thrown to users commit 69364d6050e048d0d8834e0f30ce84da3f6a13f3 Author: NanoCode012 Date: Tue Jul 14 17:36:48 2020 +0700 Changed number of workers check commit d738487089e41c22b3b1cd73aa7c1c40320a6ebf Author: NanoCode012 Date: Tue Jul 14 17:33:38 2020 +0700 Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. commit e742dd9619d29306c7541821238d3d7cddcdc508 Author: yizhi.chen Date: Tue Jul 14 15:38:48 2020 +0800 Make SyncBN a choice commit e90d4004387e6103fecad745f8cbc2edc918e906 Merge: 5bf8beb cd90360 Author: yzchen Date: Tue Jul 14 15:32:10 2020 +0800 Merge pull request #6 from NanoCode012/patch-5 Update train.py commit cd9036017e7f8bd519a8b62adab0f47ea67f4962 Author: NanoCode012 Date: Tue Jul 14 13:39:29 2020 +0700 Update train.py Remove redundant `opt.` prefix. commit 5bf8bebe8873afb18b762fe1f409aca116fac073 Merge: c9558a9 a1c8406 Author: yizhi.chen Date: Tue Jul 14 14:09:51 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit c9558a9b51547febb03d9c1ca42e2ef0fc15bb31 Author: yizhi.chen Date: Tue Jul 14 13:51:34 2020 +0800 Add device allocation for loss compute commit 4f08c692fb5e943a89e0ee354ef6c80a50eeb28d Author: yizhi.chen Date: Thu Jul 9 11:16:27 2020 +0800 Revert drop_last commit 1dabe33a5a223b758cc761fc8741c6224205a34b Merge: a1ce9b1 4b8450b Author: yizhi.chen Date: Thu Jul 9 11:15:49 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit a1ce9b1e96b71d7fcb9d3e8143013eb8cebe5e27 Author: yizhi.chen Date: Thu Jul 9 11:15:21 2020 +0800 fix lr warning commit 4b8450b46db76e5e58cd95df965d4736077cfb0e Merge: b9a50ae 02c63ef Author: yzchen Date: Wed Jul 8 21:24:24 2020 +0800 Merge pull request #4 from NanoCode012/patch-4 Add drop_last for multi gpu commit 02c63ef81cf98b28b10344fe2cce08a03b143941 Author: NanoCode012 Date: Wed Jul 8 10:08:30 2020 +0700 Add drop_last for multi gpu commit b9a50aed48ab1536f94d49269977e2accd67748f Merge: ec2dc6c 121d90b Author: yizhi.chen Date: Tue Jul 7 19:48:04 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit ec2dc6cc56de43ddff939e14c450672d0fbf9b3d Merge: d0326e3 82a6182 Author: yizhi.chen Date: Tue Jul 7 19:34:31 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit d0326e398dfeeeac611ccc64198d4fe91b7aa969 Author: yizhi.chen Date: Tue Jul 7 19:31:24 2020 +0800 Add SyncBN commit 82a6182b3ad0689a4432b631b438004e5acb3b74 Merge: 96fa40a 050b2a5 Author: yzchen Date: Tue Jul 7 19:21:01 2020 +0800 Merge pull request #1 from NanoCode012/patch-2 Convert BatchNorm to SyncBatchNorm commit 050b2a5a79a89c9405854d439a1f70f892139b1c Author: NanoCode012 Date: Tue Jul 7 12:38:14 2020 +0700 Add cleanup for process_group commit 2aa330139f3cc1237aeb3132245ed7e5d6da1683 Author: NanoCode012 Date: Tue Jul 7 12:07:40 2020 +0700 Remove apex.parallel. Use torch.nn.parallel For future compatibility commit 77c8e27e603bea9a69e7647587ca8d509dc1990d Author: NanoCode012 Date: Tue Jul 7 01:54:39 2020 +0700 Convert BatchNorm to SyncBatchNorm commit 96fa40a3a925e4ffd815fe329e1b5181ec92adc8 Author: yizhi.chen Date: Mon Jul 6 21:53:56 2020 +0800 Fix the datset inconsistency problem commit 16e7c269d062c8d16c4d4ff70cc80fd87935dc95 Author: yizhi.chen Date: Mon Jul 6 11:34:03 2020 +0800 Add loss multiplication to preserver the single-process performance commit e83805563065ffd2e38f85abe008fc662cc17909 Merge: 625bb49 3bdea3f Author: yizhi.chen Date: Fri Jul 3 20:56:30 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit 625bb49f4e52d781143fea0af36d14e5be8b040c Author: yizhi.chen Date: Thu Jul 2 22:45:15 2020 +0800 DDP established * Fixed destroy_process_group in DP mode * Update torch_utils.py * Update utils.py Revert build_targets() to current master. * Update datasets.py * Fixed world_size attribute not found Co-authored-by: NanoCode012 Co-authored-by: Glenn Jocher * Update ci-testing.yml (#445) * Update ci-testing.yml * Update ci-testing.yml * Update requirements.txt * Update requirements.txt * Update google_utils.py * Update test.py * Update ci-testing.yml * pretrained model loading bug fix (#450) Signed-off-by: Glenn Jocher * Update datasets.py (#454) Co-authored-by: Glenn Jocher Co-authored-by: Jirka Co-authored-by: Jirka Borovec Co-authored-by: yzchen Co-authored-by: pritul dave <41751718+pritul2@users.noreply.github.com> --- .gitattributes | 2 + .github/workflows/ci-testing.yml | 72 ++++++++++++++++++++++++++++++++ README.md | 2 + hubconf.py | 4 +- models/common.py | 26 +++++++++--- requirements.txt | 6 +-- test.py | 13 +++--- train.py | 41 ++++++++---------- utils/datasets.py | 2 +- utils/google_utils.py | 2 +- weights/download_weights.sh | 6 ++- 11 files changed, 131 insertions(+), 45 deletions(-) create mode 100644 .gitattributes create mode 100644 .github/workflows/ci-testing.yml diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000000..dad4239ebad5 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# this drop notebooks from GitHub language stats +*.ipynb linguist-vendored diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml new file mode 100644 index 000000000000..0ee330a45483 --- /dev/null +++ b/.github/workflows/ci-testing.yml @@ -0,0 +1,72 @@ +name: CI CPU testing + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: [push, pull_request] + +jobs: + cpu-tests: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.8] + model: ['yolov5s'] # models to test + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 50 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get pip cache + id: pip-cache + run: | + python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" + + - name: Cache pip + uses: actions/cache@v1 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.python-version }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -qr requirements.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -q onnx + python --version + pip --version + pip list + shell: bash + + - name: Download data + run: | + python -c "from utils.google_utils import * ; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')" + mv ./coco128 ../ + + - name: Tests workflow + run: | + export PYTHONPATH="$PWD" # to run *.py. files in subdirectories + di=cpu # inference devices # define device + + # train + python train.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --cfg models/${{ matrix.model }}.yaml --epochs 1 --device $di + # detect + python detect.py --weights weights/${{ matrix.model }}.pt --device $di + python detect.py --weights runs/exp0/weights/last.pt --device $di + # test + python test.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --device $di + python test.py --img 256 --batch 8 --weights runs/exp0/weights/last.pt --device $di + + python models/yolo.py --cfg models/${{ matrix.model }}.yaml # inspect + python models/export.py --img 256 --batch 1 --weights weights/${{ matrix.model }}.pt # export + shell: bash diff --git a/README.md b/README.md index df4060b813c2..c80b139a2014 100755 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@   +![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) + This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. ** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 8, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. diff --git a/hubconf.py b/hubconf.py index 29e93bdf2135..bbca702f326b 100644 --- a/hubconf.py +++ b/hubconf.py @@ -37,9 +37,11 @@ def create(name, pretrained, channels, classes): state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter model.load_state_dict(state_dict, strict=False) # load return model + except Exception as e: help_url = 'https://github.com/ultralytics/yolov5/issues/36' - print('%s\nCache maybe be out of date. Delete cache and retry. See %s for help.' % (e, help_url)) + s = 'Cache maybe be out of date, deleting cache and retrying may solve this. See %s for help.' % help_url + raise Exception(s) from e def yolov5s(pretrained=False, channels=3, classes=80): diff --git a/models/common.py b/models/common.py index 2c2d600394c1..7a7272be9a5c 100644 --- a/models/common.py +++ b/models/common.py @@ -76,12 +76,6 @@ def forward(self, x): return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) -class Flatten(nn.Module): - # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions - def forward(self, x): - return x.view(x.size(0), -1) - - class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups @@ -100,3 +94,23 @@ def __init__(self, dimension=1): def forward(self, x): return torch.cat(x, self.d) + + +class Flatten(nn.Module): + # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions + @staticmethod + def forward(x): + return x.view(x.size(0), -1) + + +class Classify(nn.Module): + # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + super(Classify, self).__init__() + self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1) + self.flat = Flatten() + + def forward(self, x): + z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list + return self.flat(self.conv(z)) # flatten to x(b,c2) diff --git a/requirements.txt b/requirements.txt index 0deceacc74fb..c3926610d4e6 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ # pip install -U -r requirements.txt Cython -numpy==1.17.3 +numpy>=1.18.5 opencv-python torch>=1.5.1 matplotlib pillow tensorboard PyYAML>=5.3 -torchvision +torchvision>=0.6 scipy tqdm -git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI +# pycocotools>=2.0 # Nvidia Apex (optional) for mixed precision training -------------------------- # git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . --user && cd .. && rm -rf apex diff --git a/test.py b/test.py index ed7e29caf66a..b1e6a231eec1 100644 --- a/test.py +++ b/test.py @@ -126,13 +126,13 @@ def test(data, # Append to pycocotools JSON dictionary if save_json: # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... - image_id = int(Path(paths[si]).stem.split('_')[-1]) + image_id = Path(paths[si]).stem box = pred[:, :4].clone() # xyxy scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape box = xyxy2xywh(box) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(pred.tolist(), box.tolist()): - jdict.append({'image_id': image_id, + jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id, 'category_id': coco91class[int(p[5])], 'bbox': [round(x, 3) for x in b], 'score': round(p[4], 5)}) @@ -200,8 +200,7 @@ def test(data, print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) # Save JSON - if save_json and map50 and len(jdict): - imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] + if save_json and len(jdict): f = 'detections_val2017_%s_results.json' % \ (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename print('\nCOCO mAP with pycocotools... saving %s...' % f) @@ -212,6 +211,7 @@ def test(data, from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval + imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api cocoDt = cocoGt.loadRes(f) # initialize COCO pred api cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') @@ -220,9 +220,8 @@ def test(data, cocoEval.accumulate() cocoEval.summarize() map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) - except: - print('WARNING: pycocotools must be installed with numpy==1.17 to run correctly. ' - 'See https://github.com/cocodataset/cocoapi/issues/356') + except Exception as e: + print('ERROR: pycocotools unable to run: %s' % e) # Return results model.float() # for training diff --git a/train.py b/train.py index 40f82bbed9c9..ac381b316fd7 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter import torch.multiprocessing as mp @@ -25,6 +26,7 @@ print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') mixed_precision = False # not installed + def train(local_rank, hyp, opt, device): print(f'Hyperparameters {hyp}') if local_rank in [-1, 0]: @@ -58,11 +60,12 @@ def train(local_rank, hyp, opt, device): torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:9999', rank=local_rank, world_size=opt.world_size) # distributed backend + # TODO: Init DDP logging. Only the first process is allowed to log. # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs. # Configure - init_seeds(2+local_rank) + init_seeds(2 + local_rank) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict train_path = data_dict['train'] @@ -124,7 +127,7 @@ def train(local_rank, hyp, opt, device): # load model try: ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() - if model.state_dict()[k].shape == v.shape} # to FP32, filter + if k in model.state_dict() and model.state_dict()[k].shape == v.shape} model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ @@ -160,7 +163,6 @@ def train(local_rank, hyp, opt, device): scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) - # Exponential moving average # From https://github.com/rwightman/pytorch-image-models/blob/master/train.py: # "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper" @@ -175,20 +177,22 @@ def train(local_rank, hyp, opt, device): model = DDP(model, device_ids=[local_rank], output_device=local_rank) elif (opt.parallel): model = DP(model) - + # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, - cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size) + cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, + world_size=opt.world_size) + mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches - assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) + assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) # Testloader if local_rank in [-1, 0]: # local_rank is set to -1. Because only the first process is expected to do evaluation. - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, - cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] - + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, + cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0] + # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model @@ -233,7 +237,8 @@ def train(local_rank, hyp, opt, device): if local_rank in [-1, 0]: w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) - dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx + dataset.indices = random.choices(range(dataset.n), weights=image_weights, + k=dataset.n) # rand weighted idx # Broadcast. if local_rank != -1: indices = torch.zeros([dataset.n], dtype=torch.int) @@ -248,6 +253,7 @@ def train(local_rank, hyp, opt, device): # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = torch.zeros(4, device=device) # mean losses + if opt.distributed: dataloader.sampler.set_epoch(epoch) pbar = enumerate(dataloader) @@ -393,7 +399,7 @@ def train(local_rank, hyp, opt, device): plot_results() # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) - dist.destroy_process_group() if local_rank not in [-1,0] else None + dist.destroy_process_group() if opt.distributed else None torch.cuda.empty_cache() return results @@ -449,7 +455,6 @@ def run(fn, hyp, opt, device): parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') # Parameter For DDP. - # parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.") parser.add_argument("--distributed", action="store_true", help="Set ddp mode") opt = parser.parse_args() @@ -458,8 +463,6 @@ def run(fn, hyp, opt, device): if last and not opt.weights: print(f'Resuming training from {last}') opt.weights = last if opt.resume and not opt.weights else opt.weights - # if opt.local_rank in [-1, 0]: - # check_git_status() check_git_status() opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file @@ -479,16 +482,6 @@ def run(fn, hyp, opt, device): if (opt.distributed): assert torch.cuda.is_available() and torch.cuda.device_count() > 1, "DDP is not available" if device.type == 'cpu': mixed_precision = False - # elif opt.local_rank != -1: - # # DDP mode - # assert torch.cuda.device_count() > opt.local_rank - # torch.cuda.set_device(opt.local_rank) - # device = torch.device("cuda", opt.local_rank) - # dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - # opt.world_size = dist.get_world_size() - # assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" - # opt.batch_size = opt.total_batch_size // opt.world_size elif torch.cuda.is_available() and torch.cuda.device_count() > 1: opt.parallel = True if (opt.distributed): diff --git a/utils/datasets.py b/utils/datasets.py index d0d647fb9964..a0a077528b9e 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -17,7 +17,7 @@ from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' -img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng'] +img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff','.dng'] vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv'] # Get orientation exif tag diff --git a/utils/google_utils.py b/utils/google_utils.py index 0a3dec1d4bab..ca9600b35a13 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -51,7 +51,7 @@ def gdrive_download(id='1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', name='coco128.zip'): s = "curl -Lb ./cookie \"drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=%s\" -o %s" % ( id, name) else: # small file - s = "curl -s -L -o %s 'drive.google.com/uc?export=download&id=%s'" % (name, id) + s = 'curl -s -L -o %s "drive.google.com/uc?export=download&id=%s"' % (name, id) r = os.system(s) # execute, capture return values os.remove('cookie') if os.path.exists('cookie') else None diff --git a/weights/download_weights.sh b/weights/download_weights.sh index 6834ddb37bb2..206b7002aeca 100755 --- a/weights/download_weights.sh +++ b/weights/download_weights.sh @@ -1,8 +1,10 @@ #!/bin/bash # Download common models -python3 -c "from utils.google_utils import *; +python -c " +from utils.google_utils import *; attempt_download('weights/yolov5s.pt'); attempt_download('weights/yolov5m.pt'); attempt_download('weights/yolov5l.pt'); -attempt_download('weights/yolov5x.pt')" +attempt_download('weights/yolov5x.pt') +"