From 4fa932698b6a3b5e2acc914c86a869a352f512fe Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Sat, 6 Feb 2021 21:04:09 +0800 Subject: [PATCH 01/14] yolotr --- data/scripts/get_coco.sh | 8 ++++--- data/scripts/get_voc.sh | 3 ++- models/common.py | 51 ++++++++++++++++++++++++++++++++++++++++ models/yolo.py | 2 +- models/yolotrs.yaml | 48 +++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 models/yolotrs.yaml diff --git a/data/scripts/get_coco.sh b/data/scripts/get_coco.sh index b0df905c8525..4844d656236f 100755 --- a/data/scripts/get_coco.sh +++ b/data/scripts/get_coco.sh @@ -11,7 +11,8 @@ d='../' # unzip directory url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ f='coco2017labels.zip' # 68 MB -echo 'Downloading' $url$f ' ...' && curl -L $url$f -o $f && unzip -q $f -d $d && rm $f # download, unzip, remove +#echo 'Downloading' $url$f ' ...' && curl -L $url$f -o $f && unzip -q $f -d $d && rm $f # download, unzip, remove +echo 'Downloading' $url$f ' ...' && wget $url$f && unzip -q $f -d $d && rm $f # Download/unzip images d='../coco/images' # unzip directory @@ -20,7 +21,8 @@ f1='train2017.zip' # 19G, 118k images f2='val2017.zip' # 1G, 5k images f3='test2017.zip' # 7G, 41k images (optional) for f in $f1 $f2; do - echo 'Downloading' $url$f '...' && curl -L $url$f -o $f # download, (unzip, remove in background) - unzip -q $f -d $d && rm $f & + #echo 'Downloading' $url$f '...' && curl -L $url$f -o $f # download, (unzip, remove in background) + #echo 'Downloading' $url$f '...' && wget $url$f + unzip -q $f -d $d # && rm $f & done wait # finish background tasks diff --git a/data/scripts/get_voc.sh b/data/scripts/get_voc.sh index 06414b085095..b4c258bc433e 100644 --- a/data/scripts/get_voc.sh +++ b/data/scripts/get_voc.sh @@ -18,7 +18,8 @@ f1=VOCtrainval_06-Nov-2007.zip # 446MB, 5012 images f2=VOCtest_06-Nov-2007.zip # 438MB, 4953 images f3=VOCtrainval_11-May-2012.zip # 1.95GB, 17126 images for f in $f3 $f2 $f1; do - echo 'Downloading' $url$f '...' && curl -L $url$f -o $f # download, (unzip, remove in background) + #echo 'Downloading' $url$f '...' && curl -L $url$f -o $f # download, (unzip, remove in background) + echo 'Downloading' $url$f '...' && wget $url$f unzip -q $f -d $d && rm $f & done wait # finish background tasks diff --git a/models/common.py b/models/common.py index e8adb66293d5..c4210866b5a3 100644 --- a/models/common.py +++ b/models/common.py @@ -40,6 +40,43 @@ def fuseforward(self, x): return self.act(self.conv(x)) +class Transformer(nn.Module): + def __init__(self, c1, c2, num_heads): + super(Transformer, self).__init__() + + self.linear = nn.Linear(c1, c1) + self.ln1 = nn.LayerNorm(c1) + self.q = nn.Linear(c1, c1) + self.k = nn.Linear(c1, c1) + self.v = nn.Linear(c1, c1) + self.ma = nn.MultiheadAttention(embed_dim=c1, num_heads=num_heads) + self.ln2 = nn.LayerNorm(c1) + self.fc1 = nn.Linear(c1, c2) + self.fc2 = nn.Linear(c2, c2) + self.gelu = nn.GELU() + self.c2 = c2 + + def forward(self, x): + b, _, w, h = x.shape + p = x.flatten(2) + p = p.unsqueeze(0) + p = p.transpose(0, 3) + p = p.squeeze(3) + e = self.linear(p) + x = p + e + + x_ = self.ln1(x) + x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x + x = self.ln2(x) + x = self.fc1(x) + x = self.gelu(x) + x = self.fc2(x) + x = x.unsqueeze(3) + x = x.transpose(0, 3) + x = x.reshape(b, self.c2, w, h) + return x + + class Bottleneck(nn.Module): # Standard bottleneck def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion @@ -53,6 +90,13 @@ def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) +class BoT(Bottleneck): + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, shortcut, g, e) + c_ = int(c2 * e) + self.cv2 = Transformer(c_, c2, 4) + + class BottleneckCSP(nn.Module): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion @@ -87,6 +131,13 @@ def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) +class C3T(C3): + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = nn.Sequential(*[BoT(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): diff --git a/models/yolo.py b/models/yolo.py index 6cf9dde08d1a..f99d4332a795 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -210,7 +210,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: + if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3T]: c1, c2 = ch[f], args[0] # Normal diff --git a/models/yolotrs.yaml b/models/yolotrs.yaml new file mode 100644 index 000000000000..a3ffbfbb5653 --- /dev/null +++ b/models/yolotrs.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3T, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] From b479d24feee94a3df9f98a9b5716fe4a3a32cedc Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Thu, 11 Feb 2021 13:08:23 +0800 Subject: [PATCH 02/14] transformer block --- models/common.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ models/yolo.py | 4 ++-- models/yolotrs.yaml | 2 +- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/models/common.py b/models/common.py index c4210866b5a3..20eec645306a 100644 --- a/models/common.py +++ b/models/common.py @@ -77,6 +77,57 @@ def forward(self, x): return x +class TransformerLayer(nn.Module): + def __init__(self, c, num_heads): + super().__init__() + + self.ln1 = nn.LayerNorm(c) + self.q = nn.Linear(c, c) + self.k = nn.Linear(c, c) + self.v = nn.Linear(c, c) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.ln2 = nn.LayerNorm(c) + self.fc1 = nn.Linear(c, c) + self.fc2 = nn.Linear(c, c) + self.act = nn.SiLU() + + def forward(self, x): + x_ = self.ln1(x) + x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x + x = self.ln2(x) + x = self.fc2(self.act(self.fc1(x))) + x + return x + + +class TransformerBlock(nn.Module): + def __init__(self, c1, c2, num_heads, num_layers): + super().__init__() + + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) + self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) + self.c2 = c2 + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + b, _, w, h = x.shape + p = x.flatten(2) + p = p.unsqueeze(0) + p = p.transpose(0, 3) + p = p.squeeze(3) + e = self.linear(p) + x = p + e + + x = self.tr(x) + x = x.unsqueeze(3) + x = x.transpose(0, 3) + x = x.reshape(b, self.c2, w, h) + return x + + class Bottleneck(nn.Module): # Standard bottleneck def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion @@ -138,6 +189,13 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): self.m = nn.Sequential(*[BoT(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) +class C3TR(C3): + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): diff --git a/models/yolo.py b/models/yolo.py index f99d4332a795..860bae86a82f 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -210,7 +210,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3T]: + if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3T, C3TR]: c1, c2 = ch[f], args[0] # Normal @@ -232,7 +232,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) # c2 = make_divisible(c2, 8) if c2 != no else c2 args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3]: + if m in [BottleneckCSP, C3, C3T, C3TR]: args.insert(2, n) n = 1 elif m is nn.BatchNorm2d: diff --git a/models/yolotrs.yaml b/models/yolotrs.yaml index a3ffbfbb5653..624ad3ee4b62 100644 --- a/models/yolotrs.yaml +++ b/models/yolotrs.yaml @@ -21,7 +21,7 @@ backbone: [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, C3T, [1024, False]], # 9 + [-1, 3, C3TR, [1024, False]], # 9 ] # YOLOv5 head From 9a1dee87f0f1b8215aa29d0b36f9f6f2493c819d Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Sun, 28 Feb 2021 00:39:47 +0800 Subject: [PATCH 03/14] Remove bias in Transformer --- models/common.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index b329cf4424dd..935b4f63b5bb 100644 --- a/models/common.py +++ b/models/common.py @@ -83,20 +83,19 @@ def __init__(self, c, num_heads): super().__init__() self.ln1 = nn.LayerNorm(c) - self.q = nn.Linear(c, c) - self.k = nn.Linear(c, c) - self.v = nn.Linear(c, c) + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) self.ln2 = nn.LayerNorm(c) - self.fc1 = nn.Linear(c, c) - self.fc2 = nn.Linear(c, c) - self.act = nn.SiLU() + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) def forward(self, x): x_ = self.ln1(x) x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x x = self.ln2(x) - x = self.fc2(self.act(self.fc1(x))) + x + x = self.fc2(self.fc1(x)) + x return x From d27066f64d4aaff3e322a3a3afa90ed5780e5542 Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Tue, 2 Mar 2021 01:41:13 +0800 Subject: [PATCH 04/14] Remove C3T --- models/common.py | 44 -------------------------------------------- models/yolo.py | 4 ++-- 2 files changed, 2 insertions(+), 46 deletions(-) diff --git a/models/common.py b/models/common.py index 935b4f63b5bb..1dacb4587f02 100644 --- a/models/common.py +++ b/models/common.py @@ -41,43 +41,6 @@ def fuseforward(self, x): return self.act(self.conv(x)) -class Transformer(nn.Module): - def __init__(self, c1, c2, num_heads): - super(Transformer, self).__init__() - - self.linear = nn.Linear(c1, c1) - self.ln1 = nn.LayerNorm(c1) - self.q = nn.Linear(c1, c1) - self.k = nn.Linear(c1, c1) - self.v = nn.Linear(c1, c1) - self.ma = nn.MultiheadAttention(embed_dim=c1, num_heads=num_heads) - self.ln2 = nn.LayerNorm(c1) - self.fc1 = nn.Linear(c1, c2) - self.fc2 = nn.Linear(c2, c2) - self.gelu = nn.GELU() - self.c2 = c2 - - def forward(self, x): - b, _, w, h = x.shape - p = x.flatten(2) - p = p.unsqueeze(0) - p = p.transpose(0, 3) - p = p.squeeze(3) - e = self.linear(p) - x = p + e - - x_ = self.ln1(x) - x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x - x = self.ln2(x) - x = self.fc1(x) - x = self.gelu(x) - x = self.fc2(x) - x = x.unsqueeze(3) - x = x.transpose(0, 3) - x = x.reshape(b, self.c2, w, h) - return x - - class TransformerLayer(nn.Module): def __init__(self, c, num_heads): super().__init__() @@ -182,13 +145,6 @@ def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) -class C3T(C3): - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): - super().__init__(c1, c2, n, shortcut, g, e) - c_ = int(c2 * e) - self.m = nn.Sequential(*[BoT(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) - - class C3TR(C3): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): super().__init__(c1, c2, n, shortcut, g, e) diff --git a/models/yolo.py b/models/yolo.py index dfbd81b5cf64..1f4c9d5f7430 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -211,7 +211,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, - C3, C3T, C3TR]: + C3, C3TR]: c1, c2 = ch[f], args[0] # Normal @@ -233,7 +233,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) # c2 = make_divisible(c2, 8) if c2 != no else c2 args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3, C3T, C3TR]: + if m in [BottleneckCSP, C3, C3TR]: args.insert(2, n) n = 1 elif m is nn.BatchNorm2d: From 3a2b4034f5785be3489c6618ff0b2590e9f610ae Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Tue, 2 Mar 2021 02:54:17 +0800 Subject: [PATCH 05/14] Remove a deprecated class --- models/common.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/models/common.py b/models/common.py index 1dacb4587f02..47f794913fa1 100644 --- a/models/common.py +++ b/models/common.py @@ -104,13 +104,6 @@ def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) -class BoT(Bottleneck): - def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): - super().__init__(c1, c2, shortcut, g, e) - c_ = int(c2 * e) - self.cv2 = Transformer(c_, c2, 4) - - class BottleneckCSP(nn.Module): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion From 258f6b0fec2a09da7cc4d4558e55338cefc4df24 Mon Sep 17 00:00:00 2001 From: DingYiwei <846414640@qq.com> Date: Sat, 6 Mar 2021 16:59:49 +0800 Subject: [PATCH 06/14] put the 2nd LayerNorm into the 2nd residual block --- models/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index e2900912b745..c4f15c1092bb 100644 --- a/models/common.py +++ b/models/common.py @@ -57,8 +57,8 @@ def __init__(self, c, num_heads): def forward(self, x): x_ = self.ln1(x) x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x - x = self.ln2(x) - x = self.fc2(self.fc1(x)) + x + x_ = self.ln2(x) + x = self.fc2(self.fc1(x_)) + x return x From 2b0fdc12dd8069989ab804e49626d1f07eca12be Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Mar 2021 17:35:26 +0200 Subject: [PATCH 07/14] move example model to models/hub, rename to -transformer --- models/{yolotrs.yaml => hub/yolov5s-transformer.yaml} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename models/{yolotrs.yaml => hub/yolov5s-transformer.yaml} (94%) diff --git a/models/yolotrs.yaml b/models/hub/yolov5s-transformer.yaml similarity index 94% rename from models/yolotrs.yaml rename to models/hub/yolov5s-transformer.yaml index 624ad3ee4b62..f2d666722b30 100644 --- a/models/yolotrs.yaml +++ b/models/hub/yolov5s-transformer.yaml @@ -21,7 +21,7 @@ backbone: [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, C3TR, [1024, False]], # 9 + [-1, 3, C3TR, [1024, False]], # 9 <-------- C3TR() Transformer module ] # YOLOv5 head From 33e17af9316857df9459b56d94961496111a6ff1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Mar 2021 17:48:18 +0200 Subject: [PATCH 08/14] Add module comments and TODOs --- models/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/common.py b/models/common.py index 55906acb3f47..738a79f33cc3 100644 --- a/models/common.py +++ b/models/common.py @@ -1,13 +1,12 @@ # YOLOv5 common modules import math -from pathlib import Path - import numpy as np import requests import torch import torch.nn as nn from PIL import Image +from pathlib import Path from torch.cuda import amp from utils.datasets import letterbox @@ -44,9 +43,9 @@ def fuseforward(self, x): class TransformerLayer(nn.Module): + # TODO: comment/source here ... def __init__(self, c, num_heads): super().__init__() - self.ln1 = nn.LayerNorm(c) self.q = nn.Linear(c, c, bias=False) self.k = nn.Linear(c, c, bias=False) @@ -65,9 +64,9 @@ def forward(self, x): class TransformerBlock(nn.Module): + # TODO: comment/source here ... def __init__(self, c1, c2, num_heads, num_layers): super().__init__() - self.conv = None if c1 != c2: self.conv = Conv(c1, c2) @@ -141,6 +140,7 @@ def forward(self, x): class C3TR(C3): + # C3 module with TransformerBlock() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) From 23c9ee3410180f5e7b425fdd4e82812380b45b0f Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Thu, 1 Apr 2021 00:20:26 +0800 Subject: [PATCH 09/14] Remove LN in Transformer --- models/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/models/common.py b/models/common.py index 738a79f33cc3..0e760e5d0f77 100644 --- a/models/common.py +++ b/models/common.py @@ -46,20 +46,16 @@ class TransformerLayer(nn.Module): # TODO: comment/source here ... def __init__(self, c, num_heads): super().__init__() - self.ln1 = nn.LayerNorm(c) self.q = nn.Linear(c, c, bias=False) self.k = nn.Linear(c, c, bias=False) self.v = nn.Linear(c, c, bias=False) self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) - self.ln2 = nn.LayerNorm(c) self.fc1 = nn.Linear(c, c, bias=False) self.fc2 = nn.Linear(c, c, bias=False) def forward(self, x): - x_ = self.ln1(x) - x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x - x_ = self.ln2(x) - x = self.fc2(self.fc1(x_)) + x + x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x + x = self.fc2(self.fc1(x)) + x return x From a5b83467c2a7974b2cdcdb14b8aee6bfc23280ca Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Thu, 1 Apr 2021 00:44:39 +0800 Subject: [PATCH 10/14] Add comments for Transformer --- models/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/common.py b/models/common.py index 0e760e5d0f77..c2d0ee9ec2f7 100644 --- a/models/common.py +++ b/models/common.py @@ -43,7 +43,7 @@ def fuseforward(self, x): class TransformerLayer(nn.Module): - # TODO: comment/source here ... + # Transformer layer described in https://arxiv.org/abs/2010.11929, with LayerNorm layers removed for better performance def __init__(self, c, num_heads): super().__init__() self.q = nn.Linear(c, c, bias=False) @@ -60,13 +60,13 @@ def forward(self, x): class TransformerBlock(nn.Module): - # TODO: comment/source here ... + # Vision Transformer described in https://arxiv.org/abs/2010.11929 def __init__(self, c1, c2, num_heads, num_layers): super().__init__() self.conv = None if c1 != c2: self.conv = Conv(c1, c2) - self.linear = nn.Linear(c2, c2) + self.linear = nn.Linear(c2, c2) # learnable position embedding self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) self.c2 = c2 From e6e5f0ea1704f66c883563fd0e5c7939851cd708 Mon Sep 17 00:00:00 2001 From: Yiwei Ding Date: Thu, 1 Apr 2021 01:26:19 +0800 Subject: [PATCH 11/14] Solve the problem of MA with DDP --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index d5b2d1b75c52..cf015b1e1b20 100644 --- a/train.py +++ b/train.py @@ -218,7 +218,10 @@ def train(hyp, opt, device, tb_writer=None): # DDP mode if cuda and rank != -1: - model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) + # `find_unused_parameters=True` should be passed for the incompatibility of nn.MultiheadAttention with DDP, + # according to https://github.com/pytorch/pytorch/issues/26698 + find_unused_params = False if not [type(layer) for layer in model.modules() if isinstance(layer, nn.MultiheadAttention)] else True + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, find_unused_parameters=find_unused_params) # Model parameters hyp['box'] *= 3. / nl # scale to layers From feb4e76c11d2bf3085502d13d16623648fb90f07 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Apr 2021 16:47:37 +0200 Subject: [PATCH 12/14] cleanup --- models/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index c2d0ee9ec2f7..c9223967204a 100644 --- a/models/common.py +++ b/models/common.py @@ -43,7 +43,7 @@ def fuseforward(self, x): class TransformerLayer(nn.Module): - # Transformer layer described in https://arxiv.org/abs/2010.11929, with LayerNorm layers removed for better performance + # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) def __init__(self, c, num_heads): super().__init__() self.q = nn.Linear(c, c, bias=False) @@ -60,7 +60,7 @@ def forward(self, x): class TransformerBlock(nn.Module): - # Vision Transformer described in https://arxiv.org/abs/2010.11929 + # Vision Transformer https://arxiv.org/abs/2010.11929 def __init__(self, c1, c2, num_heads, num_layers): super().__init__() self.conv = None From d38b59b7101734d456051c680db87708a6783d07 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Apr 2021 17:09:55 +0200 Subject: [PATCH 13/14] cleanup find_unused_parameters --- train.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index cf015b1e1b20..fbbdf3452774 100644 --- a/train.py +++ b/train.py @@ -1,14 +1,10 @@ import argparse import logging import math +import numpy as np import os import random import time -from copy import deepcopy -from pathlib import Path -from threading import Thread - -import numpy as np import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -16,6 +12,9 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import yaml +from copy import deepcopy +from pathlib import Path +from threading import Thread from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -218,10 +217,9 @@ def train(hyp, opt, device, tb_writer=None): # DDP mode if cuda and rank != -1: - # `find_unused_parameters=True` should be passed for the incompatibility of nn.MultiheadAttention with DDP, - # according to https://github.com/pytorch/pytorch/issues/26698 - find_unused_params = False if not [type(layer) for layer in model.modules() if isinstance(layer, nn.MultiheadAttention)] else True - model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, find_unused_parameters=find_unused_params) + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, + # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 + find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) # Model parameters hyp['box'] *= 3. / nl # scale to layers From 5fd5b86b8f92f3b31802310c2e6cacbfffc5ae38 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Apr 2021 17:13:49 +0200 Subject: [PATCH 14/14] PEP8 reformat --- models/common.py | 5 +++-- train.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/models/common.py b/models/common.py index c9223967204a..a25172dcfcac 100644 --- a/models/common.py +++ b/models/common.py @@ -1,12 +1,13 @@ # YOLOv5 common modules import math +from pathlib import Path + import numpy as np import requests import torch import torch.nn as nn from PIL import Image -from pathlib import Path from torch.cuda import amp from utils.datasets import letterbox @@ -66,7 +67,7 @@ def __init__(self, c1, c2, num_heads, num_layers): self.conv = None if c1 != c2: self.conv = Conv(c1, c2) - self.linear = nn.Linear(c2, c2) # learnable position embedding + self.linear = nn.Linear(c2, c2) # learnable position embedding self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) self.c2 = c2 diff --git a/train.py b/train.py index 5dea198ef33f..1f2b467e732b 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,14 @@ import argparse import logging import math -import numpy as np import os import random import time +from copy import deepcopy +from pathlib import Path +from threading import Thread + +import numpy as np import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -12,9 +16,6 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import yaml -from copy import deepcopy -from pathlib import Path -from threading import Thread from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter