Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply Transformer in the backbone #2329

Closed
dingyiwei opened this issue Mar 1, 2021 · 65 comments · Fixed by #2333 or #5645
Closed

Apply Transformer in the backbone #2329

dingyiwei opened this issue Mar 1, 2021 · 65 comments · Fixed by #2333 or #5645
Labels
documentation Improvements or additions to documentation enhancement New feature or request

Comments

@dingyiwei
Copy link
Contributor

🚀 Feature

Transformer is popular in NLP, and now is also applied on CV. I added C3TR just by replacing the sequential self.m in C3 with a Transformer block, which could reduce GFlOPs and make Yolo achieve a better result.

Motivation

  • Dosovitskiy et al. proposed ViT
  • Facebook applied Transformer on object detection as an encoder
  • So I thought Transformer could make yolo better

Pitch

I add 3 classes in https://github.com/dingyiwei/yolov5/blob/Transformer/models/common.py :

class TransformerLayer(nn.Module):
    def __init__(self, c, num_heads):
        super().__init__()

        self.ln1 = nn.LayerNorm(c)
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.ln2 = nn.LayerNorm(c)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        x_ = self.ln1(x)
        x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
        x = self.ln2(x)
        x = self.fc2(self.fc1(x)) + x
        return x


class TransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers):
        super().__init__()

        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)
        self.linear = nn.Linear(c2, c2)
        self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
        self.c2 = c2

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        b, _, w, h = x.shape
        p = x.flatten(2)
        p = p.unsqueeze(0)
        p = p.transpose(0, 3)
        p = p.squeeze(3)
        e = self.linear(p)
        x = p + e

        x = self.tr(x)
        x = x.unsqueeze(3)
        x = x.transpose(0, 3)
        x = x.reshape(b, self.c2, w, h)
        return x


class C3TR(C3):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)
        self.m = TransformerBlock(c_, c_, 4, n)

And I just put it as the last part of the backbone instead of a C3 block.

backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 9, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 3, C3TR, [1024, False]],  # 9    <---- here is my modifcation
  ]

I conducted experiments on 2 Nvidia GTX 1080Ti cards, where depth_multiple and width_multiple are the same as Yolov5s. Here are my experimental results with img-size 640. For convenience I named the method in this issue as Yolov5TRs.

Model Params GFLOPs
Yolov5s 7266973 17.0
Yolov5TRs 7268765 16.8
Model Dataset TTA mAP@.5 mAP@.5:.95 Speed (ms)
Yolov5s coco (val) N 0.558 0.365 4.4
Yolov5TRs coco (val) N 0.568 0.363 4.4
Yolov5s coco (test-dev) N 0.559 0.365 4.6
Yolov5TRs coco (test-dev) N 0.567 0.365 4.5
Yolov5s coco (test-dev) Y 0.568 0.378 12.0
Yolov5TRs coco (test-dev) Y 0.571 0.375 11.0

We can see that Yolov5TRs get higher scores in mAP@0.5 with a faster speed. (I'm not sure why my results of Yolov5s are different from which shown in README. The model was downloaded from release v4.0) When depth_multiple and width_multiple are set to larger numbers, C3TR should be more lightweight than C3. Since I do not have so much time on it and my machine is not very strong, I did not run experiments on M, L and X. Maybe someone could conduct the future experiments:smile:

@dingyiwei dingyiwei added the enhancement New feature or request label Mar 1, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Mar 1, 2021

👋 Hello @dingyiwei, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@glenn-jocher
Copy link
Member

@dingyiwei hey very cool!! The updates seem a bit faster with a bit less FLOPS... I'll have to look at this a little more in depth, but very quickly I would add that the C3TR module you placed in at the end of the backbone will primarily effect large objects, so many of the smaller objects may not be significantly affected by the change.

To give a bit of background: the largest C3 modules, like the 1024-channel one you replaced are responsible for most of the model parameter count, but execute very fast (due to the small 20x20 feature grid they sample), whereas the the earliest C3 modules like 1-P2/4 and 2-P2/8 have very few parameters, but are slow to execute due to their very small stride and large grid, i.e. 160x160 and 80x80.

So it would be interesting to see the effects of replacing the 256 and 512 channel C3 modules as well has.

@glenn-jocher
Copy link
Member

@dingyiwei just checked, we have a multigpu instance freeing up soon, I think we can add a few C3TR runs to the queue to experiment further. Could you submit a PR with your above updates please?

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 2, 2021

@dingyiwei I pasted your modules into common.py and added C3TR to the modules list in yolo.py, and I can build a model successfully, but my numbers look a little different than yours:

default YOLOv5s

Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPS

[-1, 3, C3TR, [1024, False]], # 9

Model Summary: 276 layers, 6686013 parameters, 6686013 gradients, 16.6 GFLOPS

My full C3TR module (with only self.m different):

class C3TR(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(C3TR, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = TransformerBlock(c_, c_, 4, n)

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

EDIT: had to add C3TR in a second spot in yolo.py, now I match your numbers.

Model Summary: 286 layers, 7277885 parameters, 7277885 gradients, 16.8 GFLOPS

@Joker316701882
Copy link

@dingyiwei @glenn-jocher Applying dropout can greatly improve Transformer's performance, so I did a slight modify on the TransformerLayer and observed improvements on my own model on COCO val(based on YOLOv5L). I'm not very familiar with the standard Transformer, but most codes I saw about Transformer apply Dropout, so the following TransformerLayer could be a better implementation.

class TransformerLayer(nn.Module):
     def __init__(self, c, num_heads):
         super().__init__()
 
         self.ln1 = nn.LayerNorm(c)
         self.q = nn.Linear(c, c, bias=False)
         self.k = nn.Linear(c, c, bias=False)
         self.v = nn.Linear(c, c, bias=False)
         self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
         self.ln2 = nn.LayerNorm(c)
         self.fc1 = nn.Linear(c, c, bias=False)
         self.fc2 = nn.Linear(c, c, bias=False)
         self.dropout = nn.Dropout(0.1)
         self.act = nn.ReLU(True)
 
     def forward(self, x):
         x_ = self.ln1(x)
         x = self.dropout(self.ma(self.q(x_), self.k(x_), self.v(x_))[0]) + x
         x_ = self.ln2(x)
         x_ = self.fc2(self.dropout(self.act(self.fc1(x_))))
         x = x + self.dropout(x_)
         return x

@dingyiwei
Copy link
Contributor Author

dingyiwei commented Mar 3, 2021

Hi @Joker316701882 , actually I removed dropout at the beginning since there's no dropout in this codebase 🤣. I'll have a try now on VOC.
FYI, I tried nn.SiLU before in self.fc2(self.act(self.fc1)) but got a worse result. So you can also run experiments without activation functions in TransformerLayer.

@NanoCode012
Copy link
Contributor

Hello @dingyiwei , may I ask if you trained with multi-gpu option or single-gpu? I saw that you wrote "2 Nvidia GTX 1080Ti cards" in your first post.

The reason I'm asking is that I set 2 GPU & 4GPU runs for the 5m/5l using your backbone and got an error around the 110-120th epoch.

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

Do you perhaps have any clue about this error? I also recall that glenn was planning to do multi-gpu training as well on this branch. Could you tell me if you run into any errors as well?

@dingyiwei
Copy link
Contributor Author

Hi @NanoCode012 , I ran my experiments by python train.py --data coco.yaml --cfg yolotrs.yaml --weights '' --batch-size 64. I saw 2 of my GPUs worked so I just left them running. Thus I've never met this problem since I didn't use DDP mode.

I guess the problem could be caused by nn.MultiheadAttention, according to the error message. Its forward has 2 outputs, attn_output and attn_output_weights, where the first one is what we need:

    def forward(self, x):
        x_ = self.ln1(x)
        x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x   # <---- here we only use the first output
        x = self.ln2(x)
        x = self.fc2(self.fc1(x)) + x
        return x

I'm going to check it when my last experiment finished.

@NanoCode012
Copy link
Contributor

Hello @dingyiwei, I see!

Have you tried just using a single GPU for training instead? From my test on COCO, DP didn't actually speed up training. Maybe you could run two training instead of one :)

I found an issue pytorch/pytorch#26698 which talks about the incompatibility of nn.MultiheadAttention with DDP. I will try their proposed solution below. The author there did mention that it introduced another bug, but I'll have to try to test it out. I guess we will need a PR to DDP if we decide to include the transformer in the backbone.

passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel` 

Another note: this can introduce some overhead in DDP https://pytorch.org/docs/stable/notes/ddp.html

Forward Pass: The DDP takes the input and passes it to the local model, and then analyzes the output from the local model if find_unused_parameters is set to True.... Note that traversing the autograd graph introduces extra overheads, so applications should only set find_unused_parameters to True when necessary.

@dingyiwei
Copy link
Contributor Author

Hi @Joker316701882 , I tested dropout and dropout+act on VOC (based on yolov5s + Transformer), but it seems no obvious promotion. May I ask for your experimental results about dropout?

@glenn-jocher @zhiqwang @NanoCode012 And I found a MISTAKE in my PR #2333 : in a classic Transformer layer, the 2nd LayerNorm should be placed in the 2nd residual block (as described in Joker316701882's comment), according to ViT. But I executed x = self.ln2(x) individually...

image

Fortunately so far I didn't feel any damages or benefits from the mistake, but I'm not sure how it will affect on larger models.

@jaqub-manuel
Copy link

Hey, @dingyiwei,
I applied your addition to Custom dataset and there was a slight increase, 0.005. Why did you apply after SPP (1024-channel), Could you explain little more for YOLOv5? I applied it before SPP (512-channel) but got lower results.
Thanks...

@dingyiwei
Copy link
Contributor Author

Hi @jaqub-manuel , usually components with self-attention mechanism e.g., Non-local and GCNet, are used for extract global information. So I just put Transformer at the last part of the backbone intuitively.

@glenn-jocher is trying to put Transformer in different stages of the backbone and in the head of Yolov5. Maybe his experiments could give us some ideas.

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 5, 2021

@dingyiwei @jaqub-manuel I started an experiment run but got sidetracked earlier in the week. I discovered some important information though. It seems like the transformer block uses up a lot of memory. I created a transformer branch:
https://github.com/ultralytics/yolov5/tree/transformer

And tried to train 8 models, 1 default yolov5m.yaml and then 7 transformer models. Each of the transformer models replaces C3 with C3TR in the location mentioned, i.e. only in layer 2, or only in backbone, etc.
https://github.com/ultralytics/yolov5/tree/transformer/models
Screen Shot 2021-03-05 at 2 14 32 PM

Unfortunately all of the 7 models except the layer 9 model hit CUDA OOM, so I cancelled the training to think a bit. The layers that use the least amount of CUDA memory are the largest stride layers (P5/32), like layer 9, so this may be why @dingyiwei was using it for the test. I think maybe layer 9 is then the best place to implement, as it uses less memory, and affects the whole head. So all I've really learned is that the default test @dingyiwei ran is probably the best for producing a trainable model that doesn't eat too many resources.

@dingyiwei can you update the PR with a fix for the mistake in #2329 (comment), and then I'll train a YOLOv5m model side by side with the layer 9 replacement, and maybe I can try a layer 9 + P5 head replacement also. The P5 layer itself is the largest mAP contributor at 640 resolution, so its not all bad news that we can only apply the transformer to that layer to minimize memory usage.

@NanoCode012
Copy link
Contributor

NanoCode012 commented Mar 12, 2021

Hello, I finished most of my trainings (2 left) on testing the Transformer. I noted down my results in wandb. It's my first time using it, so I hope I'm doing it right.

Transformer runs on wandb

My observations were that the Transformer runs (denoted by tr) produced mixed results. They weren't as clear-cut as in @dingyiwei 's first post. Also, the experiment with 2nd LayerNorm fix 4_5trmv2 got lower results than without the fix 4_trmv1.

image

Edit: Added table here for backup

Name batch_size test map 0.5 test map 0.5..0.95 pyco map 0.5 pyco map 0.5..0.95
1_5m 64 62 42.3 62.7 43.6
1_5trm 64 61.9 42.2 62.2 43.4
4_5mv2 256 62.5 42.6 63.3 43.9
4_5trmv1 256 62.2 42.6 62.9 43.8
4_5trmv2 256 62.1 42.2 62.8 43.4
1_5lv2 48 64 45 64.7 46.2
1_5trl 48 65.4 45.7 66 46.9
4_5trlv3 128 65.3 45.8 66 47
1_5trx 32 - - - -

@dingyiwei
Copy link
Contributor Author

dingyiwei commented Mar 16, 2021

Inspired by @NanoCode012 , I tried to remove both LayerNorm layers of Transformer in YOLOv5s, and got a surprise:

Model Dataset TTA mAP@.5 mAP@.5:.95
Yolov5s coco (val) N 0.558 0.365
Yolov5s + Tr coco (val) N 0.568 0.363
Yolov5s + Tr(without LN) coco (val) N 0.571 0.366

Will run on test-dev and upload the model later.

UPDATE:

Experimental results:

Model Dataset TTA mAP@.5 mAP@.5:.95
Yolov5s coco (val) N 0.558 0.365
Yolov5s + Tr coco (val) N 0.568 0.363
Yolov5s + Tr(without LN) coco (val) N 0.571 0.366
Yolov5s coco (test-dev) N 0.559 0.365
Yolov5s + Tr coco (test-dev) N 0.567 0.365
Yolov5s + Tr(without LN) coco (test-dev) N 0.569 0.366
Yolov5s coco (test-dev) Y 0.568 0.378
Yolov5s + Tr coco (test-dev) Y 0.571 0.375
Yolov5s + Tr(without LN) coco (test-dev) Y 0.573 0.377

Here is the implementation:

class TransformerLayer(nn.Module):
    def __init__(self, c, num_heads):
        super().__init__()

        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
        x = self.fc2(self.fc1(x)) + x
        return x

New model is here.

@Joker316701882
Copy link

@dingyiwei According to your posted results, the mAP@0.5 improved but mAP@.05:.95 remains unchanged. Does it mean mAP@.75 actually dropped?

@dingyiwei
Copy link
Contributor Author

Hi @Joker316701882 , I didn't record mAP@.75 in those experiments. According to @glenn-jocher 's explanation, C3TR at the end of the backbone could affect on large objects. So I guess mAP@.95 would drop and mAP@.75 might be unchanged.

@jaqub-manuel
Copy link

Dear @dingyiwei ,
could you upload new model or share link or code, then i will try for my custom dataset.

@dingyiwei
Copy link
Contributor Author

Hi everyone, I updated the experimental results, the implementation and the trained model of C3TR without LN in this comment the day before yesterday. It seems editing a comment would not trigger a notification or an email, so I just remind you about that.

@glenn-jocher
Copy link
Member

@dingyiwei very interesting result! I think layernorm() is a pretty resource intensive operation (at least when compared to batchnorm). Did removing it reduce the training memory requirements?

@dingyiwei
Copy link
Contributor Author

Hi @glenn-jocher , in my experiments yes. For YOLOv5s + TR, gpu_mem showed 6.63G, while for YOLOv5s + TR(without LN), gpu_mem showed 6.61G.

@glenn-jocher
Copy link
Member

@dingyiwei thanks for the info, so not much of a change in memory from removing layernorm().

@zachluo
Copy link

zachluo commented Mar 28, 2021

hi, all, did anyone try position embedding? It seems like the transformer helps classification rather than localization according to the results of AP@0.5 and AP@0.5:0.95.

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 29, 2021

@dingyiwei I'm working on getting the Transformer PR #2333 merged, I merged master to bring it up to date with the latest changes, and I noticed that the TransformerLayer() module in the PR is different from your most recent in #2329 (comment), which do you think we should we use for the PR? Let me know, thanks!

@glenn-jocher
Copy link
Member

@dingyiwei also we should add a one-line comment for each of the 3 new modules that explains a bit or cites a source if you can please. I've done this with C3TR(), but left the other two up to you.

Once we have these updates and decide on TransformerLayer() then I can merge the PR. Thanks!

@glenn-jocher glenn-jocher linked a pull request Mar 29, 2021 that will close this issue
@Alex-afka
Copy link

@glenn-jocher在我的实验中是的。对于YOLOv5s+TR,gpu_mem显示6.63G,而YOLOv5S+TR(无LN),gpu_mem显示6.61G。

how train this new module. can you show me the detail about this? you train with pretrain ?or train from scratch?

@glenn-jocher
Copy link
Member

@guyiyifeurach there are no transformer pretrained weights, but you can start from the normal pretrained weights instead. To train a YOLOv5s transformer model in our Colab notebook for example:
https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb

# Train YOLOv5s on COCO128 for 3 epochs
!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cfg yolov5s-transformer.yaml

@qiy20
Copy link

qiy20 commented Oct 24, 2021

This dimensional operation will change the batch_size dim? I don't understand why we're doing this?

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

I think the right operation is:

p = x.flatten(2).transpose(1, 2)

@dingyiwei

@dingyiwei
Copy link
Contributor Author

Hi @qiy20 , I forgot why to write this piece of code😂. Feel free to update it if you confirm it is correct.

@glenn-jocher
Copy link
Member

@qiy20 @dingyiwei would the right simplification be this?

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

# simplied
p = x.flatten(2).transpose(0, 2)

@dingyiwei
Copy link
Contributor Author

@glenn-jocher I think no..

# b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c 
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)

# b,c,w,h-->b,c,wh-->b,wh,c
p = x.flatten(2).transpose(1, 2)

# b,c,w,h-->b,c,wh-->wh,c,b
p = x.flatten(2).transpose(0, 2)

I thought my original idea was to keep c after b. transpose once cannot do that.

An alternative is adding batch_first=True in MultiheadAttention, then we could

p = x.flatten(2).transpose(1, 2)
return self.tr(p + self.linear(p)).transpose(1, 2).reshape(b, self.c2, w, h)

I'll verify it with experiments. Let me know if you get different ideas :)

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 10, 2021

@dingyiwei ok I think I've got it. Yes are right, transpose is acting unexpectedly. I had to use permute, but this seems to result in a 2x speedup:

import torch

x= torch.rand(16,3,80,40)
p1 = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
p2 = x.flatten(2).permute(2,0,1)
print(torch.allclose(p1,p2))  # True

%timeit x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
# 5.36 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit x.flatten(2).permute(2,0,1)
# 2.83 µs ± 62 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@glenn-jocher
Copy link
Member

@dingyiwei if batch_first=True profiles faster that might be the better solution.

@dingyiwei
Copy link
Contributor Author

@glenn-jocher Training time and inference time appear no difference among the current code, permute and batch_first=True.

I ran 10 epochs for each solution with python train.py --data data/coco.yaml --cfg models/hub/yolov5s-transformer.yaml --weights '' --batch-size 32 --epochs 10 and tested them with python val.py --data data/coco.yaml --weights runs/train/exp/weights/best.pt --img 640 on one 2080ti.

Model Training time (hour) Inference time (ms)
Current 5.053 2.6
Permute 5.053 2.6
Batch first 5.053 2.6

But permute is more elegant and readable, I'll submit a pull request for it.

p = x.flatten(2).permute(2, 0, 1)
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)

@glenn-jocher
Copy link
Member

@dingyiwei understood! Yes please submit a PR for permute().

@glenn-jocher
Copy link
Member

@dingyiwei #5645 PR is merged, replacing multiple transpose ops with a single permute in TransformerBlock(). Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐

@qiy20
Copy link

qiy20 commented Nov 17, 2021

Sorry for the delay. @dingyiwei is right! I ignore the arg batch_first=FLase.

@qiy20
Copy link

qiy20 commented Nov 17, 2021

But,i have another question about the pos embeding.
self.linear = nn.Linear(c2, c2) # learnable position embedding
It seems to be difference values to difference feature maps, if the feature map values change, the pos embedding will change. this is different from ViT or transformer, Why do you think so?
@dingyiwei @glenn-jocher

@dingyiwei
Copy link
Contributor Author

Good question😂 Indeed ViT uses 1D learnable random-generated parameters as the pos embedding. I knew more about CV but little about NLP so I felt unfamiliar with the pos embedding at that time and applied a common operation in CV - something like a residual Linear layer.

Detection is different from classification. It's hard to say whether a residual layer or standalone parameters works better for the pos embedding on Yolo. I'll try to conduct experiments on this issue and post results here.

@qiy20
Copy link

qiy20 commented Nov 18, 2021

I think the pos embedding reflects the distance between the feature points, so standalone parameters may be better, the Linear(x) doesn't contain much position information.

@sakurasakura1996
Copy link

@dingyiwei I have a question, why transformer block only includes encoder, not including decoder. Is the encoder more suitable for classification tasks?

@nrupatunga
Copy link
Contributor

nrupatunga commented Nov 30, 2021

@sakurasakura1996

my understanding is that here the intention of adding transformer block is to get the better features (by attending to different parts of the image), which might result in better box/class predictions compared to other modules (eg. C3)

@iscyy
Copy link

iscyy commented Dec 2, 2021

@dingyiwei hi, I have a question, if the transformer module is added, does it mean that the previous pure CNN pre-training weights can no longer be used.

@dingyiwei
Copy link
Contributor Author

@Him-wen Yes, you have to train the model from scratch.

@mx2013713828
Copy link

@Him-wen Yes, you have to train the model from scratch.

can you provide a pretrained transformer model?thx!!!

@dingyiwei
Copy link
Contributor Author

@mx2013713828 You may want to find a outdated model here with this commit. No official pretrained models for Yolov5s-transformer.

@glenn-jocher glenn-jocher added documentation Improvements or additions to documentation and removed Stale labels May 5, 2022
@zhangweida2080
Copy link

zhangweida2080 commented Aug 5, 2022

@dingyiwei Do you have a reference to use this kind of structure?

  • Your seems only use Transformer in the last C3 layer, why not other layers?
  • You did not use 4 times of hidden neurons in linear layer.

@dingyiwei
Copy link
Contributor Author

dingyiwei commented Aug 7, 2022

@zhangweida2080 You may want to take a look at my first few comments in this thread.

  • My original idea was from the work of Google.
  • If I put the component in other layers, the model would become very huge and could be hardly trained.
  • I actually didn't work on CV for a long time and this thread was started more than 1 year ago... But I'm a consequentialist so I didn't care about what should I put in the network but how could I reach a better result. That was why I applied Transformer. Is there a fixed thinking in the related area that we must put 4 times of hidden neurons in a linear layer (in a Transformer structure)? If it really works, welcome to put a better experimental result here :)

@zhangweida2080
Copy link

zhangweida2080 commented Aug 8, 2022

@dingyiwei Thank you for your reply. There is no fixed thinking about the usage in different settings.
However, since your original idea is from ViT (https://arxiv.org/pdf/2010.11929.pdf), I suppose you will follow the implementation of vit.
There are some differences:

Thanks a lot.

@dingyiwei
Copy link
Contributor Author

@zhangweida2080 For the first 2 questions, I had to work out a way to get a better result on it in a very short time due to my personal requirements, so I built a much simpler structure than the Transformer in that paper (but it really worked on COCO anyway) and shared here. If I got more time and more resources, I would try more structures and conduct more experiments.
For the 3rd question, it's really common when you try to apply a popular model on a customized dataset, since the model could be tended to improve for some popular datasets. I cannot help on your specific problem, but I would suggest training a pre-trained model on a large dataset, collecting as much as data you could and do your best on data augmentation. Good luck :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet