Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_dataset supports loading cards not present in local catalog #929

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/qa_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from unitxt import get_logger
from unitxt.api import evaluate
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import TaskCard
from unitxt.collections_operators import Wrap
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.loaders import LoadFromDictionary
from unitxt.standard import StandardRecipe

logger = get_logger()

Expand Down Expand Up @@ -39,7 +38,7 @@
# What is the color of the sky?
# Answer:
# "
dataset = StandardRecipe(card=card, template="templates.qa.open.title")().to_dataset()
dataset = load_dataset(card=card, template="templates.qa.open.title")
test_dataset = dataset["test"]


Expand Down
5 changes: 2 additions & 3 deletions examples/standalone_qa_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from unitxt import get_logger
from unitxt.api import evaluate
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import Task, TaskCard
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.loaders import LoadFromDictionary
from unitxt.standard import StandardRecipe
from unitxt.templates import InputOutputTemplate, TemplatesDict

logger = get_logger()
Expand Down Expand Up @@ -44,7 +43,7 @@
)

# Verbalize the dataset using the template
dataset = StandardRecipe(card=card, template_card_index="simple")().to_dataset()
dataset = load_dataset(card=card, template_card_index="simple")
test_dataset = dataset["test"]


Expand Down
64 changes: 62 additions & 2 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from datasets import DatasetDict

Expand All @@ -8,6 +8,7 @@
from .logging_utils import get_logger
from .metric_utils import _compute
from .operator import SourceOperator
from .standard import StandardRecipe

logger = get_logger()

Expand All @@ -21,12 +22,71 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
return source().to_dataset()


def load_dataset(dataset_query: str) -> DatasetDict:
def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
dataset_query = dataset_query.replace("sys_prompt", "instruction")
dataset_stream = get_dataset_artifact(dataset_query)
return dataset_stream().to_dataset()


def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
for param in dataset_params.keys():
assert param in recipe_attributes, (
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
)
recipe = StandardRecipe(**dataset_params)
return recipe().to_dataset()


def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
"""Loads dataset.

If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
catalog based on parameters specified in the query.
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.

Args:
dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
For example:
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.

Returns:
DatasetDict

Examples:
dataset = load_dataset(
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
) # card must be present in local catalog

card = TaskCard(...)
template = Template(...)
loader_limit = 10
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
"""
if dataset_query and kwargs:
raise ValueError(
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
"If you want to load dataset from a card in local catalog, use query only. "
"Otherwise, use key-worded arguments only to specify properties of dataset."
)

if dataset_query:
if not isinstance(dataset_query, str):
raise ValueError(
f"If specified, 'dataset_query' must be a string, however, "
f"'{dataset_query}' was provided instead, which is of type "
f"'{type(dataset_query)}'."
)
return _load_dataset_from_query(dataset_query)

if kwargs:
return _load_dataset_from_dict(kwargs)

raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")


def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)

Expand Down
37 changes: 37 additions & 0 deletions tests/library/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from unitxt.api import evaluate, load_dataset, produce
from unitxt.card import TaskCard
from unitxt.loaders import LoadHF
from unitxt.task import Task
from unitxt.templates import InputOutputTemplate, TemplatesList

from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -125,3 +129,36 @@ def test_produce_with_recipe_with_list_of_instances(self):
}

self.assertDictEqual(target, result)

def test_load_dataset_from_dict(self):
card = TaskCard(
loader=LoadHF(path="glue", name="wnli"),
task=Task(
inputs=["sentence1", "sentence2"],
outputs=["label"],
metrics=["metrics.accuracy"],
),
templates=TemplatesList(
[
InputOutputTemplate(
input_format="Sentence1: {sentence1} Sentence2: {sentence2}",
output_format="{label}",
),
InputOutputTemplate(
input_format="Sentence2: {sentence2} Sentence1: {sentence1}",
output_format="{label}",
),
]
),
)

dataset = load_dataset(card=card, template_card_index=1, loader_limit=5)

self.assertEqual(len(dataset["train"]), 5)
self.assertEqual(
dataset["train"]["source"][0].strip(),
"Sentence2: The carrot had a hole. "
"Sentence1: I stuck a pin through a carrot. "
"When I pulled the pin out, it had a hole.",
)
self.assertEqual(dataset["train"]["metrics"][0], ["metrics.accuracy"])
Loading