Skip to content

Commit

Permalink
update helper functions to include all generation config handling and…
Browse files Browse the repository at this point in the history
… overriding
  • Loading branch information
dsikka committed Sep 21, 2023
1 parent 0609801 commit 93c012d
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 68 deletions.
83 changes: 16 additions & 67 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import datetime
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -42,8 +41,11 @@
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils.helpers import (
check_and_return_generation_config,
create_causal_mask,
override_config,
pad_to_fixed_length,
process_generation_config,
repeat_inputs,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
Expand Down Expand Up @@ -141,6 +143,13 @@ class Config:
"top_k, repetition_penalty.",
)

kwargs: Optional[Dict] = Field(
default=None,
description="Any arguments to override generation_config arguments. Refer to "
"the generation_config argument for a full list of supported variables. Only "
"valid when generation_config is not None.",
)


class GeneratedText(BaseModel):
text: str = Field(
Expand Down Expand Up @@ -264,7 +273,7 @@ def __init__(
self.engine, self.multitoken_engine = self.initialize_engines()
self.streaming = False

self.generation_config = self._process_generation_config(generation_config)
self.generation_config = process_generation_config(generation_config)
if self.generation_config:
_LOGGER.info(
"Generation config provided for pipline. This will be used "
Expand Down Expand Up @@ -408,79 +417,19 @@ def output_schema(self) -> Type[BaseModel]:
"""
return TextGenerationOutput

def _process_generation_config(
self, generation_config: [None, str, pathlib.Path, Dict, GenerationConfig]
) -> Union[GenerationConfig, None]:
"""
Process and return a GenerationConfig. The function can take in a filepath
pointing to a json consisting of the config values, a dictionary with the config
values, or a loaded GenerationConfig object. If None is given, the defaults are,
the pipeline GenerationConfig is used, if provided. If both are None, default
are used for generation.
:param generation_config: either a json filepath, dictionary or loaded
GenerationConfig object
:return: GenerationConfig object or None
"""
if isinstance(generation_config, GenerationConfig):
return generation_config

if not generation_config:
return None

# TODO: move to tmp folder
if isinstance(generation_config, dict):
config_dir = os.getcwd()
config_name = "generation_config.json"
local_config_path = os.path.join(config_dir, config_name)
_LOGGER.info(
"Dictionary provided for the generation config. Creating temporary "
" generation_config.json"
)
with open(local_config_path, "w") as f:
json.dump(generation_config, f)

if isinstance(generation_config, (str, pathlib.Path)):
generation_config = pathlib.Path(generation_config)
config_dir = generation_config.parent.absolute()
config_name = generation_config.name

generation_config = GenerationConfig.from_pretrained(config_dir, config_name)
return generation_config

def _check_and_return_generation_config(
self, input_generation_config: [None, str, pathlib.Path, Dict, GenerationConfig]
) -> Union[GenerationConfig, None]:
generation_config = self._process_generation_config(input_generation_config)
if generation_config is None:
if self.generation_config:
generation_config = self.generation_config
else:
_LOGGER.info(
"Input generation config detected. This will override any"
" config provided during pipeline creation."
)

if not generation_config:
_LOGGER.info(
" No GenerationConfig detected. Using GenerationDefaults values"
)
generation_config = GenerationDefaults()
return generation_config

def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
"""
Convert the input schema for the pipeline to the inputs for the engine.
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
generation_config = self._check_and_return_generation_config(
inputs.generation_config
generation_config = check_and_return_generation_config(
self.generation_config, inputs.generation_config, GenerationDefaults()
)

generation_config = override_config(inputs.kwargs, generation_config)

self.streaming = inputs.streaming
if not self.cache_support_enabled and generation_config.max_length > 1:
raise ValueError(
Expand Down Expand Up @@ -712,7 +661,7 @@ def engine_forward(
if max_new_tokens:
max_tokens = max_new_tokens + len(generated_tokens)
else:
max_tokens = generation_config.max_tokens
max_tokens = generation_config.max_length
max_tokens = (
max_tokens if max_tokens > 0 else (100 * self.sequence_length)
)
Expand Down
123 changes: 122 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import pathlib
import uuid
from typing import List, Union
from typing import Dict, List, Optional, Union

import numpy
from transformers import GenerationConfig


__all__ = [
"generate_session_id",
"pad_to_fixed_length",
"create_causal_mask",
"repeat_inputs",
"check_and_return_generation_config",
"override_config",
"process_generation_config",
]

_LOGGER = logging.getLogger(__name__)
Expand All @@ -37,6 +44,120 @@ def generate_session_id() -> str:
return session_id


def process_generation_config(
generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig]
) -> Union[GenerationConfig, None]:
"""
Process and return a GenerationConfig. The function can take in a filepath
pointing to a json consisting of the config values, a dictionary with the config
values, or a loaded GenerationConfig object. If None is given, the defaults are,
the pipeline GenerationConfig is used, if provided. If both are None, default
are used for generation.
:param generation_config: either a json filepath, dictionary or loaded
GenerationConfig object
:return: GenerationConfig object or None
"""
if isinstance(generation_config, GenerationConfig):
return generation_config

if not generation_config:
return None

# TODO: move to tmp folder
if isinstance(generation_config, dict):
config_dir = os.getcwd()
config_name = "generation_config.json"
local_config_path = os.path.join(config_dir, config_name)
_LOGGER.info(
"Dictionary provided for the generation config. Creating temporary "
" generation_config.json"
)
with open(local_config_path, "w") as f:
json.dump(generation_config, f)

if isinstance(generation_config, (str, pathlib.Path)):
generation_config = pathlib.Path(generation_config)
config_dir = generation_config.parent.absolute()
config_name = generation_config.name

generation_config = GenerationConfig.from_pretrained(config_dir, config_name)
return generation_config


def check_and_return_generation_config(
pipeline_generation_config: [None, str, pathlib.Path, Dict, GenerationConfig],
input_generation_config: [None, str, pathlib.Path, Dict, GenerationConfig],
defaults: "GenerationDefaults", # noqa F821
) -> Union[GenerationConfig, None]:
"""
Check if an input generation config is provided. If not, check if a pipeline
generation config exists. If neither exists, use the defualt generation configs,
either deespsparse defaults or hugging face defaults. If a pipeline config exists
and an input config exists, use the input config.
:param pipeline_generation_config: either a json filepath, dictionary or loaded
GenerationConfig object provided by the user during pipeline creation
:param input_generation_config: either a json filepath, dictionary or loaded
GenerationConfig object provided by the user during inference
:param defaults: defaults to use for the GenerationConfig if a config is not
provided during inference or pipeline creation.
:return: GenerationConfig object or None
"""
generation_config = process_generation_config(input_generation_config)
if generation_config is None:
if pipeline_generation_config:
generation_config = pipeline_generation_config
else:
_LOGGER.info(
"Input generation config detected. This will override any"
" config provided during pipeline creation."
)

if not generation_config:
_LOGGER.info(" No GenerationConfig detected. Using GenerationDefaults values")
generation_config = defaults
return generation_config


def override_config(
overrides: Optional[Dict], generation_config: GenerationConfig
) -> GenerationConfig:
"""
Override any generation config properties using the `kwargs` argument in
TextGenerationInput. If None, the generation config is returned unchanged.
If provided, update all attribute stored in the dictionary. An errror will be
raised if the dictionary contains an key which is not a GenerationConfig
attribute.
:param overrides: dictionary containing GenerationConfig attributes to update
:param generation_config: GenerationConfig to update
:return: GenerationConfig object
"""
if overrides is None:
return generation_config

for k, v in overrides.items():
try:
if getattr(generation_config, k):
setattr(generation_config, k, v)
_LOGGER.info(f"Overriding attribute {k} in the generation config")
except AttributeError as exception:
raise AttributeError(
"Argument provided for GenerationConfig is not "
"valid. Refer to the TextGenerationInput for supported attributes. "
) from exception

return generation_config


def repeat_inputs(
input_sequences: List[str], num_generated_predictions: int
) -> List[str]:
Expand Down

0 comments on commit 93c012d

Please sign in to comment.