Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No effect after “ fuse_bn_recursively” #7

Open
Ironteen opened this issue Mar 22, 2020 · 3 comments
Open

No effect after “ fuse_bn_recursively” #7

Ironteen opened this issue Mar 22, 2020 · 3 comments

Comments

@Ironteen
Copy link

Ironteen commented Mar 22, 2020

I load mobileNet v2 and operate it by fuse_bn_recursively function, then print the network strutures of this two model, but I found that the bn_fusion net is the same as the initial net, is it because of my misoperation?
`
import torch
from bn_fusion import fuse_bn_recursively
from pytorchcv.model_provider import get_model as ptcv_get_model

if name == 'main':

net = ptcv_get_model('mobilenetv2_w1', pretrained=True)

net1 = fuse_bn_recursively(net)
net1.eval()

net_dict1 = {}
for idx,(name,param) in enumerate(net.named_parameters()):
    net_dict1[name] = param

net_dict2 = {}
for idx,(name,param) in enumerate(net1.named_parameters()):
    net_dict2[name] = param
names = net_dict1.keys()

diff_cnt = 0
for name in names:
    if net_dict1[name].shape!=net_dict2[name].shape:
        diff_cnt +=1
print("diff params:",diff_cnt)

`

@Ironteen Ironteen changed the title No difference No effect after “ fuse_bn_recursively” Mar 22, 2020
@lext
Copy link
Collaborator

lext commented Mar 22, 2020

Please, take a look at the resent example. You need some bells and whistles to make it work.

@Ironteen
Copy link
Author

Please, take a look at the resent example. You need some bells and whistles to make it work.

Thank you very much for your prompt reply. I think I understand what you mean. This project is very enlightening for me.

@jiangyuqi1017
Copy link

jiangyuqi1017 commented Jun 23, 2020

Actually I think the bn fusion in this repo is for depthwise convolution. Not standard convolution, correct me if I am wrong. Please checking the implementation of kito.
if conv_layer_type == 'Conv2D': for i in range(conv_weights.shape[-1]): conv_weights[:, :, :, i] *= A[i] elif conv_layer_type == 'Conv2DTranspose': for i in range(conv_weights.shape[-2]): conv_weights[:, :, i, :] *= A[i] elif conv_layer_type == 'DepthwiseConv2D': for i in range(conv_weights.shape[-2]): conv_weights[:, :, i, :] *= A[i] elif conv_layer_type == 'Conv3D': for i in range(conv_weights.shape[-1]): conv_weights[:, :, :, :, i] *= A[i]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants