Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to import dynamically quantized roberta model in ONNX format. #1656

Open
tsamiss opened this issue Jun 25, 2024 · 1 comment
Open
Labels
bug Something isn't working

Comments

@tsamiss
Copy link

tsamiss commented Jun 25, 2024

Describe the bug
I cannot seem to load a dynamically quantized roberta model for cpu inference in ONNX format.
I can load the pre-quantized model just fine. Currently working on a Vertex AI instance on GCP.

Expected behavior
The Engine is expected to load the model.

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 18.04]: linux 5.10.0-30-cloud-amd64
  2. Python version [e.g. 3.8]: Python 3.10.14
  3. DeepSparse version or commit hash [e.g. 0.1.0, f7245c8]: 1.7.1
  4. ML framework version(s) [e.g. torch 1.7.1]: torch==1.13.1, transformers==4.40.2
  5. Other Python package versions [e.g. SparseML, Sparsify, numpy, ONNX]: onnx==1.14.1, onnxruntime==1.16.0
  6. CPU info - output of deepsparse/src/deepsparse/arch.bin or output of cpu_architecture() as follows:
>>> import deepsparse.cpu
>>> print(deepsparse.cpu.cpu_architecture())

{'L1_data_cache_size': 32768,
'L1_instruction_cache_size': 32768,
'L2_cache_size': 1048576,
'L3_cache_size': 40370176,
'architecture': 'x86_64',
'available_cores_per_socket': 2,
'available_num_cores': 2,
'available_num_hw_threads': 4,
'available_num_numa': 1,
'available_num_sockets': 1,
'available_sockets': 1,
'available_threads_per_core': 2,
'bf16': False,
'cores_per_socket': 2,
'dotprod': False,
'i8mm': False,
'isa': 'avx512',
'num_cores': 2,
'num_hw_threads': 4,
'num_numa': 1,
'num_sockets': 1,
'threads_per_core': 2,
'vbmi': False,
'vbmi2': False,
'vendor': 'GenuineIntel',
'vendor_id': 'Intel',
'vendor_model': 'Intel(R) Xeon(R) CPU @ 2.00GHz',
'vnni': False,
'zen1': False}

To Reproduce

import os
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from onnxruntime.quantization import quantize_dynamic
from deepsparse import Engine


class RobertaWrapper(torch.nn.Module):
    def __init__(self, model):
        super(RobertaWrapper, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, position_ids):
        outputs = self.model(input_ids, attention_mask=attention_mask, position_ids=position_ids)
        return outputs['logits']


tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
model.eval()
model = RobertaWrapper(model)

print('Model Loading - OK!')

text_sample = 'This is just a sample text'

inputs = tokenizer(text_sample, return_tensors="pt", truncation=True, padding="max_length", max_length=258)
dummy_input_ids = inputs['input_ids']
dummy_attention_mask = inputs['attention_mask']
dummy_position_ids = torch.arange(0, dummy_input_ids.shape[1]).view(1, -1)


torch.onnx.export(
    model, 
    (dummy_input_ids, dummy_attention_mask, dummy_position_ids), 
    "./model_roberta_base.onnx",
    input_names=['input_ids', 'attention_mask', 'position_ids'], 
    output_names=['logits'],
)

print('Model ONNX Exporting - OK!')

os.system('python -m onnxruntime.quantization.preprocess --input model_roberta_base.onnx --output model_quant_preproc.onnx')

print('Model Quantization Preprocessing - OK!')

model_fp32 = 'model_quant_preproc.onnx'
model_quant = 'model_quant_dyna.onnx'
quantized_model = quantize_dynamic(model_fp32, model_quant)

print('Model Quantization - OK!')

model_path = "./model_quant_dyna.onnx"
compiled_model = Engine(model=model_path)

Errors

DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.7.1 COMMUNITY | (3904e8ec) (release) (optimized) (system=avx512, binary=avx512)
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.7.1 (3904e8ec) (release) (optimized) (system=avx512, binary=avx512)
Date: 06-25-2024 @ 14:51:52 UTC
OS: Linux instance-tsamis 5.10.0-30-cloud-amd64 #1 SMP Debian 5.10.218-1 (2024-06-01)
Arch: x86_64
CPU: GenuineIntel
Vendor: Intel
Cores/sockets/threads: [2, 1, 4]
Available cores/sockets/threads: [2, 1, 4]
L1 cache size data/instruction: 32k/32k
L2 cache size: 1Mb
L3 cache size: 38.5Mb
Total memory: 14.6487G
Free memory: 4.36606G
Thread: 0x7f1514ec7740

Assertion at ./src/include/wand/engine/execution/planner.hpp:118

Backtrace:
0# 0x00007f14744de663:
[440fb6c34c8b25722a1e026a004489e94c89f641b9010000004c89e7e8fc7716]
[02585a84db75084c89e7e86ee2ffff4c89e7e8f677160248833d6e251e020074]
1# 0x00007f14744e0398:
[845a1602b901000000ba76000000488d35136f85fe488d3d2bc27cfee877e2ff]
[ff4889c3c5f877e9edfaffff4889c3e912fbffff4889c3c5f877e9f7faffff48]
2# wand::engine::compiler::compiler::execution_graph_to_linear_order(wand::engine::execution::graph&&) const in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/deepsparse/avx512/libonnxruntime.so.1.15.1
3# wand::engine::compiler::compiler::compile(wand::engine::execution::graph&&) const in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/deepsparse/avx512/libonnxruntime.so.1.15.1
4# wand::engine::compiler::compiler::compile(wand::engine::compute::compute_graph&&) const in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/deepsparse/avx512/libonnxruntime.so.1.15.1
5# wand::engine::compiler::compiler::compile(wand::engine::intake::graph&&) const in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/deepsparse/avx512/libonnxruntime.so.1.15.1
6# 0x00007f14735a79f7:
[ff488bbc24c80000004885ff7405e8269debff4c89ea4889de4c89f7e8a8e509]
[0349c7042400000000bff0000000e8a6930103488b157f1881ff488d4810c5f1]
7# 0x00007f14735adb83:
[850801000083400801488dbbb00000004d89e04c89e94c89f24889dee81c9cff]
[ff488b45b84989c64885c0741f488b1d593011034885db0f85c00000008b4008]
8# 0x00007f14735ef5a0:
[5648488b464c53415450524889fa488bbde0fcffff488985f8fcffffe8ffe3fb]
[ff488bbd08feffff4883c4204885ff7405e85a21e7ff488bbd18fdffff4885ff]
9# 0x00007f14735b8915:
[0000488dbc2490000000488d4c24404c8b44241848897c2420488b33e81a5a03]
[00488b4308c5f9efc0c5f96f942490000000c5f97f8424900000004889442430]
10# 0x00007f14735b9cf0:
[4154c4e1f96eca4989fcc4e3f122c6014883ec104889e6c5f97f0424e81feaff]
[ff4883c4104c89e0415cc3cccccccccc488b07c5fa6f06c5fa7f00c5fa6f4e10]
11# 0x00007f1473e47258:
[02e992ea7f02cccc415741564155415455534883ec38488b0648897c2410ff50]
[28488b4424104c8b7008488b184c89f04829d84889c248c1f80548c1fa034885]
12# 0x00007f1473e489f2:
[488b742470488dbc24700100004c89f24889c148897c241848890424e84de8ff]
[ff488b4310488b0bc5f9efc0c5f96fb42470010000488b530848894424104989]
13# 0x00007f1473e4eaa3:
[488d8550feffff4889c7488985f8fcffff4889c3c5fa7f8538feffffe8dc9cff]
[ff4883bd50feffff000f847803000041b879010000488d0dd0d6e4fe4889de31]
14# 0x00007f1473e5141e:
[89f94c8b0bff75c04c89f24c89ee488bbd58ffffffffb548ffffff50e8b1d2ff]
[ff4883c42048837d800074b6488bb558ffffff41b825020000488d0d6abce5fe]
15# 0x00007f14734976f4:
[ffc5f97f8590feffff488d4838488d85c0feffff5048898550feffffe8ab9a9b]
[004883bd60ffffff00585a0f849b000000418bbf74090000488bb578feffff41]
16# 0x00007f14734a3e68:
[0fb68540fcffff488b9578fcffff4889de4c89ff89c1898560fcffffe89734ff]
[ff4883bdd0fdffff000f84e10000008bbb7409000041b81a0600004c89fe488d]
17# 0x00007f1473462679:
[836f080175e1ebbf0f1f8000000000488b442430488b7c2438488b30e8c60304]
[00488b44245048894424304885c00f84eafcffff488b7c2438e8e97c9c00e9c7]
18# 0x00007f1473476984:
[83c4184c89e05b5d415c415dc30f1f800000000031d24c89ee4889efe81bb2fe]
[ff488b7c24084989c44885c075c648893b4883c4184c89e05b5d415c415dc348]
19# deepsparse::ort_engine::init(wand::arch_t const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, int, std::shared_ptrwand::parallel::scheduler_factory_t, std::optional const&, std::optional const&) in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/deepsparse/avx512/libdeepsparse.so
20# 0x00007f1476774697:
[89e94d89f88b8d2cffffff4c89f7488b9530ffffff488b358d1f0d00e8d8820c]
[00488b7da8595e4885ff7405e8f87d0c00488b7d804885ff0f848f000000807d]
21# 0x00007f14767758ab:
[4518488b4d104c8d4d806a00488bb5f8feffff4489e24889c74989c5e874ebff]
[ff488b7d88585a4885ff7405e8e46b0c004c89ffe8dc700c00e9c1fcffff0f1f]
22# 0x00007f14767b707a:
[feffff41574c8b8538feffff4d89f1488b8d58feffff8bb550feffffe825e4fb]
[ff488bbd78feffff4989c4585a4885ff7405e80f54080080bd18ffffff007432]
23# 0x00007f1476780537:
[33ffff4c8dac24c0010000488b78404c89eee8a2c10b00498b04244c89e7ff50]
[304989c74c89ef4889442478e89835ffff4983ff010f85c40300004983c4684c]
24# cfunction_call at /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:543
25# _PyObject_MakeTpCall.localalias at /usr/local/src/conda/python-3.10.14/Objects/call.c:215
26# method_vectorcall at /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:112
27# method_vectorcall at /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:99
28# method_vectorcall at /usr/local/src/conda/python-3.10.14/Objects/classobject.c:83
29# slot_tp_init at /usr/local/src/conda/python-3.10.14/Objects/typeobject.c:7737
30# type_call at /usr/local/src/conda/python-3.10.14/Objects/typeobject.c:1135
31# pybind11_meta_call in /home/jupyter/NLP_Workflows/env/lib/python3.10/site-packages/onnx/onnx_cpp2py_export.cpython-310-x86_64-linux-gnu.so
32# _PyObject_MakeTpCall.localalias at /usr/local/src/conda/python-3.10.14/Objects/call.c:215

Please email a copy of this stack trace and any additional information to: support@neuralmagic.com

Additional context
When using static quantization this error does not occur.

@tsamiss tsamiss added the bug Something isn't working label Jun 25, 2024
@bfineran
Copy link
Member

hi @tsamiss deepsparse does not support dynamic quantization. Static quantization model training and export is provided in neuralmagic/sparseml

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants