Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add production-time recipe processing capability to unitxt #739

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,7 @@ ibmcos_datasets/
kaggle.json

src/unitxt/catalog_back/*
src/unitxt/catalog/metrics/example/accuracy.json
src/unitxt/catalog/processors/example/to_string.json
prod_env/*
benchmark_output/*
58 changes: 58 additions & 0 deletions docs/docs/production.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
.. _production:

.. note::

To use this tutorial, you need to :ref:`install unitxt <install_unitxt>`.

=====================================
Using in production
=====================================

Unitxt can be used to process data in production. First define a recipe:
OfirArviv marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2"


Second prepare an instance in the exact schema of the task in that recipe:


.. code-block:: python

instance = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way from a recipe to know its from a speficic task?
I want the metric to only accept cards that are from specific task, so I know the recipe beforehand?

"label": "?",
"text_a": "It works perfectly",
"text_b": "It works!",
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
}

Then you can produce that model-ready data with the `produce` function:

.. code-block:: python

from unitxt import produce

result = produce([instance], recipe)

Then you have the production ready instance in the result. If you `print(result[0]["source"])` you will get:

.. code-block::

Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.

premise: When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth., hypothesis: mother was careful not to disturb her, undressing and climbing back into her berth.
The entailment class is entailment

premise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood., hypothesis: The police were trying to stop the drug trade in the neighborhood.
The entailment class is not entailment

premise: It works perfectly, hypothesis: It works!
The entailment class is




1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ Welcome!
docs/operators
docs/contributors_guide
docs/saving_and_loading_from_catalog
docs/production
docs/helm
modules
catalog
Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from .api import evaluate, load, load_dataset
from .api import evaluate, load, load_dataset, produce
from .catalog import add_to_catalog, get_from_catalog
from .logging_utils import get_logger
from .register import register_all_artifacts, register_local_catalog
Expand Down
16 changes: 16 additions & 0 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from typing import Any, Dict, List, Union

from datasets import DatasetDict
Expand Down Expand Up @@ -28,3 +29,18 @@ def load_dataset(dataset_query: str) -> DatasetDict:

def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(recipe_query):
return get_dataset_artifact(recipe_query).produce


def produce(instance_or_instances, recipe_query):
is_list = isinstance(instance_or_instances, list)
if not is_list:
instance_or_instances = [instance_or_instances]
result = _get_produce_with_cache(recipe_query)(instance_or_instances)
if not is_list:
result = result[0]
return result
14 changes: 12 additions & 2 deletions src/unitxt/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def process(self, *args, **kwargs) -> MultiStream:
pass


def instance_generator(instance):
yield instance


def stream_single(instance: Dict[str, Any]) -> Stream:
return Stream(generator=instance_generator, gen_kwargs={"instance": instance})


class MultiStreamOperator(StreamingOperator):
"""A class representing a multi-stream operator in the streaming system.

Expand Down Expand Up @@ -198,7 +206,7 @@ def process(self, multi_stream: MultiStream) -> MultiStream:
pass

def process_instance(self, instance, stream_name="tmp"):
multi_stream = MultiStream({stream_name: [instance]})
multi_stream = MultiStream({stream_name: stream_single(instance)})
processed_multi_stream = self(multi_stream)
return next(iter(processed_multi_stream[stream_name]))

Expand Down Expand Up @@ -269,7 +277,9 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
pass

def process_instance(self, instance, stream_name="tmp"):
processed_stream = self._process_single_stream([instance], stream_name)
processed_stream = self._process_single_stream(
stream_single(instance), stream_name
)
return next(iter(processed_stream))


Expand Down
114 changes: 90 additions & 24 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SourceSequentialOperator, StreamingOperator
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
from .recipe import Recipe
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .stream import MultiStream
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .templates import Template

Expand Down Expand Up @@ -99,15 +100,15 @@ def verify(self):
def prepare_refiners(self):
self.train_refiner.max_instances = self.max_train_instances
self.train_refiner.apply_to_streams = ["train"]
self.steps.append(self.train_refiner)
self.processing.steps.append(self.train_refiner)

self.validation_refiner.max_instances = self.max_validation_instances
self.validation_refiner.apply_to_streams = ["validation"]
self.steps.append(self.validation_refiner)
self.processing.steps.append(self.validation_refiner)

self.test_refiner.max_instances = self.max_test_instances
self.test_refiner.apply_to_streams = ["test"]
self.steps.append(self.test_refiner)
self.processing.steps.append(self.test_refiner)

def prepare_metrics_and_postprocessors(self):
if self.postprocessors is None:
Expand All @@ -121,9 +122,80 @@ def prepare_metrics_and_postprocessors(self):
metrics = self.metrics
return metrics, postprocessors

def prepare(self):
def set_pipelines(self):
self.loading = SequentialOperator()
self.metadata = SequentialOperator()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be called "add_metadata" .

self.standardization = SequentialOperator()
self.processing = SequentialOperator()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear what is the difference between "standardization" and "processing".

Maybe "standardize_to_task", "process_task_fields"

self.verblization = SequentialOperator()
self.finalize = SequentialOperator()

self.steps = [
self.card.loader,
self.loading,
self.metadata,
self.standardization,
self.processing,
self.verblization,
self.finalize,
]

self.inference_instance = SequentialOperator()

self.inference_instance.steps = [
self.metadata,
self.processing,
]

self.inference_demos = SourceSequentialOperator()

self.inference_demos.steps = [
self.loading,
self.metadata,
self.standardization,
self.processing,
]

self.inference = SequentialOperator()

self.inference.steps = [self.verblization, self.finalize]

self._demos_pool_cache = None

def production_preprocess(self, task_instances):
ms = MultiStream.from_iterables({"__inference__": task_instances})
return list(self.inference_instance(ms)["__inference__"])

def production_demos_pool(self):
if self.num_demos > 0:
if self._demos_pool_cache is None:
self._demos_pool_cache = list(
self.inference_demos()[self.demos_pool_name]
)
return self._demos_pool_cache
return []

def produce(self, task_instances):
"""Use the recipe in production to produce model ready query from standard task instance."""
self.before_process_multi_stream()
multi_stream = MultiStream.from_iterables(
{
"__inference__": self.production_preprocess(task_instances),
self.demos_pool_name: self.production_demos_pool(),
}
)
multi_stream = self.inference(multi_stream)
return list(multi_stream["__inference__"])

def prepare(self):
self.set_pipelines()

loader = self.card.loader
if self.loader_limit:
loader.loader_limit = self.loader_limit
logger.info(f"Loader line limit was set to {self.loader_limit}")
self.loading.steps.append(loader)

self.metadata.steps.append(
AddFields(
fields={
"recipe_metadata": {
Expand All @@ -133,25 +205,19 @@ def prepare(self):
"format": self.format,
}
}
),
]

if self.loader_limit:
self.card.loader.loader_limit = self.loader_limit
logger.info(f"Loader line limit was set to {self.loader_limit}")
self.steps.append(StreamRefiner(max_instances=self.loader_limit))
)
)

if self.card.preprocess_steps is not None:
self.steps.extend(self.card.preprocess_steps)
self.standardization.steps.extend(self.card.preprocess_steps)

self.steps.append(self.card.task)
self.processing.steps.append(self.card.task)

if self.augmentor.augment_task_input:
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
self.steps.append(self.augmentor)
self.processing.steps.append(self.augmentor)

if self.demos_pool_size is not None:
self.steps.append(
self.processing.steps.append(
CreateDemosPool(
from_split=self.demos_taken_from,
to_split_names=[self.demos_pool_name, self.demos_taken_from],
Expand All @@ -173,23 +239,23 @@ def prepare(self):

self.prepare_refiners()

self.steps.append(self.template)
self.verblization.steps.append(self.template)
if self.num_demos > 0:
self.steps.append(
self.verblization.steps.append(
AddDemosField(
source_stream=self.demos_pool_name,
target_field=self.demos_field,
sampler=self.sampler,
)
)
self.steps.append(self.system_prompt)
self.steps.append(self.format)
self.verblization.steps.append(self.system_prompt)
self.verblization.steps.append(self.format)
if self.augmentor.augment_model_input:
self.steps.append(self.augmentor)
self.verblization.steps.append(self.augmentor)

metrics, postprocessors = self.prepare_metrics_and_postprocessors()

self.steps.append(
self.finalize.steps.append(
ToUnitxtGroup(
group="unitxt",
metrics=metrics,
Expand Down
62 changes: 61 additions & 1 deletion tests/library/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unitxt.api import evaluate, load_dataset
from unitxt.api import evaluate, load_dataset, produce

from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -39,3 +39,63 @@ def test_evaluate_with_metrics_external_setup(self):
predictions = ["2.5", "2.5", "2.2", "3", "4"]
results = evaluate(predictions, dataset["train"])
self.assertAlmostEqual(results[0]["score"]["global"]["score"], 0.2, 3)

def test_produce_with_recipe(self):
result = produce(
{
"label": "?",
"text_a": "It works perfectly",
"text_b": "It works!",
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
},
"card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2",
)

target = {
"metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"],
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.premise: Steve follows Fred's example in everything. He influences him hugely., hypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood., hypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly, hypothesis: It works!\nThe entailment class is ",
"target": "?",
"references": ["?"],
"task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "label": "?"}',
"group": "unitxt",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
}

self.assertDictEqual(target, result)

def test_produce_with_recipe_with_list_of_instances(self):
result = produce(
[
{
"label": "?",
"text_a": "It works perfectly",
"text_b": "It works!",
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
}
],
"card=cards.wnli,template=templates.classification.multi_class.relation.default,demos_pool_size=5,num_demos=2",
)[0]

target = {
"metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"],
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.premise: Steve follows Fred's example in everything. He influences him hugely., hypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood., hypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly, hypothesis: It works!\nThe entailment class is ",
"target": "?",
"references": ["?"],
"task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "label": "?"}',
"group": "unitxt",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
}

self.assertDictEqual(target, result)
Loading
Loading