Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic EfficientNet Weight Conversion #73

Merged
merged 27 commits into from
Feb 12, 2023
Merged

Conversation

DavidLandup0
Copy link
Owner

@DavidLandup0 DavidLandup0 commented Feb 11, 2023

Basic usage:

dummy_input_tf = tf.ones([1, 224, 224, 3])
dummy_input_torch = torch.ones(1, 3, 224, 224)

tf_model = deepvision.models.EfficientNetV2B0(include_top=False,
                                          pooling='avg',
                                          input_shape=(224, 224, 3),
                                          backend='tensorflow')

tf_model.save('effnet.h5')

from deepvision.models.classification.efficientnet import efficientnet_weight_mapper
pt_model = efficientnet_weight_mapper.load_tf_to_pt(filepath='effnet.h5', dummy_input=dummy_input_tf)

print(tf_model(dummy_input_tf)['output'].numpy())
print(pt_model(dummy_input_torch).detach().cpu().numpy())
# True
np.allclose(tf_model(dummy_input_tf)['output'].numpy(), pt_model(dummy_input_torch).detach().cpu().numpy())
pt_model = deepvision.models.EfficientNetV2B0(include_top=False,
                                          pooling='avg',
                                          input_shape=(3, 224, 224),
                                          backend='pytorch')
torch.save(pt_model.state_dict(), 'effnet.pt')

from deepvision.models.classification.efficientnet import efficientnet_weight_mapper

kwargs = {'include_top': False, 'pooling':'avg', 'input_shape':(3, 224, 224)}
tf_model = efficientnet_weight_mapper.load_pt_to_tf(filepath='effnet.pt',
                                architecture='EfficientNetV2B0',
                                kwargs=kwargs,
                                dummy_input=dummy_input_torch)


pt_model.eval()
print(pt_model(dummy_input_torch).detach().cpu().numpy())
print(tf_model(dummy_input_tf)['output'].numpy())
# True
np.allclose(tf_model(dummy_input_tf)['output'].numpy(), pt_model(dummy_input_torch).detach().cpu().numpy())

@DavidLandup0 DavidLandup0 changed the title EfficientNet Weight Conversion Automatic EfficientNet Weight Conversion Feb 12, 2023
@DavidLandup0 DavidLandup0 deleted the weight_conversion branch February 12, 2023 20:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant