Skip to content

Commit

Permalink
Add CloseTextSampler and FixedIndicesSampler (#1034)
Browse files Browse the repository at this point in the history
* Add CloseTextSampler

That returns demos that are textually close to the current instance.

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Make sampler call pass  current instance

Added end 2 end test of sampler that depends on output

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Added FixedIndicesSampler(Sampler):

Selects a fix set of samples based on a list of indices from the demo pool

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Made splitter currently use random_generators

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Changed all Sample randomization

To use common code to create randomizer per instance

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Updated demos in test

After a non backward compatible change

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Updated demos in test

After a non backward compatible change

Signed-off-by: Yoav Katz <katz@il.ibm.com>

---------

Signed-off-by: Yoav Katz <katz@il.ibm.com>
  • Loading branch information
yoavkatz committed Jul 22, 2024
1 parent 94daea3 commit 904c8a5
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 50 deletions.
96 changes: 80 additions & 16 deletions src/unitxt/splitters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import itertools
from abc import abstractmethod
from copy import deepcopy
from random import Random
from typing import Dict, List
from difflib import get_close_matches
from typing import Dict, List, Optional

from .artifact import Artifact
from .dict_utils import dict_get
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import new_random_generator
from .split_utils import (
Expand Down Expand Up @@ -109,7 +110,6 @@ def process(self, multi_stream: MultiStream) -> MultiStream:

class Sampler(Artifact):
sample_size: int = None
random_generator: Random = new_random_generator(sub_seed="Sampler")

def prepare(self):
super().prepare()
Expand All @@ -123,17 +123,15 @@ def set_size(self, size):
size = int(size)
self.sample_size = size

def init_new_random_generator(self):
self.random_generator = new_random_generator(
sub_seed="init_new_random_generator"
)

@abstractmethod
def sample(
self, instances_pool: List[Dict[str, object]]
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
) -> List[Dict[str, object]]:
pass

def get_random_generator_based_on_instance(self, instance):
return new_random_generator(sub_seed={**instance["input_fields"]})

def filter_source_by_instance(
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
) -> List[Dict[str, object]]:
Expand All @@ -151,11 +149,74 @@ def filter_source_by_instance(


class RandomSampler(Sampler):
"""Selects a random sample of instances."""

def sample(
self,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
instances_pool = list(instances_pool)
random_generator = self.get_random_generator_based_on_instance(instance)
return random_generator.sample(instances_pool, self.sample_size)


class FixedIndicesSampler(Sampler):
"""Selects a fix set of samples based on a list of indices."""

indices: List[int]

def sample(
self,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
num_instances = len(instances_pool)

instances = []
for index in self.indices:
if index >= num_instances:
raise ValueError(
f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
)
instances.append(instances_pool[index])
return instances


class CloseTextSampler(Sampler):
"""Selects the samples of instances which are the closest textual match to the given instance.
Comparison is done based on a given field in the instance.
"""

field: str

def sample(
self, instances_pool: List[Dict[str, object]]
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
) -> List[Dict[str, object]]:
field = f"input_fields/{self.field}"
value = dict_get(instance, field)

instances_pool = list(instances_pool)
return self.random_generator.sample(instances_pool, self.sample_size)

# Get 'sample_size' closest matchest texts based on field
options = []
for instance_in_pool in instances_pool:
options.append(dict_get(instance_in_pool, field))
closest_matches = get_close_matches(
value, options, n=self.sample_size, cutoff=0
)
# Randmly select 'sample_size' instances that are from the closest matches text
# (There may be multiple instance with same text in the given field, and the order returned is
# is also randomized )
instances_pool = [
instance_in_pool
for instance_in_pool in instances_pool
if dict_get(instance_in_pool, field) in closest_matches
]
random_generator = self.get_random_generator_based_on_instance(instance)
return random_generator.sample(instances_pool, self.sample_size)


class DiverseLabelsSampler(Sampler):
Expand Down Expand Up @@ -237,12 +298,15 @@ def divide_by_repr(self, exemplars_pool):
return labels

def sample(
self, instances_pool: List[Dict[str, object]]
self,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
if self.labels_cache is None:
self.labels_cache = self.divide_by_repr(instances_pool)
all_labels = list(self.labels_cache.keys())
self.random_generator.shuffle(all_labels)
random_generator = self.get_random_generator_based_on_instance(instance)
random_generator.shuffle(all_labels)
from collections import Counter

if self.sample_size > len(instances_pool):
Expand All @@ -263,10 +327,10 @@ def sample(

result = []
for label, allocation in allocations.items():
sample = self.random_generator.sample(self.labels_cache[label], allocation)
sample = random_generator.sample(self.labels_cache[label], allocation)
result.extend(sample)

self.random_generator.shuffle(result)
random_generator.shuffle(result)
return result


Expand Down Expand Up @@ -300,7 +364,7 @@ def process(
raise ValueError(
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
)
sampled_instances = self.sampler.sample(source_stream)
sampled_instances = self.sampler.sample(source_stream, instance)
instance[self.target_field] = sampled_instances
return instance
except FaultyStreamError as e:
Expand Down
4 changes: 1 addition & 3 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):

def before_process_multi_stream(self):
super().before_process_multi_stream()
if self.sampler: # e.g. when num_demos is 0, the sampler may not be initialized
self.sampler.init_new_random_generator()

def verify(self):
super().verify()
Expand Down Expand Up @@ -362,7 +360,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
demos_field (str, optional): Field name for demos. Default is "demos".
demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True
sampler (Sampler, optional): Sampler object to be used in the recipe.
sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0.
steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
instruction_card_index (int, optional): Index of instruction card to be used
Expand Down
18 changes: 16 additions & 2 deletions tests/library/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,14 @@ def test_produce_with_recipe(self):

target = {
"metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"],
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood.\nhypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ",
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n"
"premise: When Tatyana reached the cabin, her mother was sleeping. "
"She was careful not to disturb her, undressing and climbing back "
"into her berth.\n"
"hypothesis: mother was careful not to disturb her, undressing and "
"climbing back into her berth.\n"
"The entailment class is entailment\n\n"
"premise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ",
"target": "?",
"references": ["?"],
"task_data": '{"text_a": "It works perfectly", '
Expand Down Expand Up @@ -164,7 +171,14 @@ def test_produce_with_recipe_with_list_of_instances(self):

target = {
"metrics": ["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"],
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood.\nhypothesis: The police were trying to stop the drug trade in the neighborhood.\nThe entailment class is not entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ",
"source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n"
"premise: When Tatyana reached the cabin, her mother was sleeping. "
"She was careful not to disturb her, undressing and climbing back "
"into her berth.\n"
"hypothesis: mother was careful not to disturb her, undressing and "
"climbing back into her berth.\n"
"The entailment class is entailment\n\n"
"premise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ",
"target": "?",
"references": ["?"],
"task_data": '{"text_a": "It works perfectly", '
Expand Down
73 changes: 48 additions & 25 deletions tests/library/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,54 @@ def test_standard_recipe_production_with_demos(self):

target = {
"metrics": ["metrics.accuracy"],
"source": "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n\n\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nThe single group within society that is most vulnerable to reference group influence is:\nA. The older consumer who feels somewhat left out of things.\nB. The married women, many of whom feel a need for stability in their lives.\nC. New immigrants who really want to assimilate into their new culture.\nD. Children, who base most of their buying decisions on outside influences.\nAnswer:\nAgent: D\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\n Which of the following is an assumption in Maslow's hierarchy of needs?\nA. Needs are dependent on culture and also on social class.\nB. Lower-level needs must be at least partially satisfied before higher needs can affect behaviour.\nC. Needs are not prioritized or arranged in any particular order.\nD. Satisfied needs are motivators, and new needs emerge when current needs remain unmet.\nAnswer:\nAgent: B\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nIn an organization, the group of people tasked with buying decisions is referred to as the _______________.\nA. Outsourcing unit.\nB. Procurement centre.\nC. Chief executive unit.\nD. Decision-making unit.\nAnswer:\nAgent: D\n\n\nUser:The following are multiple choice questions (with answers) about testing.\n\nwhat?\nA. yes\nB. not\nC. maybe\nAnswer:\nAgent:",
"source": """<<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
User: The following are multiple choice questions (with answers) about marketing.
Although the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as:
A. Care lines.
B. Direct mail.
C. Inserts.
D. Door to door.
Answer:
Agent: D
User: The following are multiple choice questions (with answers) about marketing.
_____________ is a natural outcome when combining demographic and geographic variables.
A. Geodemographics
B. Product differentiation.
C. ANSOFF matrix.
D. Brand management.
Answer:
Agent: A
User: The following are multiple choice questions (with answers) about marketing.
In an organization, the group of people tasked with buying decisions is referred to as the _______________.
A. Outsourcing unit.
B. Procurement centre.
C. Chief executive unit.
D. Decision-making unit.
Answer:
Agent: D
User:The following are multiple choice questions (with answers) about testing.
what?
A. yes
B. not
C. maybe
Answer:
Agent:""",
"target": " C",
"references": [" C"],
"task_data": '{"topic": "testing",'
Expand Down Expand Up @@ -544,30 +591,6 @@ def test_recipe_loaded_from_arguments_and_overwrites_only(self):
first_inst = next(iterator)
self.assertListEqual(["metrics.accuracy"], first_inst["metrics"])

def test_standard_recipe_with_a_sampler(self):
"""Check that the sampler is re-initialized before processing a recipe.
To do so, save the random generator within the sampler before activating the recipe,
and compare it to the random generator within the sampler after the revipe was called.
The two generators should be different objects, indicating that the sampler was properly
re-initialized during the preparation of the recipe.
"""
recipe = StandardRecipeWithIndexes(
card="cards.sst2",
template_card_index=0,
max_train_instances=0,
max_test_instances=2,
num_demos=1,
demos_pool_size=10,
)
sampler = recipe.card.sampler

random_generator1 = sampler.random_generator
recipe()
random_generator2 = sampler.random_generator

self.assertNotEqual(random_generator1, random_generator2)

def test_standard_recipe_with_a_missing_sampler(self):
"""Check that initializing a recipe with a card that does not have a sampler raises an exception."""
task_card, _ = copy.deepcopy(fetch_artifact("cards.sst2"))
Expand Down
Loading

0 comments on commit 904c8a5

Please sign in to comment.