Skip to content

Commit

Permalink
Update wikitq, tab_fact taskcards (#963)
Browse files Browse the repository at this point in the history
Co-authored-by: Rajmohan <rajmohanc1@in.ibm.com>
Co-authored-by: Elron Bandel <elronbandel@gmail.com>
  • Loading branch information
3 people committed Jun 30, 2024
1 parent 094b1a1 commit b155523
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 28 deletions.
5 changes: 3 additions & 2 deletions prepare/cards/tab_fact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# Set unitxt.settings.allow_unverified_code=True or environment variable: UNITXT_ALLOW_UNVERIFIED_CODE to True

card = TaskCard(
loader=LoadHF(path="ibm/tab_fact", streaming=False),
loader=LoadHF(
path="ibm/tab_fact", streaming=False, data_classification_policy=["public"]
),
preprocess_steps=[
"splitters.small_no_test",
SerializeTableAsIndexedRowMajor(field_to_field=[["table", "table_serialized"]]),
RenameFields(
field_to_field={"table_serialized": "text_a", "statement": "text_b"}
Expand Down
26 changes: 18 additions & 8 deletions prepare/cards/wikitq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,33 @@
SerializeTableAsIndexedRowMajor,
Set,
TaskCard,
TruncateTableCells,
TruncateTableRows,
)
from unitxt.catalog import add_to_catalog
from unitxt.templates import MultiReferenceTemplate, TemplatesList
from unitxt.test_utils.card import test_card

card = TaskCard(
loader=LoadHF(path="wikitablequestions"),
loader=LoadHF(path="wikitablequestions", data_classification_policy=["public"]),
preprocess_steps=[
"splitters.small_no_test",
Set({"context_type": "table"}),
TruncateTableCells(max_length=15, table="table", text_output="answers"),
TruncateTableRows(field="table", rows_to_keep=50),
## truncate only if needed as it can impact evaluation results.
# TruncateTableCells(max_length=15, table="table", text_output="answers"),
# TruncateTableRows(field="table", rows_to_keep=50),
SerializeTableAsIndexedRowMajor(field_to_field=[["table", "context"]]),
],
task="tasks.qa.with_context.extractive",
templates="templates.qa.with_context.all",
task="tasks.qa.with_context.extractive[metrics=[metrics.unsorted_list_exact_match]]",
templates=TemplatesList(
[
MultiReferenceTemplate(
input_format="Based on this {context_type}: {context}\nAnswer the question: {question}",
references_field="answers",
postprocessors=[
"processors.to_list_by_comma_space",
"processors.str_to_float_format",
],
),
]
),
__description__=(
"This WikiTableQuestions dataset is a large-scale dataset for the task of question answering on semi-structured tables… See the full description on the dataset page: https://huggingface.co/datasets/wikitablequestions"
),
Expand Down
13 changes: 12 additions & 1 deletion prepare/processors/to_list_by_comma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unitxt import add_to_catalog
from unitxt.operator import SequentialOperator
from unitxt.processors import ToListByComma
from unitxt.processors import ToListByComma, ToListByCommaSpace

add_to_catalog(
SequentialOperator(
Expand All @@ -18,3 +18,14 @@
"processors.to_list_by_comma_from_references",
overwrite=True,
)

add_to_catalog(
SequentialOperator(
steps=[
ToListByCommaSpace(field="prediction", process_every_value=False),
ToListByCommaSpace(field="references", process_every_value=True),
]
),
"processors.to_list_by_comma_space",
overwrite=True,
)
6 changes: 4 additions & 2 deletions src/unitxt/catalog/cards/tab_fact.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"loader": {
"__type__": "load_hf",
"path": "ibm/tab_fact",
"streaming": false
"streaming": false,
"data_classification_policy": [
"public"
]
},
"preprocess_steps": [
"splitters.small_no_test",
{
"__type__": "serialize_table_as_indexed_row_major",
"field_to_field": [
Expand Down
34 changes: 19 additions & 15 deletions src/unitxt/catalog/cards/wikitq.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,18 @@
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "wikitablequestions"
"path": "wikitablequestions",
"data_classification_policy": [
"public"
]
},
"preprocess_steps": [
"splitters.small_no_test",
{
"__type__": "set",
"fields": {
"context_type": "table"
}
},
{
"__type__": "truncate_table_cells",
"max_length": 15,
"table": "table",
"text_output": "answers"
},
{
"__type__": "truncate_table_rows",
"field": "table",
"rows_to_keep": 50
},
{
"__type__": "serialize_table_as_indexed_row_major",
"field_to_field": [
Expand All @@ -33,8 +24,21 @@
]
}
],
"task": "tasks.qa.with_context.extractive",
"templates": "templates.qa.with_context.all",
"task": "tasks.qa.with_context.extractive[metrics=[metrics.unsorted_list_exact_match]]",
"templates": {
"__type__": "templates_list",
"items": [
{
"__type__": "multi_reference_template",
"input_format": "Based on this {context_type}: {context}\nAnswer the question: {question}",
"references_field": "answers",
"postprocessors": [
"processors.to_list_by_comma_space",
"processors.str_to_float_format"
]
}
]
},
"__description__": "This WikiTableQuestions dataset is a large-scale dataset for the task of question answering on semi-structured tables… See the full description on the dataset page: https://huggingface.co/datasets/wikitablequestions",
"__tags__": {
"annotations_creators": "crowdsourced",
Expand Down
15 changes: 15 additions & 0 deletions src/unitxt/catalog/processors/to_list_by_comma_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"__type__": "sequential_operator",
"steps": [
{
"__type__": "to_list_by_comma_space",
"field": "prediction",
"process_every_value": false
},
{
"__type__": "to_list_by_comma_space",
"field": "references",
"process_every_value": true
}
]
}
5 changes: 5 additions & 0 deletions src/unitxt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class ToListByComma(SplitStrip):
strip_every_element = True


class ToListByCommaSpace(SplitStrip):
delimiter = ", "
strip_every_element = True


class RegexParser(FieldOperator):
"""A processor that uses regex in order to parse a string."""

Expand Down

0 comments on commit b155523

Please sign in to comment.