From 904c8a5fa122350f310901f1cc5f501509772002 Mon Sep 17 00:00:00 2001 From: Yoav Katz <68273864+yoavkatz@users.noreply.github.com> Date: Mon, 22 Jul 2024 20:45:28 +0300 Subject: [PATCH] Add CloseTextSampler and FixedIndicesSampler (#1034) * Add CloseTextSampler That returns demos that are textually close to the current instance. Signed-off-by: Yoav Katz * Make sampler call pass current instance Added end 2 end test of sampler that depends on output Signed-off-by: Yoav Katz * Added FixedIndicesSampler(Sampler): Selects a fix set of samples based on a list of indices from the demo pool Signed-off-by: Yoav Katz * Made splitter currently use random_generators Signed-off-by: Yoav Katz * Changed all Sample randomization To use common code to create randomizer per instance Signed-off-by: Yoav Katz * Updated demos in test After a non backward compatible change Signed-off-by: Yoav Katz * Updated demos in test After a non backward compatible change Signed-off-by: Yoav Katz --------- Signed-off-by: Yoav Katz --- src/unitxt/splitters.py | 96 +++++++++++++++---- src/unitxt/standard.py | 4 +- tests/library/test_api.py | 18 +++- tests/library/test_recipe.py | 73 ++++++++++----- tests/library/test_splitters.py | 157 +++++++++++++++++++++++++++++++- 5 files changed, 298 insertions(+), 50 deletions(-) diff --git a/src/unitxt/splitters.py b/src/unitxt/splitters.py index f181d147c..524b467df 100644 --- a/src/unitxt/splitters.py +++ b/src/unitxt/splitters.py @@ -1,10 +1,11 @@ import itertools from abc import abstractmethod from copy import deepcopy -from random import Random -from typing import Dict, List +from difflib import get_close_matches +from typing import Dict, List, Optional from .artifact import Artifact +from .dict_utils import dict_get from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator from .random_utils import new_random_generator from .split_utils import ( @@ -109,7 +110,6 @@ def process(self, multi_stream: MultiStream) -> MultiStream: class Sampler(Artifact): sample_size: int = None - random_generator: Random = new_random_generator(sub_seed="Sampler") def prepare(self): super().prepare() @@ -123,17 +123,15 @@ def set_size(self, size): size = int(size) self.sample_size = size - def init_new_random_generator(self): - self.random_generator = new_random_generator( - sub_seed="init_new_random_generator" - ) - @abstractmethod def sample( - self, instances_pool: List[Dict[str, object]] + self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] ) -> List[Dict[str, object]]: pass + def get_random_generator_based_on_instance(self, instance): + return new_random_generator(sub_seed={**instance["input_fields"]}) + def filter_source_by_instance( self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] ) -> List[Dict[str, object]]: @@ -151,11 +149,74 @@ def filter_source_by_instance( class RandomSampler(Sampler): + """Selects a random sample of instances.""" + + def sample( + self, + instances_pool: List[Dict[str, object]], + instance: Optional[Dict[str, object]], + ) -> List[Dict[str, object]]: + instances_pool = list(instances_pool) + random_generator = self.get_random_generator_based_on_instance(instance) + return random_generator.sample(instances_pool, self.sample_size) + + +class FixedIndicesSampler(Sampler): + """Selects a fix set of samples based on a list of indices.""" + + indices: List[int] + + def sample( + self, + instances_pool: List[Dict[str, object]], + instance: Optional[Dict[str, object]], + ) -> List[Dict[str, object]]: + num_instances = len(instances_pool) + + instances = [] + for index in self.indices: + if index >= num_instances: + raise ValueError( + f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})" + ) + instances.append(instances_pool[index]) + return instances + + +class CloseTextSampler(Sampler): + """Selects the samples of instances which are the closest textual match to the given instance. + + Comparison is done based on a given field in the instance. + + """ + + field: str + def sample( - self, instances_pool: List[Dict[str, object]] + self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] ) -> List[Dict[str, object]]: + field = f"input_fields/{self.field}" + value = dict_get(instance, field) + instances_pool = list(instances_pool) - return self.random_generator.sample(instances_pool, self.sample_size) + + # Get 'sample_size' closest matchest texts based on field + options = [] + for instance_in_pool in instances_pool: + options.append(dict_get(instance_in_pool, field)) + closest_matches = get_close_matches( + value, options, n=self.sample_size, cutoff=0 + ) + # Randmly select 'sample_size' instances that are from the closest matches text + # (There may be multiple instance with same text in the given field, and the order returned is + # is also randomized ) + instances_pool = [ + instance_in_pool + for instance_in_pool in instances_pool + if dict_get(instance_in_pool, field) in closest_matches + ] + random_generator = self.get_random_generator_based_on_instance(instance) + return random_generator.sample(instances_pool, self.sample_size) class DiverseLabelsSampler(Sampler): @@ -237,12 +298,15 @@ def divide_by_repr(self, exemplars_pool): return labels def sample( - self, instances_pool: List[Dict[str, object]] + self, + instances_pool: List[Dict[str, object]], + instance: Optional[Dict[str, object]], ) -> List[Dict[str, object]]: if self.labels_cache is None: self.labels_cache = self.divide_by_repr(instances_pool) all_labels = list(self.labels_cache.keys()) - self.random_generator.shuffle(all_labels) + random_generator = self.get_random_generator_based_on_instance(instance) + random_generator.shuffle(all_labels) from collections import Counter if self.sample_size > len(instances_pool): @@ -263,10 +327,10 @@ def sample( result = [] for label, allocation in allocations.items(): - sample = self.random_generator.sample(self.labels_cache[label], allocation) + sample = random_generator.sample(self.labels_cache[label], allocation) result.extend(sample) - self.random_generator.shuffle(result) + random_generator.shuffle(result) return result @@ -300,7 +364,7 @@ def process( raise ValueError( f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}." ) - sampled_instances = self.sampler.sample(source_stream) + sampled_instances = self.sampler.sample(source_stream, instance) instance[self.target_field] = sampled_instances return instance except FaultyStreamError as e: diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index eed5cde04..9d86c46b6 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -58,8 +58,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator): def before_process_multi_stream(self): super().before_process_multi_stream() - if self.sampler: # e.g. when num_demos is 0, the sampler may not be initialized - self.sampler.init_new_random_generator() def verify(self): super().verify() @@ -362,7 +360,7 @@ class StandardRecipe(StandardRecipeWithIndexes): demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train". demos_field (str, optional): Field name for demos. Default is "demos". demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True - sampler (Sampler, optional): Sampler object to be used in the recipe. + sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0. steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe. augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text instruction_card_index (int, optional): Index of instruction card to be used diff --git a/tests/library/test_api.py b/tests/library/test_api.py index aa2421eee..a7601e463 100644 --- a/tests/library/test_api.py +++ b/tests/library/test_api.py @@ -125,7 +125,14 @@ def test_produce_with_recipe(self): target = { "metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], - "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood.\nhypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", + "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n" + "premise: When Tatyana reached the cabin, her mother was sleeping. " + "She was careful not to disturb her, undressing and climbing back " + "into her berth.\n" + "hypothesis: mother was careful not to disturb her, undressing and " + "climbing back into her berth.\n" + "The entailment class is entailment\n\n" + "premise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", "target": "?", "references": ["?"], "task_data": '{"text_a": "It works perfectly", ' @@ -164,7 +171,14 @@ def test_produce_with_recipe_with_list_of_instances(self): target = { "metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], - "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood.\nhypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", + "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n" + "premise: When Tatyana reached the cabin, her mother was sleeping. " + "She was careful not to disturb her, undressing and climbing back " + "into her berth.\n" + "hypothesis: mother was careful not to disturb her, undressing and " + "climbing back into her berth.\n" + "The entailment class is entailment\n\n" + "premise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", "target": "?", "references": ["?"], "task_data": '{"text_a": "It works perfectly", ' diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index a4d336cfc..19f169738 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -168,7 +168,54 @@ def test_standard_recipe_production_with_demos(self): target = { "metrics": ["metrics.accuracy"], - "source": "<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n\n\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nThe single group within society that is most vulnerable to reference group influence is:\nA. The older consumer who feels somewhat left out of things.\nB. The married women, many of whom feel a need for stability in their lives.\nC. New immigrants who really want to assimilate into their new culture.\nD. Children, who base most of their buying decisions on outside influences.\nAnswer:\nAgent: D\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\n Which of the following is an assumption in Maslow's hierarchy of needs?\nA. Needs are dependent on culture and also on social class.\nB. Lower-level needs must be at least partially satisfied before higher needs can affect behaviour.\nC. Needs are not prioritized or arranged in any particular order.\nD. Satisfied needs are motivators, and new needs emerge when current needs remain unmet.\nAnswer:\nAgent: B\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nIn an organization, the group of people tasked with buying decisions is referred to as the _______________.\nA. Outsourcing unit.\nB. Procurement centre.\nC. Chief executive unit.\nD. Decision-making unit.\nAnswer:\nAgent: D\n\n\nUser:The following are multiple choice questions (with answers) about testing.\n\nwhat?\nA. yes\nB. not\nC. maybe\nAnswer:\nAgent:", + "source": """<> +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. +<> + + + + +User: The following are multiple choice questions (with answers) about marketing. + +Although the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as: +A. Care lines. +B. Direct mail. +C. Inserts. +D. Door to door. +Answer: +Agent: D + +User: The following are multiple choice questions (with answers) about marketing. + + _____________ is a natural outcome when combining demographic and geographic variables. +A. Geodemographics +B. Product differentiation. +C. ANSOFF matrix. +D. Brand management. +Answer: +Agent: A + +User: The following are multiple choice questions (with answers) about marketing. + +In an organization, the group of people tasked with buying decisions is referred to as the _______________. +A. Outsourcing unit. +B. Procurement centre. +C. Chief executive unit. +D. Decision-making unit. +Answer: +Agent: D + + +User:The following are multiple choice questions (with answers) about testing. + +what? +A. yes +B. not +C. maybe +Answer: +Agent:""", "target": " C", "references": [" C"], "task_data": '{"topic": "testing",' @@ -544,30 +591,6 @@ def test_recipe_loaded_from_arguments_and_overwrites_only(self): first_inst = next(iterator) self.assertListEqual(["metrics.accuracy"], first_inst["metrics"]) - def test_standard_recipe_with_a_sampler(self): - """Check that the sampler is re-initialized before processing a recipe. - - To do so, save the random generator within the sampler before activating the recipe, - and compare it to the random generator within the sampler after the revipe was called. - The two generators should be different objects, indicating that the sampler was properly - re-initialized during the preparation of the recipe. - """ - recipe = StandardRecipeWithIndexes( - card="cards.sst2", - template_card_index=0, - max_train_instances=0, - max_test_instances=2, - num_demos=1, - demos_pool_size=10, - ) - sampler = recipe.card.sampler - - random_generator1 = sampler.random_generator - recipe() - random_generator2 = sampler.random_generator - - self.assertNotEqual(random_generator1, random_generator2) - def test_standard_recipe_with_a_missing_sampler(self): """Check that initializing a recipe with a card that does not have a sampler raises an exception.""" task_card, _ = copy.deepcopy(fetch_artifact("cards.sst2")) diff --git a/tests/library/test_splitters.py b/tests/library/test_splitters.py index bac1943f9..c4833bc83 100644 --- a/tests/library/test_splitters.py +++ b/tests/library/test_splitters.py @@ -1,6 +1,10 @@ import copy -from unitxt.splitters import DiverseLabelsSampler +from unitxt.api import load_dataset +from unitxt.blocks import TaskCard +from unitxt.collections_operators import Wrap +from unitxt.loaders import LoadFromDictionary +from unitxt.splitters import CloseTextSampler, DiverseLabelsSampler, FixedIndicesSampler from tests.utils import UnitxtTestCase @@ -35,7 +39,10 @@ def test_sample(self): self.new_exemplar(choices, ["cow"], "Moo1"), self.new_exemplar(choices, ["duck"], "Quack"), ] - result = sampler.sample(instances) + result = sampler.sample( + instances, + self.new_exemplar(choices, ["any"], "any"), + ) from collections import Counter @@ -59,7 +66,10 @@ def test_sample_no_empty_labels(self): self.new_exemplar(choices, ["cow"], "Moo1"), self.new_exemplar(choices, ["duck"], "Quack"), ] - result = sampler.sample(instances) + result = sampler.sample( + instances, + self.new_exemplar(choices, ["any"], "any"), + ) from collections import Counter @@ -79,7 +89,9 @@ def test_sample_list(self): self.new_exemplar(choices, ["dog"], "Bark2"), self.new_exemplar(choices, ["duck"], "Quack"), ] - result = sampler.sample(instances) + result = sampler.sample( + instances, self.new_exemplar(choices, ["any"], "any") + ) from collections import Counter counts = Counter() @@ -146,3 +158,140 @@ def test_filter_with_bad_input(self): f"'input_fields' field is missing from '{instance}'.", str(cm.exception), ) + + +class TestCloseTextSampler(UnitxtTestCase): + """Tests for the CloseTextSampler object.""" + + @staticmethod + def new_exemplar(question: str, answer: str): + """Return an exemplar in a correct format.""" + return { + "input_fields": {"question": question, "answer": answer}, + } + + def test_sample(self): + instances = [ + self.new_exemplar("What is your name?", "John"), + self.new_exemplar("In which country is Paris located?", "France"), + self.new_exemplar("What's the time?", "22:00"), + self.new_exemplar("What is your name, please?", "Mary"), + ] + + num_samples = 2 + sampler = CloseTextSampler(num_samples, field="question") + + results = sampler.sample( + instances, self.new_exemplar("What's your name?", "don't know") + ) + self.assertEqual(results, [instances[0], instances[3]]) + + results = sampler.sample( + instances, self.new_exemplar("What is the time?", "don't know") + ) + self.assertEqual(results, [instances[2], instances[0]]) + + num_samples = 1 + sampler = CloseTextSampler(num_samples, field="answer") + results = sampler.sample( + instances, self.new_exemplar("Who do I love?", "Mary Lu") + ) + self.assertEqual(results, [instances[3]]) + + def test_filter_with_wrong_field(self): + num_samples = 2 + sampler = CloseTextSampler(num_samples, field="wrong_field") + instances = [ + self.new_exemplar("What is your name?", "John"), + ] + instance = self.new_exemplar("What's your name?", "don't know") + with self.assertRaises(ValueError) as cm: + sampler.sample(instances, instance) + self.assertIn( + 'query "input_fields/wrong_field" did not match any item in dict', + str(cm.exception), + ) + + def test_end2end(self): + data = { + "train": [ + {"question": "What is your name?", "answer": "John"}, + {"question": "In which country is Paris located?", "answer": "France"}, + {"question": "At what time do we they eat dinner?", "answer": "22:00"}, + {"question": "What's your name, please?", "answer": "Mary"}, + {"question": "Is this your car?", "answer": "yes"}, + {"question": "What is your name?", "answer": "Sunny"}, + ], + "test": [ + {"question": "What's your name?", "answer": "John"}, + ], + } + + card = TaskCard( + loader=LoadFromDictionary(data=data), + task="tasks.qa.open", + preprocess_steps=[Wrap(field="answer", inside="list", to_field="answers")], + ) + + dataset = load_dataset( + card=card, + template="templates.qa.open.title", + demos_pool_size=5, + num_demos=2, + sampler=CloseTextSampler(field="question"), + ) + expected_output = """Answer the question. +Question: +What is your name? +Answer: +John + +Question: +What's your name, please? +Answer: +Mary + +Question: +What's your name? +Answer: +""" + self.assertEqual(dataset["test"][0]["source"], expected_output) + + +class TestFixedIndicesSampler(UnitxtTestCase): + """Tests for the FixedIndicesSampler object.""" + + @staticmethod + def new_exemplar(question: str, answer: str): + """Return an exemplar in a correct format.""" + return { + "input_fields": {"question": question, "answer": answer}, + } + + def test_sample(self): + instances = [ + self.new_exemplar("What is your name?", "John"), + self.new_exemplar("In which country is Paris located?", "France"), + self.new_exemplar("What's the time?", "22:00"), + self.new_exemplar("What is your name, please?", "Mary"), + ] + instance = self.new_exemplar("What's your name?", "don't know") + sampler = FixedIndicesSampler(indices=[2, 0]) + + results = sampler.sample(instances, instance) + self.assertEqual(results, [instances[2], instances[0]]) + + def test_out_of_bound_sample(self): + instances = [ + self.new_exemplar("What is your name?", "John"), + self.new_exemplar("In which country is Paris located?", "France"), + ] + + instance = self.new_exemplar("What's your name?", "don't know") + sampler = FixedIndicesSampler(indices=[2]) + with self.assertRaises(ValueError) as cm: + sampler.sample(instances, instance) + self.assertIn( + "FixedIndicesSampler 'indices' field contains index (2) which is out of bounds of the instance pool ( of size 2)", + str(cm.exception), + )