Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for Keras H5 models #280

Open
charlesmartin14 opened this issue Aug 29, 2023 · 2 comments
Open

Improve support for Keras H5 models #280

charlesmartin14 opened this issue Aug 29, 2023 · 2 comments

Comments

@charlesmartin14
Copy link
Member

charlesmartin14 commented Aug 29, 2023

We have basic support for Keras H5, but it can only handle very basic simple models

It can not support many older formats , so it needs to

  • fail gracefully if it can not read the model
  • display to the user what they should do (i.e use save weights)
  • add support for reading weights

In particular, we can not save subclassed models using model.save('model.h5') and so I assume we can't read them as well
Instead, we would need to read the weights only, so we need an keras_h5 weights iterator

@charlesmartin14
Copy link
Member Author

charlesmartin14 commented Aug 29, 2023

Here is some code that a basic BERT model, where all we did is save the weights


import h5py

def print_h5_structure(group, level=0):
    # Print the name of the current group
    print("  " * level + f"Group: {group.name}")

    # Recursively iterate over all items (groups and datasets) in the current group
    for name, item in group.items():
        if isinstance(item, h5py.Group):
            print_h5_structure(item, level + 1)
        elif isinstance(item, h5py.Dataset):
            print("  " * (level + 1) + f"Dataset: {name}")

# Open the H5 file
with h5py.File(model_path, 'r') as f:
    # Print the structure recursively starting from the root group
    print_h5_structure(f)

and here is the basic output (cut off)


Group: /
  Group: /bert
    Group: /bert/tf_bert_model
      Group: /bert/tf_bert_model/bert
        Group: /bert/tf_bert_model/bert/embeddings
          Group: /bert/tf_bert_model/bert/embeddings/LayerNorm
            Dataset: beta:0
            Dataset: gamma:0
          Group: /bert/tf_bert_model/bert/embeddings/position_embeddings
            Dataset: embeddings:0
          Group: /bert/tf_bert_model/bert/embeddings/token_type_embeddings
            Dataset: embeddings:0
          Group: /bert/tf_bert_model/bert/embeddings/word_embeddings
            Dataset: weight:0
        Group: /bert/tf_bert_model/bert/encoder
          Group: /bert/tf_bert_model/bert/encoder/layer_._0
            Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention
              Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/output
                Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/output/LayerNorm
                  Dataset: beta:0
                  Dataset: gamma:0
                Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/output/dense
                  Dataset: bias:0
                  Dataset: kernel:0
              Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/self
                Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/self/key
                  Dataset: bias:0
                  Dataset: kernel:0
                Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/self/query
                  Dataset: bias:0
                  Dataset: kernel:0
                Group: /bert/tf_bert_model/bert/encoder/layer_._0/attention/self/value
                  Dataset: bias:0
                  Dataset: kernel:0

@charlesmartin14
Copy link
Member Author

And here is the code to read the BERT model and load the weights into BERT

import h5py
import numpy as np
from transformers import TFBertModel, BertConfig

# 1. Create the BERT model architecture
config = BertConfig()  # Adjust the config as necessary
model = TFBertModel(config)

# Build the model with dummy data
dummy_input_data = np.zeros((1, 128), dtype=np.int32)  # Assuming a sequence length of 128 for dummy data
model(dummy_input_data)

# 2. Load the weights manually
with h5py.File(model_path, 'r') as f:
    # Recursively load weights
    def load_weights_from_group(layer, group):
        # Check if the layer has sub-layers (is a model)
        if hasattr(layer, 'layers'):
            for sub_layer in layer.layers:
                # Only try to load weights for layers with a matching name in the file
                if sub_layer.name in group:
                    sub_group = group[sub_layer.name]
                    load_weights_from_group(sub_layer, sub_group)
        # Load weights for this layer
        if 'kernel:0' in group or 'weight:0' in group:
            # List of arrays: [weights, biases, etc.]
            weights = [group[name][:] for name in group if name.endswith(':0')]
            layer.set_weights(weights)

    # Start from the top-level group
    load_weights_from_group(model, f)

print("Weights loaded successfully!")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant