Skip to content

Commit

Permalink
load_dataset supports loading cards not present in local catalog (#929)
Browse files Browse the repository at this point in the history
Signed-off-by: Paweł Knes <pawel.knes@ibm.com>
  • Loading branch information
pawelknes committed Jun 23, 2024
1 parent 7f25b7f commit f0b2306
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 8 deletions.
5 changes: 2 additions & 3 deletions examples/qa_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from unitxt import get_logger
from unitxt.api import evaluate
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import TaskCard
from unitxt.collections_operators import Wrap
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.loaders import LoadFromDictionary
from unitxt.standard import StandardRecipe

logger = get_logger()

Expand Down Expand Up @@ -39,7 +38,7 @@
# What is the color of the sky?
# Answer:
# "
dataset = StandardRecipe(card=card, template="templates.qa.open.title")().to_dataset()
dataset = load_dataset(card=card, template="templates.qa.open.title")
test_dataset = dataset["test"]


Expand Down
5 changes: 2 additions & 3 deletions examples/standalone_qa_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from unitxt import get_logger
from unitxt.api import evaluate
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import Task, TaskCard
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.loaders import LoadFromDictionary
from unitxt.standard import StandardRecipe
from unitxt.templates import InputOutputTemplate, TemplatesDict

logger = get_logger()
Expand Down Expand Up @@ -44,7 +43,7 @@
)

# Verbalize the dataset using the template
dataset = StandardRecipe(card=card, template_card_index="simple")().to_dataset()
dataset = load_dataset(card=card, template_card_index="simple")
test_dataset = dataset["test"]


Expand Down
64 changes: 62 additions & 2 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from datasets import DatasetDict

Expand All @@ -8,6 +8,7 @@
from .logging_utils import get_logger
from .metric_utils import _compute
from .operator import SourceOperator
from .standard import StandardRecipe

logger = get_logger()

Expand All @@ -21,12 +22,71 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
return source().to_dataset()


def load_dataset(dataset_query: str) -> DatasetDict:
def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
dataset_query = dataset_query.replace("sys_prompt", "instruction")
dataset_stream = get_dataset_artifact(dataset_query)
return dataset_stream().to_dataset()


def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
for param in dataset_params.keys():
assert param in recipe_attributes, (
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
)
recipe = StandardRecipe(**dataset_params)
return recipe().to_dataset()


def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
"""Loads dataset.
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
catalog based on parameters specified in the query.
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
Args:
dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
For example:
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
Returns:
DatasetDict
Examples:
dataset = load_dataset(
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
) # card must be present in local catalog
card = TaskCard(...)
template = Template(...)
loader_limit = 10
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
"""
if dataset_query and kwargs:
raise ValueError(
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
"If you want to load dataset from a card in local catalog, use query only. "
"Otherwise, use key-worded arguments only to specify properties of dataset."
)

if dataset_query:
if not isinstance(dataset_query, str):
raise ValueError(
f"If specified, 'dataset_query' must be a string, however, "
f"'{dataset_query}' was provided instead, which is of type "
f"'{type(dataset_query)}'."
)
return _load_dataset_from_query(dataset_query)

if kwargs:
return _load_dataset_from_dict(kwargs)

raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")


def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)

Expand Down
37 changes: 37 additions & 0 deletions tests/library/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from unitxt.api import evaluate, load_dataset, produce
from unitxt.card import TaskCard
from unitxt.loaders import LoadHF
from unitxt.task import Task
from unitxt.templates import InputOutputTemplate, TemplatesList

from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -125,3 +129,36 @@ def test_produce_with_recipe_with_list_of_instances(self):
}

self.assertDictEqual(target, result)

def test_load_dataset_from_dict(self):
card = TaskCard(
loader=LoadHF(path="glue", name="wnli"),
task=Task(
inputs=["sentence1", "sentence2"],
outputs=["label"],
metrics=["metrics.accuracy"],
),
templates=TemplatesList(
[
InputOutputTemplate(
input_format="Sentence1: {sentence1} Sentence2: {sentence2}",
output_format="{label}",
),
InputOutputTemplate(
input_format="Sentence2: {sentence2} Sentence1: {sentence1}",
output_format="{label}",
),
]
),
)

dataset = load_dataset(card=card, template_card_index=1, loader_limit=5)

self.assertEqual(len(dataset["train"]), 5)
self.assertEqual(
dataset["train"]["source"][0].strip(),
"Sentence2: The carrot had a hole. "
"Sentence1: I stuck a pin through a carrot. "
"When I pulled the pin out, it had a hole.",
)
self.assertEqual(dataset["train"]["metrics"][0], ["metrics.accuracy"])

0 comments on commit f0b2306

Please sign in to comment.