Skip to content

Commit

Permalink
Add control over metrics and postprocessors through the recipe
Browse files Browse the repository at this point in the history
Signed-off-by: Elron Bandel <elron.bandel@ibm.com>
  • Loading branch information
elronbandel committed Mar 14, 2024
1 parent 9363a9c commit b2106c1
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from .card import TaskCard
from .dataclass import Field, InternalField, OptionalField
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SourceSequentialOperator, StreamingOperator
Expand Down Expand Up @@ -29,6 +29,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
template: Template = None
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
format: Format = Field(default_factory=SystemFormat)
metrics: List[str] = NonPositionalField(default=None)
postprocessors: List[str] = NonPositionalField(default=None)

loader_limit: int = None

Expand Down Expand Up @@ -107,6 +109,18 @@ def prepare_refiners(self):
self.test_refiner.apply_to_streams = ["test"]
self.steps.append(self.test_refiner)

def prepare_metrics_and_postprocessors(self):
if self.postprocessors is None:
postprocessors = self.template.get_postprocessors()
else:
postprocessors = self.postprocessors

Check warning on line 116 in src/unitxt/standard.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/standard.py#L116

Added line #L116 was not covered by tests

if self.metrics is None:
metrics = self.card.task.metrics
else:
metrics = self.metrics

Check warning on line 121 in src/unitxt/standard.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/standard.py#L121

Added line #L121 was not covered by tests
return metrics, postprocessors

def prepare(self):
self.steps = [
self.card.loader,
Expand Down Expand Up @@ -173,12 +187,12 @@ def prepare(self):
if self.augmentor.augment_model_input:
self.steps.append(self.augmentor)

postprocessors = self.template.get_postprocessors()
metrics, postprocessors = self.prepare_metrics_and_postprocessors()

self.steps.append(
ToUnitxtGroup(
group="unitxt",
metrics=self.card.task.metrics,
metrics=metrics,
postprocessors=postprocessors,
)
)
Expand Down Expand Up @@ -222,6 +236,8 @@ class StandardRecipe(StandardRecipeWithIndexes):
system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe.
loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
metrics (List[str]): list of catalog metrics to use with this recipe.
postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
max_train_instances (int, optional): Maximum training instances for the refiner.
validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
Expand Down

0 comments on commit b2106c1

Please sign in to comment.