Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ultralytics/yolov5 into f…
Browse files Browse the repository at this point in the history
…eature/DDP_fixed
  • Loading branch information
yizhi.chen committed Jul 3, 2020
2 parents 625bb49 + 3bdea3f commit e838055
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 59 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ data/samples/*
# Neural Network weights -----------------------------------------------------------------------------------------------
**/*.weights
**/*.pt
**/*.pth
**/*.onnx
**/*.mlmodel
**/*.torchscript


# Below Copied From .gitignore -----------------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ gcp_test*.sh
*.pt
*.onnx
*.mlmodel
*.torchscript
darknet53.conv.74
yolov3-tiny.conv.15

Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
FROM nvcr.io/nvidia/pytorch:20.06-py3
FROM nvcr.io/nvidia/pytorch:20.03-py3
RUN pip install -U gsutil

# Create working directory
Expand Down Expand Up @@ -47,4 +47,4 @@ COPY . /usr/src/app
# sudo docker commit 6d525e299258 user/test_image && sudo docker run -it --gpus all --ipc=host -v "$(pwd)"/coco:/usr/src/coco --entrypoint=sh user/test_image

# Clean up
# docker system prune -a --volumes
# docker system prune -a --volumes
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ This repository represents Ultralytics open-source research into future object d


** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001`
** Speed<sub>GPU</sub> measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --data data/coco.yaml --img 736 --conf 0.001`
** Speed<sub>GPU</sub> measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --data data/coco.yaml --img 640 --conf 0.1`
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).


Expand Down
6 changes: 3 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def detect(save_img=False):
with torch.no_grad():
detect()

# Update all models
# # Update all models
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
# detect()
# create_pretrained(opt.weights, opt.weights)
# detect()
# create_pretrained(opt.weights, opt.weights)
19 changes: 12 additions & 7 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# This file contains modules common to various models


from utils.utils import *


def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p


def DWConv(c1, c2, k=1, s=1, act=True):
# Depthwise convolution
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # padding
self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity()

Expand Down Expand Up @@ -46,7 +51,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(c2, c2, 1, 1)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
Expand Down Expand Up @@ -79,9 +84,9 @@ def forward(self, x):

class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Focus, self).__init__()
self.conv = Conv(c1 * 4, c2, k, 1)
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)

def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
Expand Down
34 changes: 34 additions & 0 deletions models/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
# This file contains experimental modules

from models.common import *


class CrossConv(nn.Module):
# Cross Convolution
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super(CrossConv, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, 3), 1)
self.cv2 = Conv(c_, c2, (3, 1), 1, g=g)
self.add = shortcut and c1 == c2

def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
# Cross Convolution CSP
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(C3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*[CrossConv(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

def forward(self, x):
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))


class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs
Expand Down
24 changes: 12 additions & 12 deletions models/export.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Exports a YOLOv5 *.pt model to *.onnx and *.torchscript formats
"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
"""

import argparse

import onnx

from models.common import *
from utils import google_utils

Expand All @@ -21,7 +19,7 @@
print(opt)

# Input
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection

# Load PyTorch model
google_utils.attempt_download(opt.weights)
Expand All @@ -30,26 +28,28 @@
model.model[-1].export = True # set Detect() layer export=True
_ = model(img) # dry run

# Export to torchscript
# TorchScript export
try:
f = opt.weights.replace('.pt', '.torchscript') # filename
ts = torch.jit.trace(model, img)
ts.save(f)
print('Torchscript export success, saved as %s' % f)
except:
print('Torchscript export failed.')
print('TorchScript export success, saved as %s' % f)
except Exception as e:
print('TorchScript export failed: %s' % e)

# Export to ONNX
# ONNX export
try:
import onnx

f = opt.weights.replace('.pt', '.onnx') # filename
model.fuse() # only for ONNX
torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['output']) # output_names=['classes', 'boxes']

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable representation of the graph
print('ONNX export success, saved as %s\nView with https://github.com/lutzroeder/netron' % f)
except:
print('ONNX export failed.')
except Exception as e:
print('ONNX export failed: %s' % e)
50 changes: 22 additions & 28 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,8 @@ def train(hyp, tb_writer, opt, device):
total_batch_size = opt.batch_size if opt.local_rank == -1 else opt.batch_size * torch.distributed.get_world_size() # 64
weights = opt.weights # initial training weights

if opt.local_rank in [-1, 0]:
# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging is skipped here.
pass
else:
tb_writer = None

# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging is skipped here.

# Configure
init_seeds(1)
Expand All @@ -84,13 +79,13 @@ def train(hyp, tb_writer, opt, device):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes

# Remove previous results
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f)
if opt.local_rank in [-1, 0]:
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f)

# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
model.names = data_dict['names']

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
Expand Down Expand Up @@ -138,7 +133,7 @@ def train(hyp, tb_writer, opt, device):
model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e:
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
"Please delete or update %s and try again, or use --weights '' to train from scatch." \
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
% (opt.weights, opt.cfg, opt.weights, opt.weights)
raise KeyError(s) from e

Expand Down Expand Up @@ -205,6 +200,7 @@ def train(hyp, tb_writer, opt, device):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

# Class frequency
if tb_writer:
Expand Down Expand Up @@ -326,10 +322,9 @@ def train(hyp, tb_writer, opt, device):
batch_size=total_batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
single_cls=opt.single_cls,
dataloader=testloader)

# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
Expand Down Expand Up @@ -368,22 +363,21 @@ def train(hyp, tb_writer, opt, device):
# end epoch ----------------------------------------------------------------------------------------------------
# end training

results = None
if opt.local_rank in [-1, 0]:
n = opt.name
if len(n):
n = '_' + n if not n.isnumeric() else n
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
if os.path.exists(f1):
os.rename(f1, f2) # rename
ispt = f2.endswith('.pt') # is *.pt
strip_optimizer(f2) if ispt else None # strip optimizer
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
# Strip optimizers
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
if os.path.exists(f1):
os.rename(f1, f2) # rename
ispt = f2.endswith('.pt') # is *.pt
strip_optimizer(f2) if ispt else None # strip optimizer
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
# Finish
if not opt.evolve:
plot_results() # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.local_rank == -1:
if opt.local_rank == 0:
dist.destroy_process_group()
torch.cuda.empty_cache()
return results
Expand Down Expand Up @@ -414,16 +408,16 @@ def train(hyp, tb_writer, opt, device):
# Parameter For DDP.
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
opt = parser.parse_args()
opt.weights = last if opt.resume else opt.weights
opt.weights = last if opt.resume and not opt.weights else opt.weights
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
print(opt)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
# If local_rank is not -1, the DDP mode is triggered. Use local_rank to overwrite the opt.device config.
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
# DDP mode
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda")
Expand All @@ -435,10 +429,10 @@ def train(hyp, tb_writer, opt, device):
# Train
if not opt.evolve:
if opt.local_rank in [-1, 0]:
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter(comment=opt.name)
else:
tb_writer = None
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
train(hyp, tb_writer, opt, device)

# Evolve hyperparameters (optional)
Expand Down
14 changes: 9 additions & 5 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def time_synchronized():
return time.time()


def is_parallel(model):
# is model is parallel with DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)


def initialize_weights(model):
for m in model.modules():
t = type(m)
Expand Down Expand Up @@ -111,8 +116,8 @@ def model_info(model, verbose=False):

try: # FLOPS
from thop import profile
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
except:
fs = ''

Expand Down Expand Up @@ -187,7 +192,6 @@ def update(self, model):
with torch.no_grad():
msd = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
esd = self.ema.module.state_dict() if hasattr(self.ema, 'module') else self.ema.state_dict()

for k, v in esd.items():
if v.dtype.is_floating_point:
v *= d
Expand All @@ -196,6 +200,6 @@ def update(self, model):
def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
if not k.startswith('_') and not isinstance(getattr(model, k),
(torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer)):
if not k.startswith('_') and (k != 'module' or not isinstance(getattr(model, k),
(torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))):
setattr(self.ema, k, getattr(model, k))

0 comments on commit e838055

Please sign in to comment.