Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jul 13, 2020
2 parents 094079b + e169edf commit ea34f84
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 38 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This repository represents Ultralytics open-source research into future object d
| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B
| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B
| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.4** | **48.4** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B


** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
Expand Down Expand Up @@ -54,10 +54,11 @@ $ pip install -U -r requirements.txt

Inference can be run on most common media formats. Model [checkpoints](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) are downloaded automatically if available. Results are saved to `./inference/output`.
```bash
$ python detect.py --source file.jpg # image
$ python detect.py --source 0 # webcam
file.jpg # image
file.mp4 # video
./dir # directory
0 # webcam
path/ # directory
path/*.jpg # glob
rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa # rtsp stream
http://112.50.243.8/PLTV/88888888/224/3221225900/1.m3u8 # http stream
```
Expand Down
6 changes: 3 additions & 3 deletions data/coco.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# COCO 2017 dataset http://cocodataset.org
# Download command: bash yolov5/data/get_coco2017.sh
# Train command: python train.py --data ./data/coco.yaml
# Dataset should be placed next to yolov5 folder:
# Train command: python train.py --data coco.yaml
# Default dataset location is next to /yolov5:
# /parent_folder
# /coco
# /yolov5


# train and val datasets (image directory or *.txt file with image paths)
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../coco/train2017.txt # 118k images
val: ../coco/val2017.txt # 5k images
test: ../coco/test-dev2017.txt # 20k images for submission to https://competitions.codalab.org/competitions/20794
Expand Down
12 changes: 6 additions & 6 deletions data/coco128.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# COCO 2017 dataset http://cocodataset.org - first 128 training images
# Download command: python -c "from yolov5.utils.google_utils import gdrive_download; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f','coco128.zip')"
# Train command: python train.py --data ./data/coco128.yaml
# Dataset should be placed next to yolov5 folder:
# Download command: python -c "from yolov5.utils.google_utils import *; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')"
# Train command: python train.py --data coco128.yaml
# Default dataset location is next to /yolov5:
# /parent_folder
# /coco128
# /yolov5


# train and val datasets (image directory or *.txt file with image paths)
train: ../coco128/images/train2017/
val: ../coco128/images/train2017/
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../coco128/images/train2017/ # 128 images
val: ../coco128/images/train2017/ # 128 images

# number of classes
nc: 80
Expand Down
5 changes: 3 additions & 2 deletions data/get_coco2017.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/bin/bash
# COCO 2017 dataset http://cocodataset.org
# Download command: bash yolov5/data/get_coco2017.sh
# Train command: python train.py --data ./data/coco.yaml
# Dataset should be placed next to yolov5 folder:
# Train command: python train.py --data coco.yaml
# Default dataset location is next to /yolov5:
# /parent_folder
# /coco
# /yolov5


# Download labels from Google Drive, accepting presented query
filename="coco2017labels.zip"
fileid="1cXZR_ckHki6nddOmcysCuuJFM--T-Q6L"
Expand Down
3 changes: 2 additions & 1 deletion data/get_voc.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
# Download command: bash ./data/get_voc.sh
# Train command: python train.py --data voc.yaml
# Dataset should be placed next to yolov5 folder:
# Default dataset location is next to /yolov5:
# /parent_folder
# /VOC
# /yolov5


start=`date +%s`

# handle optional download dir
Expand Down
9 changes: 5 additions & 4 deletions data/voc.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
# Download command: bash ./data/get_voc.sh
# Train command: python train.py --data voc.yaml
# Dataset should be placed next to yolov5 folder:
# Default dataset location is next to /yolov5:
# /parent_folder
# /VOC
# /yolov5

# train and val datasets (image directory or *.txt file with image paths)
train: ../VOC/images/train/
val: ../VOC/images/val/

# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../VOC/images/train/ # 16551 images
val: ../VOC/images/val/ # 4952 images

# number of classes
nc: 20
Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def test(data,
# model = nn.DataParallel(model)

# Half
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU
half = device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half() # to FP16
model.half()

# Configure
model.eval()
Expand Down
43 changes: 27 additions & 16 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,39 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa

class LoadImages: # for inference
def __init__(self, path, img_size=640):
path = str(Path(path)) # os-agnostic
files = []
if os.path.isdir(path):
files = sorted(glob.glob(os.path.join(path, '*.*')))
elif os.path.isfile(path):
files = [path]
p = str(Path(path)) # os-agnostic
p = os.path.abspath(p) # absolute path
if '*' in p:
files = sorted(glob.glob(p)) # glob
elif os.path.isdir(p):
files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
elif os.path.isfile(p):
files = [p] # files
else:
raise Exception('ERROR: %s does not exist' % p)

images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
nI, nV = len(images), len(videos)
ni, nv = len(images), len(videos)

self.img_size = img_size
self.files = images + videos
self.nF = nI + nV # number of files
self.video_flag = [False] * nI + [True] * nV
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'images'
if any(videos):
self.new_video(videos[0]) # new video
else:
self.cap = None
assert self.nF > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
(path, img_formats, vid_formats)
assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
(p, img_formats, vid_formats)

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count == self.nF:
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]

Expand All @@ -107,22 +111,22 @@ def __next__(self):
if not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nF: # last video
if self.count == self.nf: # last video
raise StopIteration
else:
path = self.files[self.count]
self.new_video(path)
ret_val, img0 = self.cap.read()

self.frame += 1
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nF, self.frame, self.nframes, path), end='')
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, 'Image Not Found ' + path
print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
print('image %g/%g %s: ' % (self.count, self.nf, path), end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
Expand All @@ -140,7 +144,7 @@ def new_video(self, path):
self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

def __len__(self):
return self.nF # number of files
return self.nf # number of files


class LoadWebcam: # for inference
Expand Down Expand Up @@ -470,6 +474,13 @@ def __getitem__(self, index):
img, labels = load_mosaic(self, index)
shapes = None

# MixUp https://arxiv.org/pdf/1710.09412.pdf
# if random.random() < 0.5:
# img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
# r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
# img = (img * r + img2 * (1 - r)).astype(np.uint8)
# labels = np.concatenate((labels, labels2), 0)

else:
# Load image
img, (h0, w0), (h, w) = load_image(self, index)
Expand Down
2 changes: 2 additions & 0 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class ModelEMA:
def __init__(self, model, decay=0.9999, updates=0):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
if next(model.parameters()).device.type != 'cpu':
self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
Expand Down

0 comments on commit ea34f84

Please sign in to comment.