diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index a83f997cbfc2..537ba96e7225 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -15,6 +15,7 @@ jobs: Benchmarks: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ ubuntu-latest ] python-version: [ '3.9' ] # requires python<=3.9 @@ -37,9 +38,12 @@ jobs: python --version pip --version pip list - - name: Run benchmarks + - name: Benchmark DetectionModel + run: | + python benchmarks.py --data coco128.yaml --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29 + - name: Benchmark SegmentationModel run: | - python utils/benchmarks.py --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29 + python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 Tests: timeout-minutes: 60 @@ -126,6 +130,20 @@ jobs: model(im) # warmup, build grids for trace torch.jit.trace(model, [im]) EOF + - name: Test segmentation + shell: bash # for Windows compatibility + run: | + m=${{ matrix.model }}-seg # official weights + b=runs/train-seg/exp/weights/best # best.pt checkpoint + python segment/train.py --imgsz 64 --batch 32 --weights $m.pt --cfg $m.yaml --epochs 1 --device cpu # train + python segment/train.py --imgsz 64 --batch 32 --weights '' --cfg $m.yaml --epochs 1 --device cpu # train + for d in cpu; do # devices + for w in $m $b; do # weights + python segment/val.py --imgsz 64 --batch 32 --weights $w.pt --device $d # val + python segment/predict.py --imgsz 64 --weights $w.pt --device $d # predict + python export.py --weights $w.pt --img 64 --include torchscript --device $d # export + done + done - name: Test classification shell: bash # for Windows compatibility run: | diff --git a/utils/benchmarks.py b/benchmarks.py similarity index 87% rename from utils/benchmarks.py rename to benchmarks.py index 9d5c7f2965d5..58e083c95d55 100644 --- a/utils/benchmarks.py +++ b/benchmarks.py @@ -34,16 +34,19 @@ import pandas as pd FILE = Path(__file__).resolve() -ROOT = FILE.parents[1] # YOLOv5 root directory +ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH # ROOT = ROOT.relative_to(Path.cwd()) # relative import export -import val +from models.experimental import attempt_load +from models.yolo import SegmentationModel +from segment.val import run as val_seg from utils import notebook_init from utils.general import LOGGER, check_yaml, file_size, print_args from utils.torch_utils import select_device +from val import run as val_det def run( @@ -59,6 +62,7 @@ def run( ): y, t = [], time.time() device = select_device(device) + model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc. for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU) try: assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported @@ -76,10 +80,14 @@ def run( assert suffix in str(w), 'export failed' # Validate - result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) - metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) - speeds = result[2] # times (preprocess, inference, postprocess) - y.append([name, round(file_size(w), 1), round(metrics[3], 4), round(speeds[1], 2)]) # MB, mAP, t_inference + if model_type == SegmentationModel: + result = val_seg(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) + metric = result[0][7] # (box(p, r, map50, map), mask(p, r, map50, map), *loss(box, obj, cls)) + else: # DetectionModel: + result = val_det(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) + metric = result[0][3] # (p, r, map50, map, *loss(box, obj, cls)) + speed = result[2][1] # times (preprocess, inference, postprocess) + y.append([name, round(file_size(w), 1), round(metric, 4), round(speed, 2)]) # MB, mAP, t_inference except Exception as e: if hard_fail: assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}' diff --git a/data/coco128-seg.yaml b/data/coco128-seg.yaml new file mode 100644 index 000000000000..5e81910cc456 --- /dev/null +++ b/data/coco128-seg.yaml @@ -0,0 +1,101 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +# COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Example usage: python train.py --data coco128.yaml +# parent +# ├── yolov5 +# └── datasets +# └── coco128-seg ← downloads here (7 MB) + + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco128-seg # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + + +# Download script/URL (optional) +download: https://ultralytics.com/assets/coco128-seg.zip diff --git a/detect.py b/detect.py index a69606a3dff9..310d169281bf 100644 --- a/detect.py +++ b/detect.py @@ -149,8 +149,8 @@ def run( det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() # Print results - for c in det[:, -1].unique(): - n = (det[:, -1] == c).sum() # detections per class + for c in det[:, 5].unique(): + n = (det[:, 5] == c).sum() # detections per class s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Write results diff --git a/models/common.py b/models/common.py index 8b7dbbfa95fe..0d90ff4f8827 100644 --- a/models/common.py +++ b/models/common.py @@ -375,7 +375,6 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, if batch_dim.is_static: batch_size = batch_dim.get_length() executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2 - output_layer = next(iter(executable_network.outputs)) stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata elif engine: # TensorRT LOGGER.info(f'Loading {w} for TensorRT inference...') @@ -491,7 +490,7 @@ def forward(self, im, augment=False, visualize=False): y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) elif self.xml: # OpenVINO im = im.cpu().numpy() # FP32 - y = self.executable_network([im])[self.output_layer] + y = list(self.executable_network([im]).values()) elif self.engine: # TensorRT if self.dynamic and im.shape != self.bindings['images'].shape: i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output')) @@ -786,8 +785,21 @@ def __str__(self): return '' +class Proto(nn.Module): + # YOLOv5 mask Proto module for segmentation models + def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks + super().__init__() + self.cv1 = Conv(c1, c_, k=3) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.cv2 = Conv(c_, c_, k=3) + self.cv3 = Conv(c_, c2) + + def forward(self, x): + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) + + class Classify(nn.Module): - # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + # YOLOv5 classification head, i.e. x(b,c1,20,20) to x(b,c2) def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups super().__init__() c_ = 1280 # efficientnet_b0 size diff --git a/models/segment/yolov5l-seg.yaml b/models/segment/yolov5l-seg.yaml new file mode 100644 index 000000000000..4782de11dd2d --- /dev/null +++ b/models/segment/yolov5l-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] diff --git a/models/segment/yolov5m-seg.yaml b/models/segment/yolov5m-seg.yaml new file mode 100644 index 000000000000..f73d1992ac19 --- /dev/null +++ b/models/segment/yolov5m-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/models/segment/yolov5n-seg.yaml b/models/segment/yolov5n-seg.yaml new file mode 100644 index 000000000000..c28225ab4a50 --- /dev/null +++ b/models/segment/yolov5n-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] diff --git a/models/segment/yolov5s-seg.yaml b/models/segment/yolov5s-seg.yaml new file mode 100644 index 000000000000..7cbdb36b425c --- /dev/null +++ b/models/segment/yolov5s-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.5 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/models/segment/yolov5x-seg.yaml b/models/segment/yolov5x-seg.yaml new file mode 100644 index 000000000000..5d0c4524a99c --- /dev/null +++ b/models/segment/yolov5x-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.33 # model depth multiple +width_multiple: 1.25 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] diff --git a/models/tf.py b/models/tf.py index ecb0d4d79c78..8cce147059d3 100644 --- a/models/tf.py +++ b/models/tf.py @@ -30,7 +30,7 @@ from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, DWConvTranspose2d, Focus, autopad) from models.experimental import MixConv2d, attempt_load -from models.yolo import Detect +from models.yolo import Detect, Segment from utils.activations import SiLU from utils.general import LOGGER, make_divisible, print_args @@ -320,6 +320,36 @@ def _make_grid(nx=20, ny=20): return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32) +class TFSegment(TFDetect): + # YOLOv5 Segment head for segmentation models + def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None): + super().__init__(nc, anchors, ch, imgsz, w) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.no = 5 + nc + self.nm # number of outputs per anchor + self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv + self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos + self.detect = TFDetect.call + + def call(self, x): + p = self.proto(x[0]) + x = self.detect(self, x) + return (x, p) if self.training else ((x[0], p),) + + +class TFProto(keras.layers.Layer): + + def __init__(self, c1, c_=256, c2=32, w=None): + super().__init__() + self.cv1 = TFConv(c1, c_, k=3, w=w.cv1) + self.upsample = TFUpsample(None, scale_factor=2, mode='nearest') + self.cv2 = TFConv(c_, c_, k=3, w=w.cv2) + self.cv3 = TFConv(c_, c2, w=w.cv3) + + def call(self, inputs): + return self.cv3(self.cv2(self.upsample(self.cv1(inputs)))) + + class TFUpsample(keras.layers.Layer): # TF version of torch.nn.Upsample() def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w' @@ -377,10 +407,12 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) - elif m is Detect: + elif m in [Detect, Segment]: args.append([ch[x + 1] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) + if m is Segment: + args[3] = make_divisible(args[3] * gw, 8) args.append(imgsz) else: c2 = ch[f] diff --git a/models/yolo.py b/models/yolo.py index fa05fcf9a8d9..a0702a7c0257 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -36,6 +36,7 @@ class Detect(nn.Module): + # YOLOv5 Detect head for detection models stride = None # strides computed during build dynamic = False # force grid reconstruction export = False # export mode @@ -63,15 +64,16 @@ def forward(self, x): if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) - y = x[i].sigmoid() + y = x[i].clone() + y[..., :5 + self.nc].sigmoid_() if self.inplace: y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 - xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0 + xy, wh, etc = y.split((2, 2, self.no - 4), 4) # tensor_split((2, 4, 5), 4) if torch 1.8.0 xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh - y = torch.cat((xy, wh, conf), 4) + y = torch.cat((xy, wh, etc), 4) z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) @@ -87,6 +89,23 @@ def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version return grid, anchor_grid +class Segment(Detect): + # YOLOv5 Segment head for segmentation models + def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True): + super().__init__(nc, anchors, ch, inplace) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.no = 5 + nc + self.nm # number of outputs per anchor + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + self.proto = Proto(ch[0], self.npr, self.nm) # protos + self.detect = Detect.forward + + def forward(self, x): + p = self.proto(x[0]) + x = self.detect(self, x) + return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1]) + + class BaseModel(nn.Module): # YOLOv5 base model def forward(self, x, profile=False, visualize=False): @@ -135,7 +154,7 @@ def _apply(self, fn): # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers self = super()._apply(fn) m = self.model[-1] # Detect() - if isinstance(m, Detect): + if isinstance(m, (Detect, Segment)): m.stride = fn(m.stride) m.grid = list(map(fn, m.grid)) if isinstance(m.anchor_grid, list): @@ -169,11 +188,12 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i # Build strides, anchors m = self.model[-1] # Detect() - if isinstance(m, Detect): + if isinstance(m, (Detect, Segment)): s = 256 # 2x min stride m.inplace = self.inplace - m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward - check_anchor_order(m) # must be in pixel-space (not grid-space) + forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x) + m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward + check_anchor_order(m) m.anchors /= m.stride.view(-1, 1, 1) self.stride = m.stride self._initialize_biases() # only run once @@ -235,15 +255,21 @@ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. m = self.model[-1] # Detect() module for mi, s in zip(m.m, m.stride): # from - b = mi.bias.view(m.na, -1).detach() # conv.bias(255) to (3,85) - b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) - b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility +class SegmentationModel(DetectionModel): + # YOLOv5 segmentation model + def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, anchors=None): + super().__init__(cfg, ch, nc, anchors) + + class ClassificationModel(BaseModel): # YOLOv5 classification model def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index @@ -284,24 +310,28 @@ def parse_model(d, ch): # model_dict, input_channels(3) args[j] = eval(a) if isinstance(a, str) else a # eval strings n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x): + if m in { + Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]: + if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m is Detect: + # TODO: channel, gw, gd + elif m in {Detect, Segment}: args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) + if m is Segment: + args[3] = make_divisible(args[3] * gw, 8) elif m is Contract: c2 = ch[f] * args[0] ** 2 elif m is Expand: diff --git a/segment/predict.py b/segment/predict.py new file mode 100644 index 000000000000..ba4cf2905255 --- /dev/null +++ b/segment/predict.py @@ -0,0 +1,266 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Run YOLOv5 segmentation inference on images, videos, directories, streams, etc. + +Usage - sources: + $ python segment/predict.py --weights yolov5s-seg.pt --source 0 # webcam + img.jpg # image + vid.mp4 # video + path/ # directory + 'path/*.jpg' # glob + 'https://youtu.be/Zgi9g1ksQHc' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream + +Usage - formats: + $ python segment/predict.py --weights yolov5s-seg.pt # PyTorch + yolov5s-seg.torchscript # TorchScript + yolov5s-seg.onnx # ONNX Runtime or OpenCV DNN with --dnn + yolov5s-seg.xml # OpenVINO + yolov5s-seg.engine # TensorRT + yolov5s-seg.mlmodel # CoreML (macOS-only) + yolov5s-seg_saved_model # TensorFlow SavedModel + yolov5s-seg.pb # TensorFlow GraphDef + yolov5s-seg.tflite # TensorFlow Lite + yolov5s-seg_edgetpu.tflite # TensorFlow Edge TPU + yolov5s-seg_paddle_model # PaddlePaddle +""" + +import argparse +import os +import platform +import sys +from pathlib import Path + +import torch + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from models.common import DetectMultiBackend +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams +from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, + increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) +from utils.plots import Annotator, colors, save_one_box +from utils.segment.general import process_mask +from utils.torch_utils import select_device, smart_inference_mode + + +@smart_inference_mode() +def run( + weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) + source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam + data=ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=False, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=ROOT / 'runs/predict-seg', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + retina_masks=False, +): + source = str(source) + save_img = not nosave and not source.endswith('.txt') # save inference images + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) + if is_url and is_file: + source = check_file(source) # download + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Load model + device = select_device(device) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + stride, names, pt = model.stride, model.names, model.pt + imgsz = check_img_size(imgsz, s=stride) # check image size + + # Dataloader + if webcam: + view_img = check_imshow() + dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + bs = len(dataset) # batch_size + else: + dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + bs = 1 # batch_size + vid_path, vid_writer = [None] * bs, [None] * bs + + # Run inference + model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup + seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + for path, im, im0s, vid_cap, s in dataset: + with dt[0]: + im = torch.from_numpy(im).to(device) + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + + # Inference + with dt[1]: + visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False + pred, proto = model(im, augment=augment, visualize=visualize)[:2] + + # NMS + with dt[2]: + pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32) + + # Second-stage classifier (optional) + # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) + + # Process predictions + for i, det in enumerate(pred): # per image + seen += 1 + if webcam: # batch_size >= 1 + p, im0, frame = path[i], im0s[i].copy(), dataset.count + s += f'{i}: ' + else: + p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) + + p = Path(p) # to Path + save_path = str(save_dir / p.name) # im.jpg + txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt + s += '%gx%g ' % im.shape[2:] # print string + gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + imc = im0.copy() if save_crop else im0 # for save_crop + annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + if len(det): + masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC + + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() + + # Print results + for c in det[:, 5].unique(): + n = (det[:, 5] == c).sum() # detections per class + s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # Mask plotting + annotator.masks(masks, + colors=[colors(x, True) for x in det[:, 5]], + im_gpu=None if retina_masks else im[i]) + + # Write results + for *xyxy, conf, cls in reversed(det[:, :6]): + if save_txt: # Write to file + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(f'{txt_path}.txt', 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + if save_img or save_crop or view_img: # Add bbox to image + c = int(cls) # integer class + label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') + annotator.box_label(xyxy, label, color=colors(c, True)) + if save_crop: + save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) + + # Stream results + im0 = annotator.result() + if view_img: + if platform.system() == 'Linux' and p not in windows: + windows.append(p) + cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) + cv2.imshow(str(p), im0) + if cv2.waitKey(1) == ord('q'): # 1 millisecond + exit() + + # Save results (image with detections) + if save_img: + if dataset.mode == 'image': + cv2.imwrite(save_path, im0) + else: # 'video' or 'stream' + if vid_path[i] != save_path: # new video + vid_path[i] = save_path + if isinstance(vid_writer[i], cv2.VideoWriter): + vid_writer[i].release() # release previous video writer + if vid_cap: # video + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + else: # stream + fps, w, h = 30, im0.shape[1], im0.shape[0] + save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos + vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + vid_writer[i].write(im0) + + # Print time (inference-only) + LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms") + + # Print results + t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) + if save_txt or save_img: + s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") + if update: + strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning) + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-seg.pt', help='model path(s)') + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') + parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') + parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--view-img', action='store_true', help='show results') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes') + parser.add_argument('--nosave', action='store_true', help='do not save images/videos') + parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3') + parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--visualize', action='store_true', help='visualize features') + parser.add_argument('--update', action='store_true', help='update all models') + parser.add_argument('--project', default=ROOT / 'runs/predict-seg', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)') + parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') + parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride') + parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution') + opt = parser.parse_args() + opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand + print_args(vars(opt)) + return opt + + +def main(opt): + check_requirements(exclude=('tensorboard', 'thop')) + run(**vars(opt)) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/segment/train.py b/segment/train.py new file mode 100644 index 000000000000..bda379176151 --- /dev/null +++ b/segment/train.py @@ -0,0 +1,676 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Train a YOLOv5 segment model on a segment dataset +Models and datasets download automatically from the latest YOLOv5 release. + +Usage - Single-GPU training: + $ python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 # from pretrained (recommended) + $ python segment/train.py --data coco128-seg.yaml --weights '' --cfg yolov5s-seg.yaml --img 640 # from scratch + +Usage - Multi-GPU DDP training: + $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 --device 0,1,2,3 + +Models: https://github.com/ultralytics/yolov5/tree/master/models +Datasets: https://github.com/ultralytics/yolov5/tree/master/data +Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data +""" + +import argparse +import math +import os +import random +import sys +import time +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import yaml +from torch.optim import lr_scheduler +from tqdm import tqdm + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +import torch.nn.functional as F + +import segment.val as validate # for end-of-epoch mAP +from models.experimental import attempt_load +from models.yolo import SegmentationModel +from utils.autoanchor import check_anchors +from utils.autobatch import check_train_batch_size +from utils.callbacks import Callbacks +from utils.downloads import attempt_download, is_url +from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, + check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, one_cycle, + print_args, print_mutation, strip_optimizer, yaml_save) +from utils.loggers import GenericLogger +from utils.plots import plot_evolve, plot_labels +from utils.segment.dataloaders import create_dataloader +from utils.segment.loss import ComputeLoss +from utils.segment.metrics import KEYS, fitness +from utils.segment.plots import plot_images_and_masks, plot_results_with_masks +from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, + smart_resume, torch_distributed_zero_first) + +LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv('RANK', -1)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) + + +def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary + save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \ + Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ + opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio + # callbacks.run('on_pretrain_routine_start') + + # Directories + w = save_dir / 'weights' # weights dir + (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir + last, best = w / 'last.pt', w / 'best.pt' + + # Hyperparameters + if isinstance(hyp, str): + with open(hyp, errors='ignore') as f: + hyp = yaml.safe_load(f) # load hyps dict + LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) + opt.hyp = hyp.copy() # for saving hyps to checkpoints + + # Save run settings + if not evolve: + yaml_save(save_dir / 'hyp.yaml', hyp) + yaml_save(save_dir / 'opt.yaml', vars(opt)) + + # Loggers + data_dict = None + if RANK in {-1, 0}: + logger = GenericLogger(opt=opt, console_logger=LOGGER) + # loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance + # if loggers.clearml: + # data_dict = loggers.clearml.data_dict # None if no ClearML dataset or filled in by ClearML + # if loggers.wandb: + # data_dict = loggers.wandb.data_dict + # if resume: + # weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size + # + # # Register actions + # for k in methods(loggers): + # callbacks.register_action(k, callback=getattr(loggers, k)) + + # Config + plots = not evolve and not opt.noplots # create plots + overlap = not opt.no_overlap + cuda = device.type != 'cpu' + init_seeds(opt.seed + 1 + RANK, deterministic=True) + with torch_distributed_zero_first(LOCAL_RANK): + data_dict = data_dict or check_dataset(data) # check if None + train_path, val_path = data_dict['train'], data_dict['val'] + nc = 1 if single_cls else int(data_dict['nc']) # number of classes + names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names + is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset + + # Model + check_suffix(weights, '.pt') # check weights + pretrained = weights.endswith('.pt') + if pretrained: + with torch_distributed_zero_first(LOCAL_RANK): + weights = attempt_download(weights) # download if not found locally + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) + exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect + model.load_state_dict(csd, strict=False) # load + LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report + else: + model = SegmentationModel(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + amp = check_amp(model) # check AMP + + # Freeze + freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze + for k, v in model.named_parameters(): + v.requires_grad = True # train all layers + # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) + if any(x in k for x in freeze): + LOGGER.info(f'freezing {k}') + v.requires_grad = False + + # Image size + gs = max(int(model.stride.max()), 32) # grid size (max stride) + imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple + + # Batch size + if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size + batch_size = check_train_batch_size(model, imgsz, amp) + logger.update_params({"batch_size": batch_size}) + # loggers.on_params_update({"batch_size": batch_size}) + + # Optimizer + nbs = 64 # nominal batch size + accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing + hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay + optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay']) + + # Scheduler + if opt.cos_lr: + lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] + else: + lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) + + # EMA + ema = ModelEMA(model) if RANK in {-1, 0} else None + + # Resume + best_fitness, start_epoch = 0.0, 0 + if pretrained: + if resume: + best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume) + del ckpt, csd + + # DP mode + if cuda and RANK == -1 and torch.cuda.device_count() > 1: + LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n' + 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.') + model = torch.nn.DataParallel(model) + + # SyncBatchNorm + if opt.sync_bn and cuda and RANK != -1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) + LOGGER.info('Using SyncBatchNorm()') + + # Trainloader + train_loader, dataset = create_dataloader( + train_path, + imgsz, + batch_size // WORLD_SIZE, + gs, + single_cls, + hyp=hyp, + augment=True, + cache=None if opt.cache == 'val' else opt.cache, + rect=opt.rect, + rank=LOCAL_RANK, + workers=workers, + image_weights=opt.image_weights, + quad=opt.quad, + prefix=colorstr('train: '), + shuffle=True, + mask_downsample_ratio=mask_ratio, + overlap_mask=overlap, + ) + labels = np.concatenate(dataset.labels, 0) + mlc = int(labels[:, 0].max()) # max label class + assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' + + # Process 0 + if RANK in {-1, 0}: + val_loader = create_dataloader(val_path, + imgsz, + batch_size // WORLD_SIZE * 2, + gs, + single_cls, + hyp=hyp, + cache=None if noval else opt.cache, + rect=True, + rank=-1, + workers=workers * 2, + pad=0.5, + mask_downsample_ratio=mask_ratio, + overlap_mask=overlap, + prefix=colorstr('val: '))[0] + + if not resume: + if not opt.noautoanchor: + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor + model.half().float() # pre-reduce anchor precision + + if plots: + plot_labels(labels, names, save_dir) + # callbacks.run('on_pretrain_routine_end', labels, names) + + # DDP mode + if cuda and RANK != -1: + model = smart_DDP(model) + + # Model attributes + nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) + hyp['box'] *= 3 / nl # scale to layers + hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers + hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers + hyp['label_smoothing'] = opt.label_smoothing + model.nc = nc # attach number of classes to model + model.hyp = hyp # attach hyperparameters to model + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights + model.names = names + + # Start training + t0 = time.time() + nb = len(train_loader) # number of batches + nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations) + # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training + last_opt_step = -1 + maps = np.zeros(nc) # mAP per class + results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) + scheduler.last_epoch = start_epoch - 1 # do not move + scaler = torch.cuda.amp.GradScaler(enabled=amp) + stopper, stop = EarlyStopping(patience=opt.patience), False + compute_loss = ComputeLoss(model, overlap=overlap) # init loss class + # callbacks.run('on_train_start') + LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' + f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' + f"Logging results to {colorstr('bold', save_dir)}\n" + f'Starting training for {epochs} epochs...') + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ + # callbacks.run('on_train_epoch_start') + model.train() + + # Update image weights (optional, single-GPU only) + if opt.image_weights: + cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights + iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights + dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx + + # Update mosaic border (optional) + # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) + # dataset.mosaic_border = [b - imgsz, -b] # height, width borders + + mloss = torch.zeros(4, device=device) # mean losses + if RANK != -1: + train_loader.sampler.set_epoch(epoch) + pbar = enumerate(train_loader) + LOGGER.info(('\n' + '%11s' * 8) % + ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Instances', 'Size')) + if RANK in {-1, 0}: + pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar + optimizer.zero_grad() + for i, (imgs, targets, paths, _, masks) in pbar: # batch ------------------------------------------------------ + # callbacks.run('on_train_batch_start') + ni = i + nb * epoch # number integrated batches (since train start) + imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0 + + # Warmup + if ni <= nw: + xi = [0, nw] # x interp + # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) + accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) + for j, x in enumerate(optimizer.param_groups): + # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 + x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)]) + if 'momentum' in x: + x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) + + # Multi-scale + if opt.multi_scale: + sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) + + # Forward + with torch.cuda.amp.autocast(amp): + pred = model(imgs) # forward + loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float()) + if RANK != -1: + loss *= WORLD_SIZE # gradient averaged between devices in DDP mode + if opt.quad: + loss *= 4. + + # Backward + scaler.scale(loss).backward() + + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html + if ni - last_opt_step >= accumulate: + scaler.unscale_(optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients + scaler.step(optimizer) # optimizer.step + scaler.update() + optimizer.zero_grad() + if ema: + ema.update(model) + last_opt_step = ni + + # Log + if RANK in {-1, 0}: + mloss = (mloss * i + loss_items) / (i + 1) # update mean losses + mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) + pbar.set_description(('%11s' * 2 + '%11.4g' * 6) % + (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) + # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths) + # if callbacks.stop_training: + # return + + # Mosaic plots + if plots: + if ni < 3: + plot_images_and_masks(imgs, targets, masks, paths, save_dir / f"train_batch{ni}.jpg") + if ni == 10: + files = sorted(save_dir.glob('train*.jpg')) + logger.log_images(files, "Mosaics", epoch) + # end batch ------------------------------------------------------------------------------------------------ + + # Scheduler + lr = [x['lr'] for x in optimizer.param_groups] # for loggers + scheduler.step() + + if RANK in {-1, 0}: + # mAP + # callbacks.run('on_train_epoch_end', epoch=epoch) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) + final_epoch = (epoch + 1 == epochs) or stopper.possible_stop + if not noval or final_epoch: # Calculate mAP + results, maps, _ = validate.run(data_dict, + batch_size=batch_size // WORLD_SIZE * 2, + imgsz=imgsz, + half=amp, + model=ema.ema, + single_cls=single_cls, + dataloader=val_loader, + save_dir=save_dir, + plots=False, + callbacks=callbacks, + compute_loss=compute_loss, + mask_downsample_ratio=mask_ratio, + overlap=overlap) + + # Update best mAP + fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] + stop = stopper(epoch=epoch, fitness=fi) # early stop check + if fi > best_fitness: + best_fitness = fi + log_vals = list(mloss) + list(results) + lr + # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) + # Log val metrics and media + metrics_dict = dict(zip(KEYS, log_vals)) + logger.log_metrics(metrics_dict, epoch) + + # Save model + if (not nosave) or (final_epoch and not evolve): # if save + ckpt = { + 'epoch': epoch, + 'best_fitness': best_fitness, + 'model': deepcopy(de_parallel(model)).half(), + 'ema': deepcopy(ema.ema).half(), + 'updates': ema.updates, + 'optimizer': optimizer.state_dict(), + # 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, + 'opt': vars(opt), + 'date': datetime.now().isoformat()} + + # Save last, best and delete + torch.save(ckpt, last) + if best_fitness == fi: + torch.save(ckpt, best) + if opt.save_period > 0 and epoch % opt.save_period == 0: + torch.save(ckpt, w / f'epoch{epoch}.pt') + logger.log_model(w / f'epoch{epoch}.pt') + del ckpt + # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) + + # EarlyStopping + if RANK != -1: # if DDP training + broadcast_list = [stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + if RANK != 0: + stop = broadcast_list[0] + if stop: + break # must break all DDP ranks + + # end epoch ---------------------------------------------------------------------------------------------------- + # end training ----------------------------------------------------------------------------------------------------- + if RANK in {-1, 0}: + LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') + for f in last, best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is best: + LOGGER.info(f'\nValidating {f}...') + results, _, _ = validate.run( + data_dict, + batch_size=batch_size // WORLD_SIZE * 2, + imgsz=imgsz, + model=attempt_load(f, device).half(), + iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65 + single_cls=single_cls, + dataloader=val_loader, + save_dir=save_dir, + save_json=is_coco, + verbose=True, + plots=plots, + callbacks=callbacks, + compute_loss=compute_loss, + mask_downsample_ratio=mask_ratio, + overlap=overlap) # val best model with plots + if is_coco: + # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) + metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr)) + logger.log_metrics(metrics_dict, epoch) + + # callbacks.run('on_train_end', last, best, epoch, results) + # on train end callback using genericLogger + logger.log_metrics(dict(zip(KEYS[4:16], results)), epochs) + if not opt.evolve: + logger.log_model(best, epoch) + if plots: + plot_results_with_masks(file=save_dir / 'results.csv') # save results.png + files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + files = [(save_dir / f) for f in files if (save_dir / f).exists()] # filter + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") + logger.log_images(files, "Results", epoch + 1) + logger.log_images(sorted(save_dir.glob('val*.jpg')), "Validation", epoch + 1) + torch.cuda.empty_cache() + return results + + +def parse_opt(known=False): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s-seg.pt', help='initial weights path') + parser.add_argument('--cfg', type=str, default='', help='model.yaml path') + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path') + parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path') + parser.add_argument('--epochs', type=int, default=300, help='total training epochs') + parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') + parser.add_argument('--rect', action='store_true', help='rectangular training') + parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') + parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') + parser.add_argument('--noval', action='store_true', help='only validate final epoch') + parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor') + parser.add_argument('--noplots', action='store_true', help='save no plot files') + parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') + parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') + parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer') + parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') + parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--project', default=ROOT / 'runs/train-seg', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--quad', action='store_true', help='quad dataloader') + parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler') + parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon') + parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)') + parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2') + parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)') + parser.add_argument('--seed', type=int, default=0, help='Global training seed') + parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') + + # Instance Segmentation Args + parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory') + parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP') + + # Weights & Biases arguments + # parser.add_argument('--entity', default=None, help='W&B: Entity') + # parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option') + # parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') + # parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') + + return parser.parse_known_args()[0] if known else parser.parse_args() + + +def main(opt, callbacks=Callbacks()): + # Checks + if RANK in {-1, 0}: + print_args(vars(opt)) + check_git_status() + check_requirements() + + # Resume + if opt.resume and not opt.evolve: # resume from specified or most recent last.pt + last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run()) + opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml + opt_data = opt.data # original dataset + if opt_yaml.is_file(): + with open(opt_yaml, errors='ignore') as f: + d = yaml.safe_load(f) + else: + d = torch.load(last, map_location='cpu')['opt'] + opt = argparse.Namespace(**d) # replace + opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate + if is_url(opt_data): + opt.data = check_file(opt_data) # avoid HUB resume auth timeout + else: + opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \ + check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks + assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' + if opt.evolve: + if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve + opt.project = str(ROOT / 'runs/evolve') + opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume + if opt.name == 'cfg': + opt.name = Path(opt.cfg).stem # use model.yaml as name + opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) + + # DDP mode + device = select_device(opt.device, batch_size=opt.batch_size) + if LOCAL_RANK != -1: + msg = 'is not compatible with YOLOv5 Multi-GPU DDP training' + assert not opt.image_weights, f'--image-weights {msg}' + assert not opt.evolve, f'--evolve {msg}' + assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size' + assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE' + assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' + torch.cuda.set_device(LOCAL_RANK) + device = torch.device('cuda', LOCAL_RANK) + dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") + + # Train + if not opt.evolve: + train(opt.hyp, opt, device, callbacks) + + # Evolve hyperparameters (optional) + else: + # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) + meta = { + 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) + 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1 + 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay + 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) + 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum + 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr + 'box': (1, 0.02, 0.2), # box loss gain + 'cls': (1, 0.2, 4.0), # cls loss gain + 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight + 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) + 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight + 'iou_t': (0, 0.1, 0.7), # IoU training threshold + 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold + 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore) + 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) + 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) + 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) + 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction) + 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg) + 'translate': (1, 0.0, 0.9), # image translation (+/- fraction) + 'scale': (1, 0.0, 0.9), # image scale (+/- gain) + 'shear': (1, 0.0, 10.0), # image shear (+/- deg) + 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) + 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) + 'mosaic': (1, 0.0, 1.0), # image mixup (probability) + 'mixup': (1, 0.0, 1.0), # image mixup (probability) + 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability) + + with open(opt.hyp, errors='ignore') as f: + hyp = yaml.safe_load(f) # load hyps dict + if 'anchors' not in hyp: # anchors commented in hyp.yaml + hyp['anchors'] = 3 + if opt.noautoanchor: + del hyp['anchors'], meta['anchors'] + opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch + # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices + evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv' + if opt.bucket: + os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists + + for _ in range(opt.evolve): # generations to evolve + if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate + # Select parent(s) + parent = 'single' # parent selection method: 'single' or 'weighted' + x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) + n = min(5, len(x)) # number of previous results to consider + x = x[np.argsort(-fitness(x))][:n] # top n mutations + w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) + if parent == 'single' or len(x) == 1: + # x = x[random.randint(0, n - 1)] # random selection + x = x[random.choices(range(n), weights=w)[0]] # weighted selection + elif parent == 'weighted': + x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination + + # Mutate + mp, s = 0.8, 0.2 # mutation probability, sigma + npr = np.random + npr.seed(int(time.time())) + g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1 + ng = len(meta) + v = np.ones(ng) + while all(v == 1): # mutate until a change occurs (prevent duplicates) + v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) + for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) + hyp[k] = float(x[i + 7] * v[i]) # mutate + + # Constrain to limits + for k, v in meta.items(): + hyp[k] = max(hyp[k], v[1]) # lower limit + hyp[k] = min(hyp[k], v[2]) # upper limit + hyp[k] = round(hyp[k], 5) # significant digits + + # Train mutation + results = train(hyp.copy(), opt, device, callbacks) + callbacks = Callbacks() + # Write mutation results + print_mutation(results, hyp.copy(), save_dir, opt.bucket) + + # Plot results + plot_evolve(evolve_csv) + LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n' + f"Results saved to {colorstr('bold', save_dir)}\n" + f'Usage example: $ python train.py --hyp {evolve_yaml}') + + +def run(**kwargs): + # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') + opt = parse_opt(True) + for k, v in kwargs.items(): + setattr(opt, k, v) + main(opt) + return opt + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/segment/val.py b/segment/val.py new file mode 100644 index 000000000000..138aa00aaed3 --- /dev/null +++ b/segment/val.py @@ -0,0 +1,471 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Validate a trained YOLOv5 segment model on a segment dataset + +Usage: + $ bash data/scripts/get_coco.sh --val --segments # download COCO-segments val split (1G, 5000 images) + $ python segment/val.py --weights yolov5s-seg.pt --data coco.yaml --img 640- # validate COCO-segments + +Usage - formats: + $ python segment/val.py --weights yolov5s-seg.pt # PyTorch + yolov5s-seg.torchscript # TorchScript + yolov5s-seg.onnx # ONNX Runtime or OpenCV DNN with --dnn + yolov5s-seg.xml # OpenVINO + yolov5s-seg.engine # TensorRT + yolov5s-seg.mlmodel # CoreML (macOS-only) + yolov5s-seg_saved_model # TensorFlow SavedModel + yolov5s-seg.pb # TensorFlow GraphDef + yolov5s-seg.tflite # TensorFlow Lite + yolov5s-seg_edgetpu.tflite # TensorFlow Edge TPU + yolov5s-seg_paddle_model # PaddlePaddle +""" + +import argparse +import json +import os +import sys +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +import torch.nn.functional as F + +from models.common import DetectMultiBackend +from models.yolo import SegmentationModel +from utils.callbacks import Callbacks +from utils.general import (LOGGER, NUM_THREADS, Profile, check_dataset, check_img_size, check_requirements, check_yaml, + coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, + scale_coords, xywh2xyxy, xyxy2xywh) +from utils.metrics import ConfusionMatrix, box_iou +from utils.plots import output_to_target, plot_val_study +from utils.segment.dataloaders import create_dataloader +from utils.segment.general import mask_iou, process_mask, process_mask_upsample, scale_image +from utils.segment.metrics import Metrics, ap_per_class_box_and_mask +from utils.segment.plots import plot_images_and_masks +from utils.torch_utils import de_parallel, select_device, smart_inference_mode + + +def save_one_txt(predn, save_conf, shape, file): + # Save one txt result + gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls in predn.tolist(): + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(file, 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + +def save_one_json(predn, jdict, path, class_map, pred_masks): + # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + from pycocotools.mask import encode + + def single_encode(x): + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + return rle + + image_id = int(path.stem) if path.stem.isnumeric() else path.stem + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) + for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): + jdict.append({ + 'image_id': image_id, + 'category_id': class_map[int(p[5])], + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5), + 'segmentation': rles[i]}) + + +def process_batch(detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): + """ + Return correct prediction matrix + Arguments: + detections (array[N, 6]), x1, y1, x2, y2, conf, class + labels (array[M, 5]), class, x1, y1, x2, y2 + Returns: + correct (array[N, 10]), for 10 IoU levels + """ + if masks: + if overlap: + nl = len(labels) + index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 + gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) + gt_masks = torch.where(gt_masks == index, 1.0, 0.0) + if gt_masks.shape[1:] != pred_masks.shape[1:]: + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] + gt_masks = gt_masks.gt_(0.5) + iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) + else: # boxes + iou = box_iou(labels[:, 1:], detections[:, :4]) + + correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) + correct_class = labels[:, 0:1] == detections[:, 5] + for i in range(len(iouv)): + x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=iouv.device) + + +@smart_inference_mode() +def run( + data, + weights=None, # model.pt path(s) + batch_size=32, # batch size + imgsz=640, # inference size (pixels) + conf_thres=0.001, # confidence threshold + iou_thres=0.6, # NMS IoU threshold + max_det=300, # maximum detections per image + task='val', # train, val, test, speed or study + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + workers=8, # max dataloader workers (per RANK in DDP mode) + single_cls=False, # treat as single-class dataset + augment=False, # augmented inference + verbose=False, # verbose output + save_txt=False, # save results to *.txt + save_hybrid=False, # save label+prediction hybrid results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_json=False, # save a COCO-JSON results file + project=ROOT / 'runs/val-seg', # save to project/name + name='exp', # save to project/name + exist_ok=False, # existing project/name ok, do not increment + half=True, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + model=None, + dataloader=None, + save_dir=Path(''), + plots=True, + overlap=False, + mask_downsample_ratio=1, + compute_loss=None, + callbacks=Callbacks(), +): + if save_json: + check_requirements(['pycocotools']) + process = process_mask_upsample # more accurate + else: + process = process_mask # faster + + # Initialize/load model and set device + training = model is not None + if training: # called by train.py + device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model + half &= device.type != 'cpu' # half precision only supported on CUDA + model.half() if half else model.float() + nm = de_parallel(model).model[-1].nm # number of masks + else: # called directly + device = select_device(device, batch_size=batch_size) + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Load model + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_img_size(imgsz, s=stride) # check image size + half = model.fp16 # FP16 supported on limited backends with CUDA + nm = de_parallel(model).model.model[-1].nm if isinstance(model, SegmentationModel) else 32 # number of masks + if engine: + batch_size = model.batch_size + else: + device = model.device + if not (pt or jit): + batch_size = 1 # export.py models default to batch-size 1 + LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + + # Data + data = check_dataset(data) # check + + # Configure + model.eval() + cuda = device.type != 'cpu' + is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'coco{os.sep}val2017.txt') # COCO dataset + nc = 1 if single_cls else int(data['nc']) # number of classes + iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95 + niou = iouv.numel() + + # Dataloader + if not training: + if pt and not single_cls: # check --weights are trained on --data + ncm = model.model.nc + assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \ + f'classes). Pass correct combination of --weights and --data that are trained together.' + model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup + pad = 0.0 if task in ('speed', 'benchmark') else 0.5 + rect = False if task == 'benchmark' else pt # square inference for benchmarks + task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images + dataloader = create_dataloader(data[task], + imgsz, + batch_size, + stride, + single_cls, + pad=pad, + rect=rect, + workers=workers, + prefix=colorstr(f'{task}: '), + overlap_mask=overlap, + mask_downsample_ratio=mask_downsample_ratio)[0] + + seen = 0 + confusion_matrix = ConfusionMatrix(nc=nc) + names = model.names if hasattr(model, 'names') else model.module.names # get class names + if isinstance(names, (list, tuple)): # old format + names = dict(enumerate(names)) + class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) + s = ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", "R", + "mAP50", "mAP50-95)") + dt = Profile(), Profile(), Profile() + metrics = Metrics() + loss = torch.zeros(4, device=device) + jdict, stats = [], [] + # callbacks.run('on_val_start') + pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar + for batch_i, (im, targets, paths, shapes, masks) in enumerate(pbar): + # callbacks.run('on_val_batch_start') + with dt[0]: + if cuda: + im = im.to(device, non_blocking=True) + targets = targets.to(device) + masks = masks.to(device) + masks = masks.float() + im = im.half() if half else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + nb, _, height, width = im.shape # batch size, channels, height, width + + # Inference + with dt[1]: + preds, protos, train_out = model(im) if compute_loss else (*model(im, augment=augment)[:2], None) + + # Loss + if compute_loss: + loss += compute_loss((train_out, protos), targets, masks)[1] # box, obj, cls + + # NMS + targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels + lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling + with dt[2]: + preds = non_max_suppression(preds, + conf_thres, + iou_thres, + labels=lb, + multi_label=True, + agnostic=single_cls, + max_det=max_det, + nm=nm) + + # Metrics + plot_masks = [] # masks for plotting + for si, (pred, proto) in enumerate(zip(preds, protos)): + labels = targets[targets[:, 0] == si, 1:] + nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions + path, shape = Path(paths[si]), shapes[si][0] + correct_masks = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init + correct_bboxes = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init + seen += 1 + + if npr == 0: + if nl: + stats.append((correct_masks, correct_bboxes, *torch.zeros((2, 0), device=device), labels[:, 0])) + if plots: + confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) + continue + + # Masks + midx = [si] if overlap else targets[:, 0] == si + gt_masks = masks[midx] + pred_masks = process(proto, pred[:, 6:], pred[:, :4], shape=im[si].shape[1:]) + + # Predictions + if single_cls: + pred[:, 5] = 0 + predn = pred.clone() + scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred + + # Evaluate + if nl: + tbox = xywh2xyxy(labels[:, 1:5]) # target boxes + scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels + labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels + correct_bboxes = process_batch(predn, labelsn, iouv) + correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True) + if plots: + confusion_matrix.process_batch(predn, labelsn) + stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls) + + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if plots and batch_i < 3: + plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot + + # Save/log + if save_txt: + save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + if save_json: + pred_masks = scale_image(im[si].shape[1:], + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1]) + save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary + # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) + + # Plot images + if plots and batch_i < 3: + if len(plot_masks): + plot_masks = torch.cat(plot_masks, dim=0) + plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) + plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths, + save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred + + # callbacks.run('on_val_batch_end') + + # Compute metrics + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy + if len(stats) and stats[0].any(): + results = ap_per_class_box_and_mask(*stats, plot=plots, save_dir=save_dir, names=names) + metrics.update(results) + nt = np.bincount(stats[4].astype(int), minlength=nc) # number of targets per class + + # Print results + pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format + LOGGER.info(pf % ("all", seen, nt.sum(), *metrics.mean_results())) + if nt.sum() == 0: + LOGGER.warning(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️') + + # Print results per class + if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): + for i, c in enumerate(metrics.ap_class_index): + LOGGER.info(pf % (names[c], seen, nt[c], *metrics.class_result(i))) + + # Print speeds + t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image + if not training: + shape = (batch_size, 3, imgsz, imgsz) + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) + + # Plots + if plots: + confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) + # callbacks.run('on_val_end') + + mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask = metrics.mean_results() + + # Save JSON + if save_json and len(jdict): + w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights + anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json + pred_json = str(save_dir / f"{w}_predictions.json") # predictions json + LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...') + with open(pred_json, 'w') as f: + json.dump(jdict, f) + + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + anno = COCO(anno_json) # init annotations api + pred = anno.loadRes(pred_json) # init predictions api + results = [] + for eval in COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm'): + if is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # img ID to evaluate + eval.evaluate() + eval.accumulate() + eval.summarize() + results.extend(eval.stats[:2]) # update results (mAP@0.5:0.95, mAP@0.5) + map_bbox, map50_bbox, map_mask, map50_mask = results + except Exception as e: + LOGGER.info(f'pycocotools unable to run: {e}') + + # Return results + model.float() # for training + if not training: + s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") + final_metric = mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask + return (*final_metric, *(loss.cpu() / len(dataloader)).tolist()), metrics.get_maps(nc), t + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path') + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-seg.pt', help='model path(s)') + parser.add_argument('--batch-size', type=int, default=32, help='batch size') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold') + parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image') + parser.add_argument('--task', default='val', help='train, val, test, speed or study') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--verbose', action='store_true', help='report mAP by class') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file') + parser.add_argument('--project', default=ROOT / 'runs/val-seg', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + opt = parser.parse_args() + opt.data = check_yaml(opt.data) # check YAML + # opt.save_json |= opt.data.endswith('coco.yaml') + opt.save_txt |= opt.save_hybrid + print_args(vars(opt)) + return opt + + +def main(opt): + check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop')) + + if opt.task in ('train', 'val', 'test'): # run normally + if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466 + LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} > 0.001 produces invalid results ⚠️') + if opt.save_hybrid: + LOGGER.info('WARNING: --save-hybrid will return high mAP from hybrid labels, not from predictions alone ⚠️') + run(**vars(opt)) + + else: + weights = opt.weights if isinstance(opt.weights, list) else [opt.weights] + opt.half = True # FP16 for fastest results + if opt.task == 'speed': # speed benchmarks + # python val.py --task speed --data coco.yaml --batch 1 --weights yolov5n.pt yolov5s.pt... + opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False + for opt.weights in weights: + run(**vars(opt), plots=False) + + elif opt.task == 'study': # speed vs mAP benchmarks + # python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n.pt yolov5s.pt... + for opt.weights in weights: + f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to + x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis + for opt.imgsz in x: # img-size + LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...') + r, _, t = run(**vars(opt), plots=False) + y.append(r + t) # results and times + np.savetxt(f, y, fmt='%10.4g') # save + os.system('zip -r study.zip study_*.txt') + plot_val_study(x=x) # plot + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/utils/dataloaders.py b/utils/dataloaders.py old mode 100755 new mode 100644 index d8ef11fd94b4..c04be853c580 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -484,6 +484,7 @@ def __init__(self, self.im_files = [self.im_files[i] for i in irect] self.label_files = [self.label_files[i] for i in irect] self.labels = [self.labels[i] for i in irect] + self.segments = [self.segments[i] for i in irect] self.shapes = s[irect] # wh ar = ar[irect] diff --git a/utils/general.py b/utils/general.py old mode 100755 new mode 100644 index f5fb2c93a3d5..8633511f89f5 --- a/utils/general.py +++ b/utils/general.py @@ -798,15 +798,18 @@ def clip_coords(boxes, shape): boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 -def non_max_suppression(prediction, - conf_thres=0.25, - iou_thres=0.45, - classes=None, - agnostic=False, - multi_label=False, - labels=(), - max_det=300): - """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nm=0, # number of masks +): + """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] @@ -816,7 +819,7 @@ def non_max_suppression(prediction, prediction = prediction[0] # select only inference output bs = prediction.shape[0] # batch size - nc = prediction.shape[2] - 5 # number of classes + nc = prediction.shape[2] - nm - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates # Checks @@ -827,13 +830,14 @@ def non_max_suppression(prediction, # min_wh = 2 # (pixels) minimum box width and height max_wh = 7680 # (pixels) maximum box width and height max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() - time_limit = 0.3 + 0.03 * bs # seconds to quit after + time_limit = 0.5 + 0.05 * bs # seconds to quit after redundant = True # require redundant detections multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS t = time.time() - output = [torch.zeros((0, 6), device=prediction.device)] * bs + mi = 5 + nc # mask start index + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height @@ -842,7 +846,7 @@ def non_max_suppression(prediction, # Cat apriori labels if autolabelling if labels and len(labels[xi]): lb = labels[xi] - v = torch.zeros((len(lb), nc + 5), device=x.device) + v = torch.zeros((len(lb), nc + nm + 5), device=x.device) v[:, :4] = lb[:, 1:5] # box v[:, 4] = 1.0 # conf v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls @@ -855,16 +859,17 @@ def non_max_suppression(prediction, # Compute conf x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf - # Box (center x, center y, width, height) to (x1, y1, x2, y2) - box = xywh2xyxy(x[:, :4]) + # Box/Mask + box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2) + mask = x[:, mi:] # zero columns if no masks # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T - x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1) else: # best class only - conf, j = x[:, 5:].max(1, keepdim=True) - x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + conf, j = x[:, 5:mi].max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] # Filter by class if classes is not None: @@ -880,6 +885,8 @@ def non_max_suppression(prediction, continue elif n > max_nms: # excess boxes x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + else: + x = x[x[:, 4].argsort(descending=True)] # sort by confidence # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes diff --git a/utils/metrics.py b/utils/metrics.py index ee7d33982cfc..001813cbcd65 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -28,7 +28,7 @@ def smooth(y, f=0.05): return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed -def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16): +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""): """ Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. # Arguments @@ -83,10 +83,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = dict(enumerate(names)) # to dict if plot: - plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) - plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') - plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') - plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') + plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) + plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') + plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') + plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') i = smooth(f1.mean(0), 0.1).argmax() # max F1 index p, r, f1 = p[:, i], r[:, i], f1[:, i] diff --git a/utils/plots.py b/utils/plots.py index 0530d0abdf48..d8d5b225a774 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -23,6 +23,7 @@ from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path, is_ascii, xywh2xyxy, xyxy2xywh) from utils.metrics import fitness +from utils.segment.general import scale_image # Settings RANK = int(os.getenv('RANK', -1)) @@ -113,6 +114,52 @@ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 2 thickness=tf, lineType=cv2.LINE_AA) + def masks(self, masks, colors, im_gpu=None, alpha=0.5): + """Plot masks at once. + Args: + masks (tensor): predicted masks on cuda, shape: [n, h, w] + colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n] + im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1] + alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque + """ + if self.pil: + # convert to numpy first + self.im = np.asarray(self.im).copy() + if im_gpu is None: + # Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...) + if len(masks) == 0: + return + if isinstance(masks, torch.Tensor): + masks = torch.as_tensor(masks, dtype=torch.uint8) + masks = masks.permute(1, 2, 0).contiguous() + masks = masks.cpu().numpy() + # masks = np.ascontiguousarray(masks.transpose(1, 2, 0)) + masks = scale_image(masks.shape[:2], masks, self.im.shape) + masks = np.asarray(masks, dtype=np.float32) + colors = np.asarray(colors, dtype=np.float32) # shape(n,3) + s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together + masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3) + self.im[:] = masks * alpha + self.im * (1 - s * alpha) + else: + if len(masks) == 0: + self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 + colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0 + colors = colors[:, None, None] # shape(n,1,1,3) + masks = masks.unsqueeze(3) # shape(n,h,w,1) + masks_color = masks * (colors * alpha) # shape(n,h,w,3) + + inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) + mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3) + + im_gpu = im_gpu.flip(dims=[0]) # flip channel + im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) + im_gpu = im_gpu * inv_alph_masks[-1] + mcs + im_mask = (im_gpu * 255).byte().cpu().numpy() + self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape) + if self.pil: + # convert im back to PIL and update draw + self.fromarray(self.im) + def rectangle(self, xy, fill=None, outline=None, width=1): # Add rectangle to image (PIL-only) self.draw.rectangle(xy, fill, outline, width) @@ -124,6 +171,11 @@ def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'): xy[1] += 1 - h self.draw.text(xy, text, fill=txt_color, font=self.font) + def fromarray(self, im): + # Update self.im from a numpy array + self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + def result(self): # Return annotated image as array return np.asarray(self.im) @@ -180,26 +232,31 @@ def butter_lowpass(cutoff, fs, order): return filtfilt(b, a, data) # forward-backward filter -def output_to_target(output): - # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] +def output_to_target(output, max_det=300): + # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting targets = [] for i, o in enumerate(output): - targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy()) - return np.array(targets) + box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1)) + return torch.cat(targets, 0).numpy() @threaded -def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): +def plot_images(images, targets, paths=None, fname='images.jpg', names=None): # Plot image grid with labels if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() if isinstance(targets, torch.Tensor): targets = targets.cpu().numpy() - if np.max(images[0]) <= 1: - images *= 255 # de-normalise (optional) + + max_size = 1920 # max image size + max_subplots = 16 # max image subplots, i.e. 4x4 bs, _, h, w = images.shape # batch size, _, height, width bs = min(bs, max_subplots) # limit plot images ns = np.ceil(bs ** 0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) # Build Image mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init diff --git a/utils/segment/__init__.py b/utils/segment/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/segment/augmentations.py b/utils/segment/augmentations.py new file mode 100644 index 000000000000..169addedf0f5 --- /dev/null +++ b/utils/segment/augmentations.py @@ -0,0 +1,104 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Image augmentation functions +""" + +import math +import random + +import cv2 +import numpy as np + +from ..augmentations import box_candidates +from ..general import resample_segments, segment2box + + +def mixup(im, labels, segments, im2, labels2, segments2): + # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf + r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 + im = (im * r + im2 * (1 - r)).astype(np.uint8) + labels = np.concatenate((labels, labels2), 0) + segments = np.concatenate((segments, segments2), 0) + return im, labels, segments + + +def random_perspective(im, + targets=(), + segments=(), + degrees=10, + translate=.1, + scale=.1, + shear=10, + perspective=0.0, + border=(0, 0)): + # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) + # targets = [cls, xyxy] + + height = im.shape[0] + border[0] * 2 # shape(h,w,c) + width = im.shape[1] + border[1] * 2 + + # Center + C = np.eye(3) + C[0, 2] = -im.shape[1] / 2 # x translation (pixels) + C[1, 2] = -im.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3) + P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) + P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3) + a = random.uniform(-degrees, degrees) + # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations + s = random.uniform(1 - scale, 1 + scale) + # s = 2 ** random.uniform(-scale, scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3) + T[0, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * width) # x translation (pixels) + T[1, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * height) # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if perspective: + im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114)) + else: # affine + im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) + + # Visualize + # import matplotlib.pyplot as plt + # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() + # ax[0].imshow(im[:, :, ::-1]) # base + # ax[1].imshow(im2[:, :, ::-1]) # warped + + # Transform label coordinates + n = len(targets) + new_segments = [] + if n: + new = np.zeros((n, 4)) + segments = resample_segments(segments) # upsample + for i, segment in enumerate(segments): + xy = np.ones((len(segment), 3)) + xy[:, :2] = segment + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine + + # clip + new[i] = segment2box(xy, width, height) + new_segments.append(xy) + + # filter candidates + i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01) + targets = targets[i] + targets[:, 1:5] = new[i] + new_segments = np.array(new_segments)[i] + + return im, targets, new_segments diff --git a/utils/segment/dataloaders.py b/utils/segment/dataloaders.py new file mode 100644 index 000000000000..f6fe642d077f --- /dev/null +++ b/utils/segment/dataloaders.py @@ -0,0 +1,330 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Dataloaders +""" + +import os +import random + +import cv2 +import numpy as np +import torch +from torch.utils.data import DataLoader, distributed + +from ..augmentations import augment_hsv, copy_paste, letterbox +from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker +from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn +from ..torch_utils import torch_distributed_zero_first +from .augmentations import mixup, random_perspective + + +def create_dataloader(path, + imgsz, + batch_size, + stride, + single_cls=False, + hyp=None, + augment=False, + cache=False, + pad=0.0, + rect=False, + rank=-1, + workers=8, + image_weights=False, + quad=False, + prefix='', + shuffle=False, + mask_downsample_ratio=1, + overlap_mask=False): + if rect and shuffle: + LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') + shuffle = False + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = LoadImagesAndLabelsAndMasks( + path, + imgsz, + batch_size, + augment=augment, # augmentation + hyp=hyp, # hyperparameters + rect=rect, # rectangular batches + cache_images=cache, + single_cls=single_cls, + stride=int(stride), + pad=pad, + image_weights=image_weights, + prefix=prefix, + downsample_ratio=mask_downsample_ratio, + overlap=overlap_mask) + + batch_size = min(batch_size, len(dataset)) + nd = torch.cuda.device_count() # number of CUDA devices + nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + # generator = torch.Generator() + # generator.manual_seed(0) + return loader( + dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn, + worker_init_fn=seed_worker, + # generator=generator, + ), dataset + + +class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing + + def __init__( + self, + path, + img_size=640, + batch_size=16, + augment=False, + hyp=None, + rect=False, + image_weights=False, + cache_images=False, + single_cls=False, + stride=32, + pad=0, + prefix="", + downsample_ratio=1, + overlap=False, + ): + super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls, + stride, pad, prefix) + self.downsample_ratio = downsample_ratio + self.overlap = overlap + + def __getitem__(self, index): + index = self.indices[index] # linear, shuffled, or image_weights + + hyp = self.hyp + mosaic = self.mosaic and random.random() < hyp['mosaic'] + masks = [] + if mosaic: + # Load mosaic + img, labels, segments = self.load_mosaic(index) + shapes = None + + # MixUp augmentation + if random.random() < hyp["mixup"]: + img, labels, segments = mixup(img, labels, segments, *self.load_mosaic(random.randint(0, self.n - 1))) + + else: + # Load image + img, (h0, w0), (h, w) = self.load_image(index) + + # Letterbox + shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape + img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) + shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling + + labels = self.labels[index].copy() + # [array, array, ....], array.shape=(num_points, 2), xyxyxyxy + segments = self.segments[index].copy() + if len(segments): + for i_s in range(len(segments)): + segments[i_s] = xyn2xy( + segments[i_s], + ratio[0] * w, + ratio[1] * h, + padw=pad[0], + padh=pad[1], + ) + if labels.size: # normalized xywh to pixel xyxy format + labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) + + if self.augment: + img, labels, segments = random_perspective( + img, + labels, + segments=segments, + degrees=hyp["degrees"], + translate=hyp["translate"], + scale=hyp["scale"], + shear=hyp["shear"], + perspective=hyp["perspective"], + return_seg=True, + ) + + nl = len(labels) # number of labels + if nl: + labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3) + if self.overlap: + masks, sorted_idx = polygons2masks_overlap(img.shape[:2], + segments, + downsample_ratio=self.downsample_ratio) + masks = masks[None] # (640, 640) -> (1, 640, 640) + labels = labels[sorted_idx] + else: + masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio) + + masks = (torch.from_numpy(masks) if len(masks) else torch.zeros(1 if self.overlap else nl, img.shape[0] // + self.downsample_ratio, img.shape[1] // + self.downsample_ratio)) + # TODO: albumentations support + if self.augment: + # Albumentations + # there are some augmentation that won't change boxes and masks, + # so just be it for now. + img, labels = self.albumentations(img, labels) + nl = len(labels) # update after albumentations + + # HSV color-space + augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"]) + + # Flip up-down + if random.random() < hyp["flipud"]: + img = np.flipud(img) + if nl: + labels[:, 2] = 1 - labels[:, 2] + masks = torch.flip(masks, dims=[1]) + + # Flip left-right + if random.random() < hyp["fliplr"]: + img = np.fliplr(img) + if nl: + labels[:, 1] = 1 - labels[:, 1] + masks = torch.flip(masks, dims=[2]) + + # Cutouts # labels = cutout(img, labels, p=0.5) + + labels_out = torch.zeros((nl, 6)) + if nl: + labels_out[:, 1:] = torch.from_numpy(labels) + + # Convert + img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + img = np.ascontiguousarray(img) + + return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks) + + def load_mosaic(self, index): + # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic + labels4, segments4 = [], [] + s = self.img_size + yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y + + # 3 additional image indices + indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices + for i, index in enumerate(indices): + # Load image + img, _, (h, w) = self.load_image(index) + + # place img in img4 + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + padw = x1a - x1b + padh = y1a - y1b + + labels, segments = self.labels[index].copy(), self.segments[index].copy() + + if labels.size: + labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format + segments = [xyn2xy(x, w, h, padw, padh) for x in segments] + labels4.append(labels) + segments4.extend(segments) + + # Concat/clip labels + labels4 = np.concatenate(labels4, 0) + for x in (labels4[:, 1:], *segments4): + np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() + # img4, labels4 = replicate(img4, labels4) # replicate + + # Augment + img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"]) + img4, labels4, segments4 = random_perspective(img4, + labels4, + segments4, + degrees=self.hyp["degrees"], + translate=self.hyp["translate"], + scale=self.hyp["scale"], + shear=self.hyp["shear"], + perspective=self.hyp["perspective"], + border=self.mosaic_border) # border to remove + return img4, labels4, segments4 + + @staticmethod + def collate_fn(batch): + img, label, path, shapes, masks = zip(*batch) # transposed + batched_masks = torch.cat(masks, 0) + for i, l in enumerate(label): + l[:, 0] = i # add target image index for build_targets() + return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks + + +def polygon2mask(img_size, polygons, color=1, downsample_ratio=1): + """ + Args: + img_size (tuple): The image size. + polygons (np.ndarray): [N, M], N is the number of polygons, + M is the number of points(Be divided by 2). + """ + mask = np.zeros(img_size, dtype=np.uint8) + polygons = np.asarray(polygons) + polygons = polygons.astype(np.int32) + shape = polygons.shape + polygons = polygons.reshape(shape[0], -1, 2) + cv2.fillPoly(mask, polygons, color=color) + nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio) + # NOTE: fillPoly firstly then resize is trying the keep the same way + # of loss calculation when mask-ratio=1. + mask = cv2.resize(mask, (nw, nh)) + return mask + + +def polygons2masks(img_size, polygons, color, downsample_ratio=1): + """ + Args: + img_size (tuple): The image size. + polygons (list[np.ndarray]): each polygon is [N, M], + N is the number of polygons, + M is the number of points(Be divided by 2). + """ + masks = [] + for si in range(len(polygons)): + mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio) + masks.append(mask) + return np.array(masks) + + +def polygons2masks_overlap(img_size, segments, downsample_ratio=1): + """Return a (640, 640) overlap mask.""" + masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio), dtype=np.uint8) + areas = [] + ms = [] + for si in range(len(segments)): + mask = polygon2mask( + img_size, + [segments[si].reshape(-1)], + downsample_ratio=downsample_ratio, + color=1, + ) + ms.append(mask) + areas.append(mask.sum()) + areas = np.asarray(areas) + index = np.argsort(-areas) + ms = np.array(ms)[index] + for i in range(len(segments)): + mask = ms[i] * (i + 1) + masks = masks + mask + masks = np.clip(masks, a_min=0, a_max=i + 1) + return masks, index diff --git a/utils/segment/general.py b/utils/segment/general.py new file mode 100644 index 000000000000..36547ed0889c --- /dev/null +++ b/utils/segment/general.py @@ -0,0 +1,120 @@ +import cv2 +import torch +import torch.nn.functional as F + + +def crop_mask(masks, boxes): + """ + "Crop" predicted masks by zeroing out everything not in the predicted bbox. + Vectorized by Chong (thanks Chong). + + Args: + - masks should be a size [h, w, n] tensor of masks + - boxes should be a size [n, 4] tensor of bbox coords in relative point form + """ + + n, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask_upsample(protos, masks_in, bboxes, shape): + """ + Crop after upsample. + proto_out: [mask_dim, mask_h, mask_w] + out_masks: [n, mask_dim], n is number of masks after nms + bboxes: [n, 4], n is number of masks after nms + shape:input_image_size, (h, w) + + return: h, w, n + """ + + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) + masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.5) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Crop before upsample. + proto_out: [mask_dim, mask_h, mask_w] + out_masks: [n, mask_dim], n is number of masks after nms + bboxes: [n, 4], n is number of masks after nms + shape:input_image_size, (h, w) + + return: h, w, n + """ + + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= mw / iw + downsampled_bboxes[:, 2] *= mw / iw + downsampled_bboxes[:, 3] *= mh / ih + downsampled_bboxes[:, 1] *= mh / ih + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + return masks.gt_(0.5) + + +def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): + """ + img1_shape: model input shape, [h, w] + img0_shape: origin pic shape, [h, w, 3] + masks: [h, w, num] + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + # masks = masks.permute(2, 0, 1).contiguous() + # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0] + # masks = masks.permute(1, 2, 0).contiguous() + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + + if len(masks.shape) == 2: + masks = masks[:, :, None] + return masks + + +def mask_iou(mask1, mask2, eps=1e-7): + """ + mask1: [N, n] m1 means number of predicted objects + mask2: [M, n] m2 means number of gt objects + Note: n means image_w x image_h + + return: masks iou, [N, M] + """ + intersection = torch.matmul(mask1, mask2.t()).clamp(0) + union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection + return intersection / (union + eps) + + +def masks_iou(mask1, mask2, eps=1e-7): + """ + mask1: [N, n] m1 means number of predicted objects + mask2: [N, n] m2 means number of gt objects + Note: n means image_w x image_h + + return: masks iou, (N, ) + """ + intersection = (mask1 * mask2).sum(1).clamp(0) # (N, ) + union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection + return intersection / (union + eps) diff --git a/utils/segment/loss.py b/utils/segment/loss.py new file mode 100644 index 000000000000..b45b2c27e0a0 --- /dev/null +++ b/utils/segment/loss.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..general import xywh2xyxy +from ..loss import FocalLoss, smooth_BCE +from ..metrics import bbox_iou +from ..torch_utils import de_parallel +from .general import crop_mask + + +class ComputeLoss: + # Compute losses + def __init__(self, model, autobalance=False, overlap=False): + self.sort_obj_iou = False + self.overlap = overlap + device = next(model.parameters()).device # get model device + h = model.hyp # hyperparameters + self.device = device + + # Define criteria + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets + + # Focal loss + g = h['fl_gamma'] # focal loss gamma + if g > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + m = de_parallel(model).model[-1] # Detect() module + self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 + self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index + self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance + self.na = m.na # number of anchors + self.nc = m.nc # number of classes + self.nl = m.nl # number of layers + self.nm = m.nm # number of masks + self.anchors = m.anchors + self.device = device + + def __call__(self, preds, targets, masks): # predictions, targets, model + p, proto = preds + bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + lcls = torch.zeros(1, device=self.device) + lbox = torch.zeros(1, device=self.device) + lobj = torch.zeros(1, device=self.device) + lseg = torch.zeros(1, device=self.device) + tcls, tbox, indices, anchors, tidxs, xywhn = self.build_targets(p, targets) # targets + + # Losses + for i, pi in enumerate(p): # layer index, layer predictions + b, a, gj, gi = indices[i] # image, anchor, gridy, gridx + tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj + + n = b.shape[0] # number of targets + if n: + pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, self.nc, nm), 1) # subset of predictions + + # Box regression + pxy = pxy.sigmoid() * 2 - 0.5 + pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1) # predicted box + iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + iou = iou.detach().clamp(0).type(tobj.dtype) + if self.sort_obj_iou: + j = iou.argsort() + b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] + if self.gr < 1: + iou = (1.0 - self.gr) + self.gr * iou + tobj[b, a, gj, gi] = iou # iou ratio + + # Classification + if self.nc > 1: # cls loss (only if multiple classes) + t = torch.full_like(pcls, self.cn, device=self.device) # targets + t[range(n), tcls[i]] = self.cp + lcls += self.BCEcls(pcls, t) # BCE + + # Mask regression + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] + marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized + mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) + for bi in b.unique(): + j = b == bi # matching index + if self.overlap: + mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) + else: + mask_gti = masks[tidxs[i]][j] + lseg += self.single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j]) + + obji = self.BCEobj(pi[..., 4], tobj) + lobj += obji * self.balance[i] # obj loss + if self.autobalance: + self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() + + if self.autobalance: + self.balance = [x / self.balance[self.ssi] for x in self.balance] + lbox *= self.hyp["box"] + lobj *= self.hyp["obj"] + lcls *= self.hyp["cls"] + lseg *= self.hyp["box"] / bs + + loss = lbox + lobj + lcls + lseg + return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() + + def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): + # Mask loss for one image + pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() + + def build_targets(self, p, targets): + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + na, nt = self.na, targets.shape[0] # number of anchors, targets + tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], [] + gain = torch.ones(8, device=self.device) # normalized to gridspace gain + ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) + if self.overlap: + batch = p[0].shape[0] + ti = [] + for i in range(batch): + num = (targets[:, 0] == i).sum() # find number of targets of each image + ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num) + ti = torch.cat(ti, 1) # (na, nt) + else: + ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1) + targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor( + [ + [0, 0], + [1, 0], + [0, 1], + [-1, 0], + [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], + device=self.device).float() * g # offsets + + for i in range(self.nl): + anchors, shape = self.anchors[i], p[i].shape + gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain # shape(3,n,7) + if nt: + # Matches + r = t[..., 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1 < g) & (gxy > 1)).T + l, m = ((gxi % 1 < g) & (gxi > 1)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors + (a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class + gij = (gxy - offsets).long() + gi, gj = gij.T # grid indices + + # Append + indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid + tbox.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors[a]) # anchors + tcls.append(c) # class + tidxs.append(tidx) + xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized + + return tcls, tbox, indices, anch, tidxs, xywhn diff --git a/utils/segment/metrics.py b/utils/segment/metrics.py new file mode 100644 index 000000000000..b09ce23fb9e3 --- /dev/null +++ b/utils/segment/metrics.py @@ -0,0 +1,210 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Model validation metrics +""" + +import numpy as np + +from ..metrics import ap_per_class + + +def fitness(x): + # Model fitness as a weighted combination of metrics + w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9] + return (x[:, :8] * w).sum(1) + + +def ap_per_class_box_and_mask( + tp_m, + tp_b, + conf, + pred_cls, + target_cls, + plot=False, + save_dir=".", + names=(), +): + """ + Args: + tp_b: tp of boxes. + tp_m: tp of masks. + other arguments see `func: ap_per_class`. + """ + results_boxes = ap_per_class(tp_b, + conf, + pred_cls, + target_cls, + plot=plot, + save_dir=save_dir, + names=names, + prefix="Box")[2:] + results_masks = ap_per_class(tp_m, + conf, + pred_cls, + target_cls, + plot=plot, + save_dir=save_dir, + names=names, + prefix="Mask")[2:] + + results = { + "boxes": { + "p": results_boxes[0], + "r": results_boxes[1], + "ap": results_boxes[3], + "f1": results_boxes[2], + "ap_class": results_boxes[4]}, + "masks": { + "p": results_masks[0], + "r": results_masks[1], + "ap": results_masks[3], + "f1": results_masks[2], + "ap_class": results_masks[4]}} + return results + + +class Metric: + + def __init__(self) -> None: + self.p = [] # (nc, ) + self.r = [] # (nc, ) + self.f1 = [] # (nc, ) + self.all_ap = [] # (nc, 10) + self.ap_class_index = [] # (nc, ) + + @property + def ap50(self): + """AP@0.5 of all classes. + Return: + (nc, ) or []. + """ + return self.all_ap[:, 0] if len(self.all_ap) else [] + + @property + def ap(self): + """AP@0.5:0.95 + Return: + (nc, ) or []. + """ + return self.all_ap.mean(1) if len(self.all_ap) else [] + + @property + def mp(self): + """mean precision of all classes. + Return: + float. + """ + return self.p.mean() if len(self.p) else 0.0 + + @property + def mr(self): + """mean recall of all classes. + Return: + float. + """ + return self.r.mean() if len(self.r) else 0.0 + + @property + def map50(self): + """Mean AP@0.5 of all classes. + Return: + float. + """ + return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 + + @property + def map(self): + """Mean AP@0.5:0.95 of all classes. + Return: + float. + """ + return self.all_ap.mean() if len(self.all_ap) else 0.0 + + def mean_results(self): + """Mean of results, return mp, mr, map50, map""" + return (self.mp, self.mr, self.map50, self.map) + + def class_result(self, i): + """class-aware result, return p[i], r[i], ap50[i], ap[i]""" + return (self.p[i], self.r[i], self.ap50[i], self.ap[i]) + + def get_maps(self, nc): + maps = np.zeros(nc) + self.map + for i, c in enumerate(self.ap_class_index): + maps[c] = self.ap[i] + return maps + + def update(self, results): + """ + Args: + results: tuple(p, r, ap, f1, ap_class) + """ + p, r, all_ap, f1, ap_class_index = results + self.p = p + self.r = r + self.all_ap = all_ap + self.f1 = f1 + self.ap_class_index = ap_class_index + + +class Metrics: + """Metric for boxes and masks.""" + + def __init__(self) -> None: + self.metric_box = Metric() + self.metric_mask = Metric() + + def update(self, results): + """ + Args: + results: Dict{'boxes': Dict{}, 'masks': Dict{}} + """ + self.metric_box.update(list(results["boxes"].values())) + self.metric_mask.update(list(results["masks"].values())) + + def mean_results(self): + return self.metric_box.mean_results() + self.metric_mask.mean_results() + + def class_result(self, i): + return self.metric_box.class_result(i) + self.metric_mask.class_result(i) + + def get_maps(self, nc): + return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) + + @property + def ap_class_index(self): + # boxes and masks have the same ap_class_index + return self.metric_box.ap_class_index + + +KEYS = [ + "train/box_loss", + "train/seg_loss", # train loss + "train/obj_loss", + "train/cls_loss", + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP_0.5(B)", + "metrics/mAP_0.5:0.95(B)", # metrics + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP_0.5(M)", + "metrics/mAP_0.5:0.95(M)", # metrics + "val/box_loss", + "val/seg_loss", # val loss + "val/obj_loss", + "val/cls_loss", + "x/lr0", + "x/lr1", + "x/lr2",] + +BEST_KEYS = [ + "best/epoch", + "best/precision(B)", + "best/recall(B)", + "best/mAP_0.5(B)", + "best/mAP_0.5:0.95(B)", + "best/precision(M)", + "best/recall(M)", + "best/mAP_0.5(M)", + "best/mAP_0.5:0.95(M)",] diff --git a/utils/segment/plots.py b/utils/segment/plots.py new file mode 100644 index 000000000000..e882c14390f0 --- /dev/null +++ b/utils/segment/plots.py @@ -0,0 +1,143 @@ +import contextlib +import math +from pathlib import Path + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch + +from .. import threaded +from ..general import xywh2xyxy +from ..plots import Annotator, colors + + +@threaded +def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg', names=None): + # Plot image grid with labels + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(targets, torch.Tensor): + targets = targets.cpu().numpy() + if isinstance(masks, torch.Tensor): + masks = masks.cpu().numpy().astype(int) + + max_size = 1920 # max image size + max_subplots = 16 # max image subplots, i.e. 4x4 + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs ** 0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i, im in enumerate(images): + if i == max_subplots: # if last batch has fewer images than we expect + break + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + im = im.transpose(1, 2, 0) + mosaic[y:y + h, x:x + w, :] = im + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(i + 1): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(targets) > 0: + idx = targets[:, 0] == i + ti = targets[idx] # image targets + + boxes = xywh2xyxy(ti[:, 2:6]).T + classes = ti[:, 1].astype('int') + labels = ti.shape[1] == 6 # labels if no conf column + conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred) + + if boxes.shape[1]: + if boxes.max() <= 1.01: # if normalized with tolerance 0.01 + boxes[[0, 2]] *= w # scale to pixels + boxes[[1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes *= scale + boxes[[0, 2]] += x + boxes[[1, 3]] += y + for j, box in enumerate(boxes.T.tolist()): + cls = classes[j] + color = colors(cls) + cls = names[cls] if names else cls + if labels or conf[j] > 0.25: # 0.25 conf thresh + label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}' + annotator.box_label(box, label, color=color) + + # Plot masks + if len(masks): + if masks.max() > 1.0: # mean that masks are overlap + image_masks = masks[[i]] # (1, 640, 640) + nl = len(ti) + index = np.arange(nl).reshape(nl, 1, 1) + 1 + image_masks = np.repeat(image_masks, nl, axis=0) + image_masks = np.where(image_masks == index, 1.0, 0.0) + else: + image_masks = masks[idx] + + im = np.asarray(annotator.im).copy() + for j, box in enumerate(boxes.T.tolist()): + if labels or conf[j] > 0.25: # 0.25 conf thresh + color = colors(classes[j]) + mh, mw = image_masks[j].shape + if mh != h or mw != w: + mask = image_masks[j].astype(np.uint8) + mask = cv2.resize(mask, (w, h)) + mask = mask.astype(np.bool) + else: + mask = image_masks[j].astype(np.bool) + with contextlib.suppress(Exception): + im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6 + annotator.fromarray(im) + annotator.im.save(fname) # save + + +def plot_results_with_masks(file="path/to/results.csv", dir="", best=True): + # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') + save_dir = Path(file).parent if file else Path(dir) + fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) + ax = ax.ravel() + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." + for f in files: + try: + data = pd.read_csv(f) + index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] + + 0.1 * data.values[:, 11]) + s = [x.strip() for x in data.columns] + x = data.values[:, 0] + for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]): + y = data.values[:, j] + # y[y == 0] = np.nan # don't show zero values + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2) + if best: + # best + ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3) + ax[i].set_title(s[j] + f"\n{round(y[index], 5)}") + else: + # last + ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3) + ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}") + # if j in [8, 9, 10]: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + print(f"Warning: Plotting error for {f}: {e}") + ax[1].legend() + fig.savefig(save_dir / "results.png", dpi=200) + plt.close() diff --git a/val.py b/val.py index 4b0bdddae3b1..6a0f18e28392 100644 --- a/val.py +++ b/val.py @@ -71,12 +71,12 @@ def save_one_json(predn, jdict, path, class_map): def process_batch(detections, labels, iouv): """ - Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format. + Return correct prediction matrix Arguments: - detections (Array[N, 6]), x1, y1, x2, y2, conf, class - labels (Array[M, 5]), class, x1, y1, x2, y2 + detections (array[N, 6]), x1, y1, x2, y2, conf, class + labels (array[M, 5]), class, x1, y1, x2, y2 Returns: - correct (Array[N, 10]), for 10 IoU levels + correct (array[N, 10]), for 10 IoU levels """ correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) iou = box_iou(labels[:, 1:], detections[:, :4]) @@ -102,6 +102,7 @@ def run( imgsz=640, # inference size (pixels) conf_thres=0.001, # confidence threshold iou_thres=0.6, # NMS IoU threshold + max_det=300, # maximum detections per image task='val', # train, val, test, speed or study device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu workers=8, # max dataloader workers (per RANK in DDP mode) @@ -187,7 +188,7 @@ def run( if isinstance(names, (list, tuple)): # old format names = dict(enumerate(names)) class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) - s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') + s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95') dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 loss = torch.zeros(3, device=device) jdict, stats, ap, ap_class = [], [], [], [] @@ -205,7 +206,7 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if compute_loss else (model(im, augment=augment), None) + preds, train_out = model(im) if compute_loss else (model(im, augment=augment), None) # Loss if compute_loss: @@ -215,10 +216,16 @@ def run( targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling with dt[2]: - out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) + preds = non_max_suppression(preds, + conf_thres, + iou_thres, + labels=lb, + multi_label=True, + agnostic=single_cls, + max_det=max_det) # Metrics - for si, pred in enumerate(out): + for si, pred in enumerate(preds): labels = targets[targets[:, 0] == si, 1:] nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions path, shape = Path(paths[si]), shapes[si][0] @@ -258,9 +265,9 @@ def run( # Plot images if plots and batch_i < 3: plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels - plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred + plot_images(im, output_to_target(preds), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred - callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, out) + callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds) # Compute metrics stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy @@ -332,11 +339,12 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') - parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') parser.add_argument('--batch-size', type=int, default=32, help='batch size') parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold') + parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image') parser.add_argument('--task', default='val', help='train, val, test, speed or study') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')