Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: 5123 metrics #5245

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion argilla/src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
self,
metric_names: Union[str, List[str]] = None,
question_name: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion, RankingQuestion] = None,
field_name: Union[str, List[str]] = None,
) -> Union["AgreementMetricResult", List["AgreementMetricResult"]]:
"""Compute agreement or reliability of annotation metrics.

Expand All @@ -94,6 +95,7 @@
Args:
metric_names: Metric name or list of metric names of the metrics, dependent on the question type.
question_name: Question for which we want to compute the metrics.
field_name: Name of the fields related to the question we want to analyse the agreement.

Note:
Currently, TextQuestion is not supported.
Expand All @@ -104,7 +106,7 @@
"""
from argilla.client.feedback.metrics.agreement_metrics import AgreementMetric

return AgreementMetric(self, question_name).compute(metric_names)
return AgreementMetric(self, question_name, field_name).compute(metric_names)

Check warning on line 109 in argilla/src/argilla/client/feedback/dataset/mixins.py

View check run for this annotation

Codecov / codecov/patch

argilla/src/argilla/client/feedback/dataset/mixins.py#L109

Added line #L109 was not covered by tests


class UnificationMixin:
Expand Down
15 changes: 12 additions & 3 deletions argilla/src/argilla/client/feedback/metrics/agreement_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains metrics to gather information related to inter-Annotator agreement. """
"""This module contains metrics to gather information related to inter-Annotator agreement."""

import warnings
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -52,6 +53,7 @@
def prepare_dataset_for_annotation_task(
dataset: Union["FeedbackDataset", "RemoteFeedbackDataset"],
question_name: str,
field_name: Union[str, List[str]],
filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None,
sort_by: Optional[List["SortBy"]] = None,
max_records: Optional[int] = None,
Expand All @@ -74,6 +76,7 @@
Args:
dataset: FeedbackDataset to compute the metrics.
question_name: Name of the question for which we want to analyse the agreement.
field_name: Name of the fields related to the question we want to analyse the agreement.
filter_by: A dict with key the field to filter by, and values the filters to apply.
Can be one of: draft, pending, submitted, and discarded. If set to None,
no filter will be applied. Defaults to None (no filter is applied).
Expand Down Expand Up @@ -108,7 +111,9 @@

for row in hf_dataset:
responses_ = row[question_name]
question_text = row["text"]
question_text = (

Check warning on line 114 in argilla/src/argilla/client/feedback/metrics/agreement_metrics.py

View check run for this annotation

Codecov / codecov/patch

argilla/src/argilla/client/feedback/metrics/agreement_metrics.py#L114

Added line #L114 was not covered by tests
" ".join([row[field] for field in field_name]) if isinstance(field_name, list) else row[field_name]
)
for response in responses_:
user_id = response["user_id"]
if user_id is None:
Expand Down Expand Up @@ -181,7 +186,7 @@
Example:
>>> import argilla as rg
>>> from argilla.client.feedback.metrics import AgreementMetric
>>> metric = AgreementMetric(dataset=dataset, question_name=question, filter_by={"response_status": "submitted"})
>>> metric = AgreementMetric(dataset=dataset, question_name=question, field_name=field, filter_by={"response_status": "submitted"})
>>> metrics_report = metric.compute("alpha")

"""
Expand All @@ -190,6 +195,7 @@
self,
dataset: FeedbackDataset,
question_name: str,
field_name: Union[str, List[str]],
filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None,
sort_by: Optional[List["SortBy"]] = None,
max_records: Optional[int] = None,
Expand All @@ -199,6 +205,7 @@
Args:
dataset: FeedbackDataset to compute the metrics.
question_name: Name of the question for which we want to analyse the agreement.
field_name: Name of the fields related to the question we want to analyse the agreement.
filter_by: A dict with key the field to filter by, and values the filters to apply.
Can be one of: draft, pending, submitted, and discarded. If set to None,
no filter will be applied. Defaults to None (no filter is applied).
Expand All @@ -207,6 +214,7 @@
max_records: The maximum number of records to use for training. Defaults to None.
"""
self._metrics_per_question = METRICS_PER_QUESTION
self._field_name = field_name

Check warning on line 217 in argilla/src/argilla/client/feedback/metrics/agreement_metrics.py

View check run for this annotation

Codecov / codecov/patch

argilla/src/argilla/client/feedback/metrics/agreement_metrics.py#L217

Added line #L217 was not covered by tests
super().__init__(dataset, question_name)
self._filter_by = filter_by
self._sort_by = sort_by
Expand All @@ -231,6 +239,7 @@
dataset = prepare_dataset_for_annotation_task(
self._dataset,
self._question_name,
self._field_name,
filter_by=self._filter_by,
sort_by=self._sort_by,
max_records=self._max_records,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,20 @@ def test_allowed_metrics(
)
dataset.add_records(records=feedback_dataset_records_with_paired_suggestions)

metric = AgreementMetric(dataset=dataset, question_name=question)
metric = AgreementMetric(
dataset=dataset, question_name=question, field_name=[field.name for field in feedback_dataset_fields]
)
assert set(metric.allowed_metrics) == metric_names


@pytest.mark.parametrize(
"question, num_items, type_of_data",
"field, question, num_items, type_of_data",
[
("question-1", None, None),
("question-2", 12, int),
("question-3", 12, str),
("question-4", 12, frozenset),
("question-5", 12, tuple),
(["text"], "question-1", None, None),
(["text", "label"], "question-2", 12, int),
(["text", "label"], "question-3", 12, str),
(["text"], "question-4", 12, FrozenSet),
(["label"], "question-5", 12, Tuple),
],
)
@pytest.mark.usefixtures(
Expand All @@ -91,6 +93,7 @@ def test_prepare_dataset_for_annotation_task(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
num_items: int,
type_of_data: Union[str, int, FrozenSet, Tuple[str]],
):
Expand All @@ -101,19 +104,24 @@ def test_prepare_dataset_for_annotation_task(
)
dataset.add_records(records=feedback_dataset_records_with_paired_suggestions)

if question in ("question-1",):
if question == "question-1":
with pytest.raises(NotImplementedError, match=r"^Question '"):
prepare_dataset_for_annotation_task(dataset, question)
prepare_dataset_for_annotation_task(dataset, question, field)
else:
formatted_dataset = prepare_dataset_for_annotation_task(dataset, question)
formatted_dataset = prepare_dataset_for_annotation_task(dataset, question, field)
assert isinstance(formatted_dataset, list)
assert len(formatted_dataset) == num_items
item = formatted_dataset[0]
assert isinstance(item, tuple)
assert isinstance(item[0], str)
assert item[0].startswith("00000000-") # beginning of our uuid for tests
assert isinstance(item[1], str)
assert item[1] == feedback_dataset_records_with_paired_suggestions[0].fields["text"]
expected_field_value = (
" ".join([feedback_dataset_records_with_paired_suggestions[0].fields[f] for f in field])
if isinstance(field, list)
else feedback_dataset_records_with_paired_suggestions[0].fields[field]
)
assert item[1] == expected_field_value
assert isinstance(item[2], type_of_data)


Expand Down Expand Up @@ -156,9 +164,17 @@ def test_agreement_metrics(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
AgreementMetric(dataset=dataset, question_name=question)
AgreementMetric(
dataset=dataset,
question_name=question,
field_name=[field.name for field in feedback_dataset_fields],
)
else:
metric = AgreementMetric(dataset=dataset, question_name=question)
metric = AgreementMetric(
dataset=dataset,
question_name=question,
field_name=[field.name for field in feedback_dataset_fields],
)
# Test for repr method
assert repr(metric) == f"AgreementMetric(question_name={question})"
metrics_report = metric.compute(metric_names)
Expand All @@ -173,19 +189,19 @@ def test_agreement_metrics(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -200,6 +216,7 @@ async def test_agreement_metrics_remote(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
owner: User,
):
Expand All @@ -219,9 +236,17 @@ async def test_agreement_metrics_remote(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
AgreementMetric(dataset=remote, question_name=question)
AgreementMetric(
dataset=remote,
question_name=question,
field_name=field,
)
else:
metric = AgreementMetric(dataset=remote, question_name=question)
metric = AgreementMetric(
dataset=remote,
question_name=question,
field_name=field,
)
# Test for repr method
assert repr(metric) == f"AgreementMetric(question_name={question})"
metrics_report = metric.compute(metric_names)
Expand All @@ -235,19 +260,19 @@ async def test_agreement_metrics_remote(


@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -262,6 +287,7 @@ def test_agreement_metrics_from_feedback_dataset(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
):
dataset = FeedbackDataset(
Expand All @@ -273,9 +299,11 @@ def test_agreement_metrics_from_feedback_dataset(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names)
dataset.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names)
else:
metrics_report = dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names)
metrics_report = dataset.compute_agreement_metrics(
question_name=question, field_name=field, metric_names=metric_names
)

if isinstance(metric_names, str):
metrics_report = [metrics_report]
Expand All @@ -288,19 +316,19 @@ def test_agreement_metrics_from_feedback_dataset(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -315,6 +343,7 @@ async def test_agreement_metrics_from_remote_feedback_dataset(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
owner: User,
) -> None:
Expand All @@ -335,9 +364,11 @@ async def test_agreement_metrics_from_remote_feedback_dataset(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
remote.compute_agreement_metrics(question_name=question, metric_names=metric_names)
remote.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names)
else:
metrics_report = remote.compute_agreement_metrics(question_name=question, metric_names=metric_names)
metrics_report = remote.compute_agreement_metrics(
question_name=question, field_name=field, metric_names=metric_names
)

if isinstance(metric_names, str):
metrics_report = [metrics_report]
Expand Down
4 changes: 2 additions & 2 deletions docs/_source/practical_guides/collect_responses.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ import argilla as rg
from argilla.client.feedback.metrics import AgreementMetric

feedback_dataset = rg.FeedbackDataset.from_argilla("...", workspace="...")
metric = AgreementMetric(dataset=feedback_dataset, question_name="question_name")
metric = AgreementMetric(dataset=feedback_dataset, field_name="text", question_name="question_name")
agreement_metrics = metric.compute("alpha")
# >>> agreement_metrics
# [AgreementMetricResult(metric_name='alpha', count=1000, result=0.467889)]
Expand All @@ -156,7 +156,7 @@ import argilla as rg

#dataset = rg.FeedbackDataset.from_huggingface("argilla/go_emotions_raw")

agreement_metrics = dataset.compute_agreement_metrics(question_name="label", metric_names="alpha")
agreement_metrics = dataset.compute_agreement_metrics(question_name="label", field_name="text", metric_names="alpha")
agreement_metrics

# AgreementMetricResult(metric_name='alpha', count=191792, result=0.2703263452657748)
Expand Down
Loading