Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Build] RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running MatMul node. Name:'/MatMul_7' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,512} != {1,32,512}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model. #21320

Closed
kabyanil opened this issue Jul 11, 2024 · 7 comments
Labels
build build issues; typically submitted using template

Comments

@kabyanil
Copy link

kabyanil commented Jul 11, 2024

Describe the issue

I am trying to convert a pytorch transformer model to onnx. My model architecture consists of multiple nn.modules, so I am converting each to onnx separately. I am having to use a combination of torch.onnx.export() and torch.onnx.dynamo_export(), because some module conversions do not support dynamo_export yet.

I am able to convert all the modules to onnx. However, when I run an inference session through the decoder module, I get the mentioned error. For reference, here is my Decoder class -

class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

Here is my code for onnx conversion of the module -

dummy_decoder_input = torch.randint(low=0, high=60, size =(1, 1, 512), dtype=torch.float)
dummy_encoder_output = torch.randint(low=0, high=60, size =(1, 32, 512), dtype=torch.float)
dummy_src_mask = torch.tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0]]]], dtype=torch.int)
dummy_tgt_mask = torch.tensor([[1]], dtype=torch.int)

args = (dummy_decoder_input, dummy_encoder_output, dummy_src_mask, dummy_tgt_mask)
# encoder_output, source_mask, decoder_input, decoder_mask
dynamic_axes = {
    'decoder_input': {0: 'batch_size', 1: 'seq_len', 2: 'embed_dim'},
    'encoder_output': {0: 'batch_size', 1: 'seq_len', 2: 'embed_dim'},
    'src_mask': {3: 'seq_len'},
    'tgt_mask': {0: 'seq_len', 1: 'seq_len'},
    'output': {0: 'batch_size', 1: 'sequence_length'}
}

x, encoder_output, src_mask, tgt_mask
torch.onnx.export(test_scripted_decoder,
                  args=args,
                  f="./onnx/decoder.onnx",
                  input_names=['decoder_input', 'encoder_output', 'src_mask', 'tgt_mask'],
                  output_names=['output'],
                  dynamic_axes=dynamic_axes,
                  verbose=True
                )

Here is my inference code in onnx -

def run_inference(input_string):
  input_tensor = encode_input(eng_tokenizer, input_string)
  encoder_input = prepare_encoder_input(eng_tokenizer, input_tensor)
  src_mask = prepare_encoder_mask(eng_tokenizer, encoder_input)

  # Run encoder
  src_embed_output = src_embed_layer.run(None, {'l_x_': encoder_input.numpy()})[0]

  src_pos_output = src_pos_layer.run(None, {'l_x_': src_embed_output})[0]

  encoder_output = src_encoder_layer.run(None, {'input_1': src_pos_output, 'input_2': src_mask.numpy()})[0]


  # Run decoder
  tgt_input = torch.tensor([[eng_tokenizer.encode('<')[0]]], dtype=torch.int32)

  while True:
      if tgt_input.size(1) == 32:
          break

      tgt_mask = causal_mask(tgt_input.size(1)).numpy().astype(np.int32)


      tgt_embed_output = tgt_embed_layer.run(None, {'l_x_': tgt_input.numpy().astype(np.int32)})[0]


      tgt_pos_output = tgt_pos_layer.run(None, {'input_1': tgt_embed_output})[0]

# ERROR OCCURS ON THE NEXT LINE
      decoder_output = tgt_decoder_layer.run(None, {'decoder_input': tgt_pos_output, 'encoder_output': encoder_output, 'src_mask': src_mask.numpy(), 'tgt_mask': tgt_mask})[0]

      last_dim = decoder_output[:, -1]

      prob = tgt_projection_layer.run(None, {'l_x_': last_dim})[0]
      next_word = np.argmax(prob, axis=1)[0]

      next_word_tensor = torch.tensor([[next_word]], dtype=torch.int64)
      tgt_input = torch.cat((tgt_input, next_word_tensor), dim=1)

      if next_word == eng_tokenizer.encode('>')[0]:
          break

  output_tokens = tgt_input.squeeze().tolist()
  output_string = eng_tokenizer.decode(output_tokens)
  return output_string

run_inference("hello")

I have marked the line which throws the error above. Here is the full error -

---------------------------------------------------------------------------
RuntimeException                          Traceback (most recent call last)
[<ipython-input-41-ce3f1a0c8e40>](https://localhost:8080/#) in <cell line: 47>()
     45   return output_string
     46 
---> 47 run_inference("hello")

1 frames
[/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py](https://localhost:8080/#) in run(self, output_names, input_feed, run_options)
    218             output_names = [output.name for output in self._outputs_meta]
    219         try:
--> 220             return self._sess.run(output_names, input_feed, run_options)
    221         except C.EPFail as err:
    222             if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running MatMul node. Name:'/MatMul_7' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,512} != {1,32,512}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

Any help in resolving this issue would be appreciated. Thanks.

Urgency

The issue is urgent to me. I am building a project in which I need to deploy the models on web, mobile and desktop. I chose onnx as it is in active development.

As I wait for a resolution, I am considering experimenting with executorch.

Target platform

Google Colab Ubuntu 22.04.3 LTS

Build script

Described in the description above.

Error / output


RuntimeException Traceback (most recent call last)
in <cell line: 47>()
45 return output_string
46
---> 47 run_inference("hello")

1 frames
/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
218 output_names = [output.name for output in self._outputs_meta]
219 try:
--> 220 return self._sess.run(output_names, input_feed, run_options)
221 except C.EPFail as err:
222 if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running MatMul node. Name:'/MatMul_7' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,512} != {1,32,512}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

Visual Studio Version

No response

GCC / Compiler Version

11.4.0

@kabyanil kabyanil added the build build issues; typically submitted using template label Jul 11, 2024
@github-actions github-actions bot added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Jul 11, 2024
@yufenglee
Copy link
Member

@kabyanil, it looks like a model issue based on the error. It is not possible to Reshape a tensor with shape {1,32,512} to {1,1,8,64}. The former has 32512 elements while the target only has 864 elements. I guess it is intended to convert {1, 32, 512} to {1, 32, 8, 64}. Please check the model.

@kabyanil
Copy link
Author

@kabyanil, it looks like a model issue based on the error. It is not possible to Reshape a tensor with shape {1,32,512} to {1,1,8,64}. The former has 32512 elements while the target only has 864 elements. I guess it is intended to convert {1, 32, 512} to {1, 32, 8, 64}. Please check the model.

I am able to run inference in python using the same code which I used for onnx conversion. What could be the issue then?

@kabyanil
Copy link
Author

@kabyanil, it looks like a model issue based on the error. It is not possible to Reshape a tensor with shape {1,32,512} to {1,1,8,64}. The former has 32_512 elements while the target only has 8_64 elements. I guess it is intended to convert {1, 32, 512} to {1, 32, 8, 64}. Please check the model.

I have updated the error message. Can you please check now?

@kabyanil kabyanil changed the title RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/Reshape_5' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:45 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,32,512}, requested shape:{1,1,8,64} [Build] [Build] RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running MatMul node. Name:'/MatMul_7' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,512} != {1,32,512}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model. Jul 11, 2024
@tianleiwu
Copy link
Contributor

@kabyanil,

I guess the decoder_input and encoder_output shall have same shape (batch_size, seq_len, hidden_size). It seems that you use different shape as dummy input in your code:

dummy_decoder_input = torch.randint(low=0, high=60, size =(1, 1, 512), dtype=torch.float)
dummy_encoder_output = torch.randint(low=0, high=60, size =(1, 32, 512), dtype=torch.float)

@sophies927 sophies927 removed the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Jul 11, 2024
@kabyanil
Copy link
Author

kabyanil commented Jul 12, 2024

@tianleiwu During inference, the encoder encodes the input to (batch_size, seq_len, embed_dim) where batch_size=1, seq_len=32 and embed_dim=512. Inputs less than seq_len are padded to make them of length 32. Whereas in the decoder, the input starts from seq_len=1. At every output of the decoder, the next token is selected using torch.max() and appended to the decoder's earlier input. So, the decoder starts from seq_len=1 and goes up until EOS token is hit.

To mimic this behaviour, I chose the dummy_encoder_output to be fixed at (1, 32, 512), and the dummy_decoder_input to be the initial shape of (1, 1, 512).

Is this a mistake?

@tianleiwu
Copy link
Contributor

tianleiwu commented Jul 12, 2024

@Qkabyanil,
If they are different, you shall use different string in the dynamic axes like:

dynamic_axes = {
    'decoder_input': {0: 'batch_size', 1: 'decoder_seq_len', 2: 'embed_dim'},
    'encoder_output': {0: 'batch_size', 1: 'encoder_seq_len', 2: 'embed_dim'},
...
}

1.18.1 will merge the Shape nodes when they found the symbolic shape is same:
#19832

It was reverted in main branch, but still in 1.18.* release. You can try older version like 1.17.* to walk around it.

@kabyanil
Copy link
Author

kabyanil commented Jul 13, 2024

Thanks so much, naming the inputs differently solved the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build build issues; typically submitted using template
Projects
None yet
Development

No branches or pull requests

4 participants