Skip to content

Commit

Permalink
Merge pull request #103 from SSamDav/main
Browse files Browse the repository at this point in the history
Fix error when calling _prepare_decoder_input_ids_for_generation
  • Loading branch information
jalammar committed Aug 10, 2023
2 parents 47740ee + 1913fc1 commit 78abb38
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,21 @@ def generate(self, input_str: str,
raise ValueError(
"max_length set to {} while input token has more tokens ({}). Consider increasing max_length" \
.format(max_length, cur_len))

# Get decoder input ids
if self.model_type == 'enc-dec': # FIXME: only done because causal LMs like GPT-2 have the _prepare_decoder_input_ids_for_generation method but do not use it
assert len(input_ids.size()) == 2 # will break otherwise
if version.parse(transformers.__version__) >= version.parse('4.13'):
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids.shape[0], None, None)

# following the code in https://github.com/huggingface/transformers/blob/d0c1aebea467af499331234e7b285a6bf91ea073/tests/generation/test_utils.py#L2099
model_kwargs = self.model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
decoder_input_ids, model_kwargs = self.model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=self.model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=self.model.config.decoder_start_token_id,
bos_token_id=self.model.config.bos_token_id,
)
else:
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None)
else:
Expand Down

0 comments on commit 78abb38

Please sign in to comment.