Skip to content

Commit

Permalink
Improved huggingface batch logic (#1336)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajsalow committed Sep 5, 2023
1 parent 50745f5 commit 93b2bcb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
16 changes: 12 additions & 4 deletions runtimes/huggingface/mlserver_huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Callable
from functools import partial
from mlserver.logging import logger
from mlserver.settings import ModelSettings

import torch
Expand Down Expand Up @@ -59,11 +60,18 @@ def load_pipeline_from_settings(
framework=hf_settings.framework,
)

# If max_batch_size > 0 we need to ensure tokens are padded
if settings.max_batch_size:
# If max_batch_size > 1 we need to ensure tokens are padded
if settings.max_batch_size > 1:
model = hf_pipeline.model
eos_token_id = model.config.eos_token_id # type: ignore
hf_pipeline.tokenizer.pad_token_id = [str(eos_token_id)] # type: ignore
if not hf_pipeline.tokenizer.pad_token_id:
eos_token_id = model.config.eos_token_id # type: ignore
if eos_token_id:
hf_pipeline.tokenizer.pad_token_id = [str(eos_token_id)] # type: ignore
else:
logger.warning(
"Model has neither pad_token or eos_token, setting batch size to 1"
)
hf_pipeline._batch_size = 1

return hf_pipeline

Expand Down
43 changes: 43 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,46 @@ def test_pipeline_is_initialised_with_correct_model_param(
pipeline_call_args = mock_pipeline_factory.return_value.call_args

assert pipeline_call_args.kwargs["model"] == expected


@pytest.mark.parametrize(
"pretrained_model, task, input_batch_size, expected_batch_size",
[
(
"hf-internal-testing/tiny-bert-for-token-classification",
"token-classification",
1,
1,
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
"token-classification",
0,
1,
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
"token-classification",
10,
1,
), # Neither pad_token nor eos_token defined revert to 1
],
)
def test_pipeline_checks_for_eos_and_pad_token(
pretrained_model: str,
task: Optional[str],
input_batch_size: Optional[int],
expected_batch_size: Optional[int],
):
hf_settings = HuggingFaceSettings(pretrained_model=pretrained_model, task=task)
model_params = ModelParameters()
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=model_params,
max_batch_size=input_batch_size,
)

m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size

0 comments on commit 93b2bcb

Please sign in to comment.