Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Update pipeline inputs to support GenerationConfig #1250

Merged
merged 19 commits into from
Sep 22, 2023

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Sep 17, 2023

Summary:

  • Update the input to the pipeline to support a transformers.GenerationConfig. This will essentially be used in replacement of many of the inputs that we provide as separate fields (such as num_return_sequences, top_k, etc).
  • Supports path to a json with the config file, dictionary, or transformers.GenerationConfig object
  • The user can provide the config either on the pipeline level or the input level. If an input level config is provided, it will override the pipeline level config and be used for generation. Otherwise, the one provided during pipeline creation will be used. If neither are given, defaults set in the GenerationDefaults class will be used
  • If the user provides a generation config, either on the pipeline level or on the input level but not all the values are set, then the defaults given by GenerationConfig will be used for the missing values, not the GenerationDefaults class.

Test Cases

Dictionary:

from deepsparse import Pipeline

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime"
)
generation_config = {
   "num_return_sequences": 2,
   "max_length": 100

}
inference = pipeline(sequences=["hello?", "cool"], generation_config=generation_config)
print(next(inference))

string or Path

from deepsparse import Pipeline

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime"
)
generation_config_path = "/home/dsikka/llama_run/current_config.json"
inference = pipeline(sequences=["hello?", "cool"], generation_config=generation_config_path)
print(next(inference))

GenerationConfig object

from deepsparse import Pipeline
from pathlib import Path
from transformers import GenerationConfig

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime"
)

generation_config_obj = GenerationConfig(
   num_return_sequences=3,
   max_length=100,
   output_scores=True
)
inference = pipeline(sequences=["hello?", "cool"], generation_config=generation_config_obj)
print(next(inference))

None - no generation config is provided, will use the GenerationDefaults instead

from deepsparse import Pipeline
from pathlib import Path
from transformers import GenerationConfig

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime"
)

inference = pipeline(sequences=["hello?", "cool"])
for out in inference:
   print(out)
   print("\n")

Set GenerationConfig on the pipeline level

  • The config set will be used for each pipeline input used
from deepsparse import Pipeline
from pathlib import Path
from transformers import GenerationConfig

generation_config_obj = GenerationConfig(
   num_return_sequences=3,
   max_length=100,
)

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime",
   generation_config=generation_config_obj
)

inference = pipeline(sequences=["hello?"])
for out in inference:
   print(out)
   print("\n")

inference = pipeline(sequences=["cool"])
for out in inference:
   print(out)
   print("\n")
  • Output 3 text generations per prompt, as set by the pipeline-level config
2023-09-19 11:31:14 deepsparse.transformers.pipelines.text_generation INFO     Generation config provided for pipline. This will be used for all inputs unless and input-specific config is provided. 
created=datetime.datetime(2023, 9, 19, 11, 32, 43, 387168) prompts=['hello?'] generations=[[GeneratedText(text='\n\nI am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am', score=None, finished=True, finished_reason='length'), GeneratedText(text='\n\nI am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am', score=None, finished=True, finished_reason='length'), GeneratedText(text='\n\nI am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am', score=None, finished=True, finished_reason='length')]] session_id=None


created=datetime.datetime(2023, 9, 19, 11, 32, 54, 414539) prompts=['cool'] generations=[[GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop'), GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop'), GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop')]] session_id=None

Set GenerationConfig on the pipeline level, override on the input level

from deepsparse import Pipeline
from pathlib import Path
from transformers import GenerationConfig

generation_config_obj = GenerationConfig(
   num_return_sequences=3,
   max_length=100,
)

pipeline = Pipeline.create(
   task="text_generation",
   model_path="/home/dsikka/.cache/sparsezoo/neuralmagic/opt-1.3b-opt_pretrain-quantW8A8/deployment",
   engine_type="onnxruntime",
   generation_config=generation_config_obj
)

generation_config_obj_input = GenerationConfig(
   num_return_sequences=2,
   max_length=50,
)


inference = pipeline(sequences=["hello?"], generation_config=generation_config_obj_input)
for out in inference:
   print(out)
   print("\n")

inference = pipeline(sequences=["cool"])
for out in inference:
   print(out)
   print("\n")
  • Output:
    For the first prompt, the pipeline config is overwritten with the config given with the input. This results in 2 text generations for the first input. For the second input, as not config is given, the pipeline config is used, resulting in 3 text generations.
created=datetime.datetime(2023, 9, 19, 11, 37, 5, 159853) prompts=['hello?'] generations=[[GeneratedText(text='\n\nI am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help.', score=None, finished=True, finished_reason='length'), GeneratedText(text='\n\nI am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help. I am a new member of the forum and I am looking for some help.', score=None, finished=True, finished_reason='length')]] session_id=None


created=datetime.datetime(2023, 9, 19, 11, 37, 17, 76668) prompts=['cool'] generations=[[GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop'), GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop'), GeneratedText(text=", i'll be there.\nI'll be there too.", score=None, finished=True, finished_reason='stop')]] session_id=None

src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
src/deepsparse/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved
bfineran
bfineran previously approved these changes Sep 21, 2023
Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - can we check in any unit tests for this?

Satrat
Satrat previously approved these changes Sep 21, 2023
Base automatically changed from enable_streaming to main September 21, 2023 20:11
@bfineran bfineran dismissed stale reviews from Satrat and themself September 21, 2023 20:11

The base branch was changed.

bfineran
bfineran previously approved these changes Sep 21, 2023
@bfineran bfineran merged commit b309fa4 into main Sep 22, 2023
13 checks passed
@bfineran bfineran deleted the update_inputs branch September 22, 2023 14:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants