Skip to content

Commit

Permalink
Fix encoder attention mask input order for ORT (#1181)
Browse files Browse the repository at this point in the history
fix encoder attention mask input order
  • Loading branch information
fxmarty authored Jul 11, 2023
1 parent 2678e74 commit 414d4c3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def forward(
)

# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contigous()` here is necessary to not have errors
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]
Expand Down Expand Up @@ -533,12 +533,12 @@ def forward(

model_inputs = [input_ids]

if "encoder_attention_mask" in self.input_names:
model_inputs.append(encoder_attention_mask)

if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)

if "encoder_attention_mask" in self.input_names:
model_inputs.append(encoder_attention_mask)

if past_key_values is not None:
model_inputs += past_key_values

Expand Down

0 comments on commit 414d4c3

Please sign in to comment.