Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I use a ONNX model to retrieve embeddings #3511

Open
brianferrell787 opened this issue Jul 24, 2024 · 0 comments
Open

How do I use a ONNX model to retrieve embeddings #3511

brianferrell787 opened this issue Jul 24, 2024 · 0 comments
Labels
question Further information is requested

Comments

@brianferrell787
Copy link

brianferrell787 commented Jul 24, 2024

Question

Hello, I was wondering how the process works for taking a bert model, transferring it to ONNX, quantizing it, and then using for embedding retrieval works. I have been using this bert model (finbert) https://huggingface.co/yiyanghkust/finbert-pretrain. I have basically tried multiple combinations but will just show this code, and maybe you guys can provide some direction, I don't know much about these sort of techniques:

Export Model to ONNX

from transformers.convert_graph_to_onnx import convert
from transformers import BertTokenizer, BertModel
model_name = 'yiyanghkust/finbert-pretrain'
onnx_model_path = 'path/to/newfinbertmodel.onnx'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name).to('cpu)

# Convert the model to ONNX
convert(framework='pt', model=model, output=onnx_model_path, opset=11, tokenizer=tokenizer)

Optimize the Model Using onnxruntime

from onnxruntime.transformers import optimizer
from onnxruntime.transformers.onnx_model_bert import BertOptimizationOptions
optimized_model_path = 'path/to/optimized_model.onnx'

# Define optimization options
opt_options = BertOptimizationOptions('bert')
opt_options.enable_embed_layer_norm = False

# Optimize the model
opt_model = optimizer.optimize_model(
    onnx_model_path,
    'bert',
    num_heads=12,
    hidden_size=768,
    optimization_options=opt_options
)

opt_model.save_model_to_file(optimized_model_path)
print(f"Optimized model saved to: {optimized_model_path}")

Quantize the Model Using onnxruntime

from onnxruntime.quantization import quantize_dynamic, QuantType
quantized_onnx_model_path = 'path/to/quantized_model.onnx'

# Quantize
quantize_dynamic(optimized_model_path, quantized_onnx_model_path, weight_type=QuantType.QInt8)
print(f"Quantized model saved to: {quantized_onnx_model_path}")

Use the Model with Flair

from flair.embeddings import TransformerOnnxWordEmbeddings
from flair.data import Sentence

finbert_embeddings = TransformerOnnxWordEmbeddings(
    onnx_model=quantized_onnx_model_path ,
    name="finbert-onnx",
    tokenizer=tokenizer,
    embedding_length=768,  
    context_length=0,
    context_dropout=0.0,
    respect_document_boundaries=False,
    stride=0,
    allow_long_sentences=False,
    fine_tune=False,
    truncate=True,
    use_lang_emb=False,
    is_document_embedding=False,
    is_token_embedding=True,
    force_max_length=False,
    feature_extractor=None,
    needs_manual_ocr=False,
    use_context_separator=True,
    providers=["CPUExecutionProvider"] 
)

sentence = Sentence('Flair is a great NLP library.')

finbert_embeddings.embed(sentence)
for token in sentence:
    print(f'Token: {token.text} - Embedding: {token.embedding}')

print(f'Sentence embedding: {sentence.get_embedding()}')

This code gives me an error about "invalid argument ONNXRUNTIMERROR invalid argument invalid feed input name: token_lengths", which I know it indicates that the ONNX model I am using does not have an input named token_lengths. So I tried going a route of fixing that but have also come up short. Like I said, I've tried a combination of things and have seen multiple errors no matter what I do, so I am not necessarily looking for a solution towards this specific error but more maybe someone can shine light on proper direction, thanks!

@brianferrell787 brianferrell787 added the question Further information is requested label Jul 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant