Skip to content

Commit

Permalink
bug: fix bug when running integrations tests for text descriptives (#…
Browse files Browse the repository at this point in the history
…4614)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

When df swith np.nan are concatenated or fillna is applied in
extract_single_field, it cast all the column to floats. This way we make
sure that after filtering the nan values, each value is casted according
to the metadata property type created.


![image](https://github.com/argilla-io/argilla/assets/127759186/6cec5cda-451c-4c9b-9512-6bac182fd7f2)

Closes #4613 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
sdiazlor authored Feb 29, 2024
1 parent c48237e commit c234cf6
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/argilla/client/feedback/integrations/textdescriptives.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def _create_metadata_property_settings(
return dataset

def _add_text_descriptives_to_metadata(
self, records: List[Union[FeedbackRecord, RemoteFeedbackRecord]], df: pd.DataFrame
self,
records: List[Union[FeedbackRecord, RemoteFeedbackRecord]],
df: pd.DataFrame,
metadata_prop_types: Optional[dict] = None,
) -> List[Union[FeedbackRecord, RemoteFeedbackRecord]]:
"""
Add the text descriptives metrics extracted previously as metadata
Expand All @@ -291,6 +294,7 @@ def _add_text_descriptives_to_metadata(
Args:
records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records.
df (pd.DataFrame): The text descriptives dataframe.
metadata_prop_types (Optional[dict]): A dictionary with the name of the metadata property and the type if needed. Defaults to None.
Returns:
List[Union[FeedbackRecord, RemoteFeedbackRecord]]: A list of FeedbackDataset or RemoteFeedbackDataset records with extracted metrics added as metadata.
Expand All @@ -302,6 +306,15 @@ def _add_text_descriptives_to_metadata(
)
for record, metrics in zip(records, df.to_dict("records")):
filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)}
if metadata_prop_types is not None:
filtered_metrics = {
key: int(value)
if metadata_prop_types.get(key) == "integer"
else float(value)
if metadata_prop_types.get(key) == "float"
else value
for key, value in filtered_metrics.items()
}
record.metadata.update(filtered_metrics)
modified_records.append(record)
progress_bar.update(task, advance=1)
Expand All @@ -312,6 +325,7 @@ def update_records(
records: List[Union[FeedbackRecord, RemoteFeedbackRecord]],
fields: Optional[List[str]] = None,
overwrite: Optional[bool] = False,
metadata_prop_types: Optional[dict] = None,
) -> List[Union[FeedbackRecord, RemoteFeedbackRecord]]:
"""
Extract text descriptives metrics from a list of FeedbackDataset or RemoteFeedbackDataset records,
Expand All @@ -321,6 +335,7 @@ def update_records(
records (List[Union[FeedbackRecord, RemoteFeedbackRecord]]): A list of FeedbackDataset or RemoteFeedbackDataset records.
fields (List[str]): A list of fields to extract metrics for. If None, extract metrics for all fields.
overwrite (Optional[bool]): Whether to overwrite existing metadata properties with the same name. Defaults to False.
metadata_prop_types (Optional[dict]): A dictionary with the name of the metadata property and the type if needed. Defaults to None.
Returns:
List[Union[FeedbackRecord, RemoteFeedbackRecord]]: A list of FeedbackDataset or RemoteFeedbackDataset records with text descriptives metrics added as metadata.
Expand Down Expand Up @@ -354,7 +369,12 @@ def update_records(
# Clean column names
extracted_metrics.columns = [self._clean_column_name(col) for col in extracted_metrics.columns]
# Add the metrics to the metadata of the records
modified_records = self._add_text_descriptives_to_metadata(modified_records, extracted_metrics)
if metadata_prop_types is None:
modified_records = self._add_text_descriptives_to_metadata(modified_records, extracted_metrics)
else:
modified_records = self._add_text_descriptives_to_metadata(
modified_records, extracted_metrics, metadata_prop_types
)
return modified_records

def update_dataset(
Expand Down Expand Up @@ -396,7 +416,10 @@ def update_dataset(

# Update the records in the dataset too
if update_records:
records = self.update_records(records=dataset.records, fields=fields, overwrite=overwrite)
metadata_prop_types = {item.name: item.type for item in dataset.metadata_properties}
records = self.update_records(
records=dataset.records, fields=fields, overwrite=overwrite, metadata_prop_types=metadata_prop_types
)
if isinstance(dataset, RemoteFeedbackDataset):
dataset.update_records(records)
return dataset

0 comments on commit c234cf6

Please sign in to comment.