Skip to content

Commit

Permalink
english; this one shouldn't be changed
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 15, 2024
1 parent cdbbdde commit 3e4948d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ def generate(

# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_tensor, model_kwargs
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]

Expand Down
7 changes: 4 additions & 3 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3203,26 +3203,27 @@ def test_special_tokens_fall_back_to_model_default(self):
)
test_bos_id = 50

# the model has a BOS token set, and the first generated token is a BOS token
# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
gen_output = model.generate()
self.assertTrue(model.generation_config.bos_token_id is not None)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])

# If pass a generation config **with** a BOS token, `generate` will use it
# If we pass a generation config **with** a BOS token, `generate` will use it
generation_config = GenerationConfig(bos_token_id=test_bos_id)
gen_output = model.generate(generation_config=generation_config)
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])

# If pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
# `model.generation_config`
generation_config = GenerationConfig(bos_token_id=None)
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertFalse(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

# Changing `model.generation_config` will affect fallback behavior
model.generation_config.bos_token_id = test_bos_id
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
Expand Down

0 comments on commit 3e4948d

Please sign in to comment.