Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SparseML Integration to V6.1 #26

Merged
merged 17 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions data/hyps/hyp.finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Hyperparameters for VOC finetuning
# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials


# Hyperparameter Evolution Results
# Generations: 306
# P R mAP.5 mAP.5:.95 box obj cls
# Metrics: 0.6 0.936 0.896 0.684 0.0115 0.00805 0.00146

lr0: 0.0032
lrf: 0.12
momentum: 0.843
weight_decay: 0.00036
warmup_epochs: 2.0
warmup_momentum: 0.5
warmup_bias_lr: 0.05
box: 0.0296
cls: 0.243
cls_pw: 0.631
obj: 0.301
obj_pw: 0.911
iou_t: 0.2
anchor_t: 2.91
# anchors: 3.63
fl_gamma: 0.0
hsv_h: 0.0138
hsv_s: 0.664
hsv_v: 0.464
degrees: 0.373
translate: 0.245
scale: 0.898
shear: 0.602
perspective: 0.0
flipud: 0.00856
fliplr: 0.5
mosaic: 1.0
mixup: 0.243
copy_paste: 0.0
34 changes: 34 additions & 0 deletions data/hyps/hyp.scratch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Hyperparameters for COCO training from scratch
# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials


lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
box: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0
3 changes: 2 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
from export import load_checkpoint


@torch.no_grad()
Expand Down Expand Up @@ -89,7 +90,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
model, extras = load_checkpoint(type_='val', weights=weights, device=device) # load FP32 model
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size

Expand Down
173 changes: 156 additions & 17 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"""

import argparse
from copy import deepcopy
import json
import os
import platform
Expand All @@ -57,20 +58,26 @@
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

from sparseml.pytorch.utils import ModuleExporter
from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.common import Conv
from models.common import Conv, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import Detect
from models.yolo import Detect, Model
from utils.activations import SiLU
from utils.datasets import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
file_size, print_args, url2file)
from utils.torch_utils import select_device
file_size, print_args, url2file, intersect_dicts)
from utils.torch_utils import select_device, torch_distributed_zero_first, is_parallel
from utils.downloads import attempt_download
from utils.sparse import SparseMLWrapper, check_download_sparsezoo_weights



def export_formats():
Expand Down Expand Up @@ -118,14 +125,33 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)
# export through SparseML so quantized and pruned graphs can be corrected
save_dir = f.parent.absolute()
save_name = str(f).split(os.path.sep)[-1]

# get the number of outputs so we know how to name and change dynamic axes
# nested outputs can be returned if model is exported with dynamic
def _count_outputs(outputs):
count = 0
if isinstance(outputs, list) or isinstance(outputs, tuple):
for out in outputs:
count += _count_outputs(out)
else:
count += 1
return count

outputs = model(im)
num_outputs = _count_outputs(outputs)
input_names = ['input']
output_names = [f'out_{i}' for i in range(num_outputs)]
dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if dynamic else None
exporter = ModuleExporter(model, save_dir)
exporter.export_onnx(im, name=save_name, convert_qat=True,
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
try:
skip_onnx_input_quantize(f, f)
except:
pass

# Checks
model_onnx = onnx.load(f) # load onnx model
Expand Down Expand Up @@ -407,14 +433,123 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
yaml = ckpt_model.yaml
if not pickle:
ckpt_model = ckpt_model.state_dict()

return {'epoch': epoch,
'model': ckpt_model,
'optimizer': optimizer.state_dict(),
'yaml': yaml,
'hyp': model.hyp,
**ema.state_dict(pickle),
**sparseml_wrapper.state_dict(),
**kwargs}

def load_checkpoint(
type_,
weights,
device,
cfg=None,
hyp=None,
nc=None,
data=None,
dnn=False,
half = False,
recipe=None,
resume=None,
rank=-1
):
with torch_distributed_zero_first(rank):
# download if not found locally or from sparsezoo if stub
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
else weights, map_location="cpu") # load checkpoint
start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
pickled = isinstance(ckpt['model'], nn.Module)
train_type = type_ == 'train'
ensemble_type = type_ == 'ensemble'
val_type = type_ =='val'

if pickled and ensemble_type:
cfg = None
if ensemble_type:
model = attempt_load(weights, map_location=device) # load ensemble using pickled
state_dict = model.state_dict()
elif val_type:
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
state_dict = model.model.state_dict()
else:
# load model from config and weights
cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \
(ckpt['model'].yaml if pickled else None)
model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc,
anchors=hyp.get('anchors') if hyp else None).to(device)
model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt['ema']) else 'model'
state_dict = ckpt[model_key].float().state_dict() if pickled else ckpt[model_key]
if val_type:
model = DetectMultiBackend(model=model, device=device, dnn=dnn, data=data, fp16=half)

# turn gradients for params back on in case they were removed
for p in model.parameters():
p.requires_grad = True

# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None
elif ckpt['recipe'] or recipe:
train_recipe, checkpoint_recipe = recipe, ckpt['recipe']

sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, checkpoint_recipe, train_recipe)
exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
loaded = False

sparseml_wrapper.apply_checkpoint_structure(float("inf"))
if train_type:
# intialize the recipe for training and restore the weights before if no quantized weights
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])
if not quantized_state_dict:
state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors)
loaded = True
sparseml_wrapper.initialize(start_epoch)

if not loaded:
state_dict = load_state_dict(model, state_dict, train=train_type, exclude_anchors=exclude_anchors)

model.float()
report = 'Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)

return model, {
'ckpt': ckpt,
'state_dict': state_dict,
'sparseml_wrapper': sparseml_wrapper,
'report': report,
}


def load_state_dict(model, state_dict, train, exclude_anchors):
# fix older state_dict names not porting to the new model setup
state_dict = {key if not key.startswith("module.") else key[7:]: val for key, val in state_dict.items()}

if train:
# load any missing weights from the model
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=['anchor'] if exclude_anchors else [])

model.load_state_dict(state_dict, strict=not train) # load

return state_dict

@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
imgsz=(640, 640), # image (height, width)
batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx'), # include formats
include=('onnx'), # include formats
half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode
Expand All @@ -430,7 +565,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25 # TF.js NMS: confidence threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
remove_grid=False,
):
t = time.time()
include = [x.lower() for x in include] # to lowercase
Expand All @@ -443,8 +579,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
# Load PyTorch model
device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
nc, names = model.nc, model.names # number of classes, class names
model, extras = load_checkpoint(type_='ensemble', weights=weights, device=device) # load FP32 model
sparseml_wrapper = extras['sparseml_wrapper']
nc, names = extras["ckpt"]["nc"], model.names # number of classes, class names

# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
Expand All @@ -469,6 +606,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
m.onnx_dynamic = dynamic
if hasattr(m, 'forward_export'):
m.forward = m.forward_export # assign custom forward (optional)
model.model[-1].export = not remove_grid # set Detect() layer grid export

for _ in range(2):
y = model(im) # dry runs
Expand Down Expand Up @@ -541,6 +679,7 @@ def parse_opt():
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument("--remove-grid", action="store_true", help="remove export of Detect() layer grid")
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
Expand All @@ -556,4 +695,4 @@ def main(opt):

if __name__ == "__main__":
opt = parse_opt()
main(opt)
main(opt)
Loading