-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into serializers
- Loading branch information
Showing
28 changed files
with
1,146 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from copy import deepcopy | ||
|
||
from unitxt import add_to_catalog | ||
from unitxt.blocks import ( | ||
LoadHF, | ||
SplitRandomMix, | ||
TaskCard, | ||
TemplatesDict, | ||
) | ||
from unitxt.dialog_operators import SerializeOpenAiFormatDialog | ||
from unitxt.operators import Copy, Set, Shuffle | ||
from unitxt.test_utils.card import test_card | ||
|
||
splits_random_mixes = { | ||
"train": SplitRandomMix( | ||
{"train": "test[0.6]", "validation": "test[0.2]", "test": "test[0.2]"} | ||
), | ||
"standard": SplitRandomMix({"test": "test"}), | ||
} | ||
|
||
subsets = ["doqa_travel", "doqa_cooking", "doqa_movies", "doc2dial", "hybridial"] | ||
for split in splits_random_mixes: | ||
for subset in subsets: | ||
card = TaskCard( | ||
loader=LoadHF(path="nvidia/ChatRAG-Bench", name=subset, split="test"), | ||
preprocess_steps=[ | ||
splits_random_mixes[split], | ||
Shuffle(), | ||
Copy( | ||
field_to_field={ | ||
"ctxs/*/text": "contexts", | ||
"messages": "dialog", | ||
"answers": "reference_answers", | ||
} | ||
), | ||
Set( | ||
fields={ | ||
"contexts_ids": [], | ||
} | ||
), | ||
SerializeOpenAiFormatDialog( | ||
field="dialog", | ||
to_field="question", | ||
format="formats.user_assistant", | ||
slice_first_and_last_turns_format=True, | ||
last_response_to_field="dummy", | ||
), | ||
], | ||
task="tasks.rag.response_generation", | ||
templates=TemplatesDict( | ||
{"default": "templates.rag.response_generation.please_respond_chat"} | ||
), | ||
) | ||
|
||
# testing the card is too slow with the bert-score metric, so dropping it | ||
card_for_test = deepcopy(card) | ||
card_for_test.task.metrics = [ | ||
"metrics.rouge", | ||
] | ||
|
||
test_card( | ||
card_for_test, | ||
strict=True, | ||
demos_taken_from="test", | ||
) | ||
add_to_catalog( | ||
card, | ||
f"cards.rag.response_generation.chat_rag_bench.{'train.' if split=='train' else ''}user_assistant_format.{subset}", | ||
overwrite=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from copy import deepcopy | ||
|
||
from unitxt import add_to_catalog | ||
from unitxt.blocks import ( | ||
LoadHF, | ||
SplitRandomMix, | ||
TaskCard, | ||
TemplatesDict, | ||
) | ||
from unitxt.operators import ( | ||
Copy, | ||
ListFieldValues, | ||
Shuffle, | ||
) | ||
from unitxt.test_utils.card import test_card | ||
|
||
card = TaskCard( | ||
loader=LoadHF( | ||
path="umarbutler/open-australian-legal-qa", | ||
), | ||
preprocess_steps=[ | ||
SplitRandomMix( | ||
{"train": "train[0.5]", "validation": "train[0.2]", "test": "train[0.3]"} | ||
), | ||
Shuffle(), | ||
Copy( | ||
field_to_field={ | ||
"source/text": "contexts", | ||
"answer": "reference_answers", | ||
"source/citation": "contexts_ids", | ||
} | ||
), | ||
ListFieldValues(fields=["reference_answers"], to_field="reference_answers"), | ||
ListFieldValues(fields=["contexts"], to_field="contexts"), | ||
ListFieldValues(fields=["contexts_ids"], to_field="contexts_ids"), | ||
], | ||
task="tasks.rag.response_generation", | ||
templates=TemplatesDict( | ||
{"default": "templates.rag.response_generation.please_respond_chat"} | ||
), | ||
) | ||
|
||
# testing the card is too slow with the bert-score metric, so dropping it | ||
card_for_test = deepcopy(card) | ||
card_for_test.task.metrics = ["metrics.rouge"] | ||
|
||
test_card( | ||
card_for_test, | ||
strict=True, | ||
demos_taken_from="test", | ||
) | ||
add_to_catalog( | ||
card, "cards.rag.response_generation.train.open_australian_legal_qa", overwrite=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
...og/cards/rag/response_generation/chat_rag_bench/train/user_assistant_format/doc2dial.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "nvidia/ChatRAG-Bench", | ||
"name": "doc2dial", | ||
"split": "test" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"train": "test[0.6]", | ||
"validation": "test[0.2]", | ||
"test": "test[0.2]" | ||
} | ||
}, | ||
{ | ||
"__type__": "shuffle" | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"ctxs/*/text": "contexts", | ||
"messages": "dialog", | ||
"answers": "reference_answers" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"contexts_ids": [] | ||
} | ||
}, | ||
{ | ||
"__type__": "serialize_open_ai_format_dialog", | ||
"field": "dialog", | ||
"to_field": "question", | ||
"format": "formats.user_assistant", | ||
"slice_first_and_last_turns_format": true, | ||
"last_response_to_field": "dummy" | ||
} | ||
], | ||
"task": "tasks.rag.response_generation", | ||
"templates": { | ||
"default": "templates.rag.response_generation.please_respond_chat" | ||
} | ||
} |
48 changes: 48 additions & 0 deletions
48
...ards/rag/response_generation/chat_rag_bench/train/user_assistant_format/doqa_cooking.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
{ | ||
"__type__": "task_card", | ||
"loader": { | ||
"__type__": "load_hf", | ||
"path": "nvidia/ChatRAG-Bench", | ||
"name": "doqa_cooking", | ||
"split": "test" | ||
}, | ||
"preprocess_steps": [ | ||
{ | ||
"__type__": "split_random_mix", | ||
"mix": { | ||
"train": "test[0.6]", | ||
"validation": "test[0.2]", | ||
"test": "test[0.2]" | ||
} | ||
}, | ||
{ | ||
"__type__": "shuffle" | ||
}, | ||
{ | ||
"__type__": "copy", | ||
"field_to_field": { | ||
"ctxs/*/text": "contexts", | ||
"messages": "dialog", | ||
"answers": "reference_answers" | ||
} | ||
}, | ||
{ | ||
"__type__": "set", | ||
"fields": { | ||
"contexts_ids": [] | ||
} | ||
}, | ||
{ | ||
"__type__": "serialize_open_ai_format_dialog", | ||
"field": "dialog", | ||
"to_field": "question", | ||
"format": "formats.user_assistant", | ||
"slice_first_and_last_turns_format": true, | ||
"last_response_to_field": "dummy" | ||
} | ||
], | ||
"task": "tasks.rag.response_generation", | ||
"templates": { | ||
"default": "templates.rag.response_generation.please_respond_chat" | ||
} | ||
} |
Oops, something went wrong.