diff --git a/.dockerignore b/.dockerignore index a68626df5f2e..42f241f28c7b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -14,8 +14,10 @@ data/samples/* # Neural Network weights ----------------------------------------------------------------------------------------------- **/*.weights **/*.pt +**/*.pth **/*.onnx **/*.mlmodel +**/*.torchscript # Below Copied From .gitignore ----------------------------------------------------------------------------------------- diff --git a/.gitignore b/.gitignore index 5a95798f0f61..07993ab27f15 100755 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ gcp_test*.sh *.pt *.onnx *.mlmodel +*.torchscript darknet53.conv.74 yolov3-tiny.conv.15 diff --git a/Dockerfile b/Dockerfile index 01551a0e49e4..357c6dbc4cb9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch -FROM nvcr.io/nvidia/pytorch:20.06-py3 +FROM nvcr.io/nvidia/pytorch:20.03-py3 RUN pip install -U gsutil # Create working directory @@ -47,4 +47,4 @@ COPY . /usr/src/app # sudo docker commit 6d525e299258 user/test_image && sudo docker run -it --gpus all --ipc=host -v "$(pwd)"/coco:/usr/src/coco --entrypoint=sh user/test_image # Clean up -# docker system prune -a --volumes \ No newline at end of file +# docker system prune -a --volumes diff --git a/README.md b/README.md index 1e29d1835196..6306e55ec866 100755 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ This repository represents Ultralytics open-source research into future object d ** APtest denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy. -** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001` -** SpeedGPU measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1` +** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --data data/coco.yaml --img 736 --conf 0.001` +** SpeedGPU measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --data data/coco.yaml --img 640 --conf 0.1` ** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). diff --git a/detect.py b/detect.py index bb84a0df0c2c..2650c202d49d 100644 --- a/detect.py +++ b/detect.py @@ -158,7 +158,7 @@ def detect(save_img=False): with torch.no_grad(): detect() - # Update all models + # # Update all models # for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: - # detect() - # create_pretrained(opt.weights, opt.weights) + # detect() + # create_pretrained(opt.weights, opt.weights) diff --git a/models/common.py b/models/common.py index 3c4a0d729210..2c2d600394c1 100644 --- a/models/common.py +++ b/models/common.py @@ -1,9 +1,15 @@ # This file contains modules common to various models - from utils.utils import * +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + def DWConv(c1, c2, k=1, s=1, act=True): # Depthwise convolution return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) @@ -11,10 +17,9 @@ def DWConv(c1, c2, k=1, s=1, act=True): class Conv(nn.Module): # Standard convolution - def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(Conv, self).__init__() - p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # padding - self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity() @@ -46,7 +51,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) - self.cv4 = Conv(c2, c2, 1, 1) + self.cv4 = Conv(2 * c_, c2, 1, 1) self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) self.act = nn.LeakyReLU(0.1, inplace=True) self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) @@ -79,9 +84,9 @@ def forward(self, x): class Focus(nn.Module): # Focus wh information into c-space - def __init__(self, c1, c2, k=1): + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(Focus, self).__init__() - self.conv = Conv(c1 * 4, c2, k, 1) + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) diff --git a/models/experimental.py b/models/experimental.py index 60cb7aa14cd5..cff9d141446d 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -1,6 +1,40 @@ +# This file contains experimental modules + from models.common import * +class CrossConv(nn.Module): + # Cross Convolution + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super(CrossConv, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, 3), 1) + self.cv2 = Conv(c_, c2, (3, 1), 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class C3(nn.Module): + # Cross Convolution CSP + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(C3, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*[CrossConv(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs diff --git a/models/export.py b/models/export.py index 2aa6ce403ac6..bb310f3f89a0 100644 --- a/models/export.py +++ b/models/export.py @@ -1,4 +1,4 @@ -"""Exports a YOLOv5 *.pt model to *.onnx and *.torchscript formats +"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats Usage: $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 @@ -6,8 +6,6 @@ import argparse -import onnx - from models.common import * from utils import google_utils @@ -21,7 +19,7 @@ print(opt) # Input - img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection + img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection # Load PyTorch model google_utils.attempt_download(opt.weights) @@ -30,20 +28,22 @@ model.model[-1].export = True # set Detect() layer export=True _ = model(img) # dry run - # Export to torchscript + # TorchScript export try: f = opt.weights.replace('.pt', '.torchscript') # filename ts = torch.jit.trace(model, img) ts.save(f) - print('Torchscript export success, saved as %s' % f) - except: - print('Torchscript export failed.') + print('TorchScript export success, saved as %s' % f) + except Exception as e: + print('TorchScript export failed: %s' % e) - # Export to ONNX + # ONNX export try: + import onnx + f = opt.weights.replace('.pt', '.onnx') # filename model.fuse() # only for ONNX - torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'], + torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], output_names=['output']) # output_names=['classes', 'boxes'] # Checks @@ -51,5 +51,5 @@ onnx.checker.check_model(onnx_model) # check onnx model print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable representation of the graph print('ONNX export success, saved as %s\nView with https://github.com/lutzroeder/netron' % f) - except: - print('ONNX export failed.') + except Exception as e: + print('ONNX export failed: %s' % e) diff --git a/train.py b/train.py index 0c63b80ae27b..b7d202d2fbe4 100644 --- a/train.py +++ b/train.py @@ -67,13 +67,8 @@ def train(hyp, tb_writer, opt, device): total_batch_size = opt.batch_size if opt.local_rank == -1 else opt.batch_size * torch.distributed.get_world_size() # 64 weights = opt.weights # initial training weights - if opt.local_rank in [-1, 0]: - # TODO: Init DDP logging. Only the first process is allowed to log. - # Since I see lots of print here, the logging is skipped here. - pass - else: - tb_writer = None - + # TODO: Init DDP logging. Only the first process is allowed to log. + # Since I see lots of print here, the logging is skipped here. # Configure init_seeds(1) @@ -84,13 +79,13 @@ def train(hyp, tb_writer, opt, device): nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes # Remove previous results - for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): - os.remove(f) + if opt.local_rank in [-1, 0]: + for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): + os.remove(f) # Create model model = Model(opt.cfg).to(device) assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc']) - model.names = data_dict['names'] # Image sizes gs = int(max(model.stride)) # grid size (max stride) @@ -138,7 +133,7 @@ def train(hyp, tb_writer, opt, device): model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ - "Please delete or update %s and try again, or use --weights '' to train from scatch." \ + "Please delete or update %s and try again, or use --weights '' to train from scratch." \ % (opt.weights, opt.cfg, opt.weights, opt.weights) raise KeyError(s) from e @@ -205,6 +200,7 @@ def train(hyp, tb_writer, opt, device): model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + model.names = data_dict['names'] # Class frequency if tb_writer: @@ -326,10 +322,9 @@ def train(hyp, tb_writer, opt, device): batch_size=total_batch_size, imgsz=imgsz_test, save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), - model=ema.ema, + model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema, single_cls=opt.single_cls, dataloader=testloader) - # Write with open(results_file, 'a') as f: f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) @@ -368,22 +363,21 @@ def train(hyp, tb_writer, opt, device): # end epoch ---------------------------------------------------------------------------------------------------- # end training - results = None if opt.local_rank in [-1, 0]: - n = opt.name - if len(n): - n = '_' + n if not n.isnumeric() else n - fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n - for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): - if os.path.exists(f1): - os.rename(f1, f2) # rename - ispt = f2.endswith('.pt') # is *.pt - strip_optimizer(f2) if ispt else None # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload + # Strip optimizers + n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name + fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n + for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): + if os.path.exists(f1): + os.rename(f1, f2) # rename + ispt = f2.endswith('.pt') # is *.pt + strip_optimizer(f2) if ispt else None # strip optimizer + os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload + # Finish if not opt.evolve: plot_results() # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) - if opt.local_rank == -1: + if opt.local_rank == 0: dist.destroy_process_group() torch.cuda.empty_cache() return results @@ -414,16 +408,16 @@ def train(hyp, tb_writer, opt, device): # Parameter For DDP. parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") opt = parser.parse_args() - opt.weights = last if opt.resume else opt.weights + opt.weights = last if opt.resume and not opt.weights else opt.weights opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file print(opt) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) - # If local_rank is not -1, the DDP mode is triggered. Use local_rank to overwrite the opt.device config. device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) if device.type == 'cpu': mixed_precision = False elif opt.local_rank != -1: + # DDP mode assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) device = torch.device("cuda") @@ -435,10 +429,10 @@ def train(hyp, tb_writer, opt, device): # Train if not opt.evolve: if opt.local_rank in [-1, 0]: + print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') tb_writer = SummaryWriter(comment=opt.name) else: tb_writer = None - print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') train(hyp, tb_writer, opt, device) # Evolve hyperparameters (optional) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index a2f69c1a92cb..b9c1ad6155c5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -54,6 +54,11 @@ def time_synchronized(): return time.time() +def is_parallel(model): + # is model is parallel with DP or DDP + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + + def initialize_weights(model): for m in model.modules(): t = type(m) @@ -111,8 +116,8 @@ def model_info(model, verbose=False): try: # FLOPS from thop import profile - macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False) - fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) + flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2 + fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS except: fs = '' @@ -187,7 +192,6 @@ def update(self, model): with torch.no_grad(): msd = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() esd = self.ema.module.state_dict() if hasattr(self.ema, 'module') else self.ema.state_dict() - for k, v in esd.items(): if v.dtype.is_floating_point: v *= d @@ -196,6 +200,6 @@ def update(self, model): def update_attr(self, model): # Assign attributes (which may change during training) for k in model.__dict__.keys(): - if not k.startswith('_') and not isinstance(getattr(model, k), - (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer)): + if not k.startswith('_') and (k != 'module' or not isinstance(getattr(model, k), + (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))): setattr(self.ema, k, getattr(model, k))