From 10d26ed44f70f02eae5012996e3fd61d9d2f3d63 Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Sun, 8 Sep 2024 19:58:27 +0300 Subject: [PATCH] Add serializers to templates and reorganize and unite all templates (#1195) * Add serializers to templates and reorganize and unite all templates Signed-off-by: elronbandel * Fixes Signed-off-by: elronbandel * Fix tests Signed-off-by: elronbandel * Fix tests Signed-off-by: elronbandel --------- Signed-off-by: elronbandel --- src/unitxt/dataset.py | 1 + src/unitxt/metric.py | 1 + src/unitxt/serializers.py | 112 ++++++++++++ src/unitxt/templates.py | 279 ++++++++++++++---------------- src/unitxt/types.py | 4 +- tests/library/test_serializers.py | 138 +++++++++++++++ tests/library/test_templates.py | 131 ++++++++++---- 7 files changed, 479 insertions(+), 187 deletions(-) create mode 100644 src/unitxt/serializers.py create mode 100644 tests/library/test_serializers.py diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index dc7026372..df1c0e521 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -40,6 +40,7 @@ from .recipe import __file__ as _ from .register import __file__ as _ from .schema import __file__ as _ +from .serializers import __file__ as _ from .settings_utils import get_constants from .span_lableing_operators import __file__ as _ from .split_utils import __file__ as _ diff --git a/src/unitxt/metric.py b/src/unitxt/metric.py index 5260e6b00..126915197 100644 --- a/src/unitxt/metric.py +++ b/src/unitxt/metric.py @@ -39,6 +39,7 @@ from .recipe import __file__ as _ from .register import __file__ as _ from .schema import __file__ as _ +from .serializers import __file__ as _ from .settings_utils import get_constants from .span_lableing_operators import __file__ as _ from .split_utils import __file__ as _ diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py new file mode 100644 index 000000000..6a895af70 --- /dev/null +++ b/src/unitxt/serializers.py @@ -0,0 +1,112 @@ +import csv +import io +from abc import abstractmethod +from typing import Any, Dict, Union + +from .operators import InstanceFieldOperator +from .type_utils import isoftype +from .types import Dialog, Image, Number, Table, Text + + +class Serializer(InstanceFieldOperator): + def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str: + return self.serialize(value, instance) + + @abstractmethod + def serialize(self, value: Any, instance: Dict[str, Any]) -> str: + pass + + +class DefaultSerializer(Serializer): + def serialize(self, value: Any, instance: Dict[str, Any]) -> str: + return str(value) + + +class DefaultListSerializer(Serializer): + def serialize(self, value: Any, instance: Dict[str, Any]) -> str: + if isinstance(value, list): + return ", ".join(str(item) for item in value) + return str(value) + + +class DialogSerializer(Serializer): + def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str: + # Convert the Dialog into a string representation, typically combining roles and content + return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value) + + +class NumberSerializer(Serializer): + def serialize(self, value: Number, instance: Dict[str, Any]) -> str: + # Check if the value is an integer or a float + if isinstance(value, int): + return str(value) + # For floats, format to one decimal place + if isinstance(value, float): + return f"{value:.1f}" + raise ValueError("Unsupported type for NumberSerializer") + + +class NumberQuantizingSerializer(NumberSerializer): + quantum: Union[float, int] = 0.1 + + def serialize(self, value: Number, instance: Dict[str, Any]) -> str: + if isoftype(value, Number): + quantized_value = round(value / self.quantum) / (1 / self.quantum) + if isinstance(self.quantum, int): + quantized_value = int(quantized_value) + return str(quantized_value) + raise ValueError("Unsupported type for NumberSerializer") + + +class TableSerializer(Serializer): + def serialize(self, value: Table, instance: Dict[str, Any]) -> str: + output = io.StringIO() + writer = csv.writer(output, lineterminator="\n") + + # Write the header and rows to the CSV writer + writer.writerow(value["header"]) + writer.writerows(value["rows"]) + + # Retrieve the CSV string + return output.getvalue().strip() + + +class ImageSerializer(Serializer): + def serialize(self, value: Image, instance: Dict[str, Any]) -> str: + if "media" not in instance: + instance["media"] = {} + if "images" not in instance["media"]: + instance["media"]["images"] = [] + idx = len(instance["media"]["images"]) + instance["media"]["images"].append(value) + return f'' + + +class DynamicSerializer(Serializer): + image: Serializer = ImageSerializer() + number: Serializer = DefaultSerializer() + table: Serializer = TableSerializer() + dialog: Serializer = DialogSerializer() + text: Serializer = DefaultSerializer() + list: Serializer = DefaultSerializer() + + def serialize(self, value: Any, instance: Dict[str, Any]) -> Any: + if isoftype(value, Image): + return self.image.serialize(value, instance) + + if isoftype(value, Table): + return self.table.serialize(value, instance) + + if isoftype(value, Dialog) and len(value) > 0: + return self.dialog.serialize(value, instance) + + if isoftype(value, Text): + return self.text.serialize(value, instance) + + if isoftype(value, Number): + return self.number.serialize(value, instance) + + if isinstance(value, list): + return self.list.serialize(value, instance) + + return str(value) diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index b99fd7286..4fc557aab 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -10,6 +10,12 @@ from .error_utils import Documentation, UnitxtError from .operator import InstanceOperator from .random_utils import new_random_generator +from .serializers import ( + DefaultListSerializer, + DynamicSerializer, + NumberQuantizingSerializer, + Serializer, +) from .settings_utils import get_constants from .type_utils import isoftype @@ -46,17 +52,19 @@ class Template(InstanceOperator): instruction: str = NonPositionalField(default="") target_prefix: str = NonPositionalField(default="") title_fields: List[str] = NonPositionalField(default_factory=list) + serializer: Serializer = NonPositionalField( + default_factory=lambda: DynamicSerializer(list=DefaultListSerializer()) + ) def input_fields_to_instruction_and_target_prefix(self, input_fields): instruction = self.apply_formatting( - input_fields, "input field", self.instruction, "instruction", serialize=True + input_fields, "input field", self.instruction, "instruction" ) target_prefix = self.apply_formatting( input_fields, "input field", self.target_prefix, "target_prefix", - serialize=True, ) return instruction, target_prefix @@ -65,6 +73,12 @@ def preprocess_input_and_reference_fields( ) -> Tuple[Dict[str, Any], Dict[str, Any]]: return input_fields, reference_fields + def preprocess_input_fields(self, input_fields: Dict[str, Any]): + return input_fields + + def preprocess_reference_fields(self, reference_fields: Dict[str, Any]): + return reference_fields + def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: @@ -78,14 +92,21 @@ def process( input_fields = instance.get("input_fields") reference_fields = instance.get("reference_fields") - input_fields, reference_fields = self.preprocess_input_and_reference_fields( - input_fields, reference_fields - ) + + if stream_name != constants.inference_stream: + input_fields, reference_fields = self.preprocess_input_and_reference_fields( + input_fields, reference_fields + ) + + input_fields = self.preprocess_input_fields(input_fields) self.set_titles(input_fields) - source = self.input_fields_to_source(input_fields) + + serialized_inputs = self.serialize(input_fields, instance) + + source = self.input_fields_to_source(serialized_inputs) instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix( - input_fields + serialized_inputs ) result = { @@ -97,19 +118,33 @@ def process( } if stream_name == constants.inference_stream: - return result + return self.post_process_instance(result) if reference_fields is None: raise ValueError("Should have reference_fields") + reference_fields = self.preprocess_reference_fields(reference_fields) + + serialized_references = self.serialize( + reference_fields, instance + ) # Dict[str, str] + target, references = self.reference_fields_to_target_and_references( - reference_fields + serialized_references ) result["target"] = target result["references"] = references - return result + return self.post_process_instance(result) + + def post_process_instance(self, instance): + return instance + + def serialize( + self, data: Dict[str, Any], instance: Dict[str, Any] + ) -> Dict[str, str]: + return {k: self.serializer.serialize(v, instance) for k, v in data.items()} @abstractmethod def input_fields_to_source(self, input_fields: Dict[str, object]) -> str: @@ -125,21 +160,13 @@ def reference_fields_to_target_and_references( ) -> Tuple[str, List[str]]: pass - def serialize_data(self, data): - return { - k: ", ".join(str(t) for t in v) if isinstance(v, list) else v - for k, v in data.items() - } - def apply_formatting( - self, data, data_type, format_str, format_name, serialize=False + self, data: Dict[str, Any], data_type: str, format_str: str, format_name: str ) -> str: - if serialize: - data = self.serialize_data(data) try: if format_str is None: raise UnitxtError( - f"Required field 'output_format' of class {self.__class__.__name__} not set in {self.__class__.__name__}", + f"Required field '{format_name}' of class {self.__class__.__name__} not set in {self.__class__.__name__}", Documentation.ADDING_TEMPLATE, ) return format_str.format(**data) @@ -197,26 +224,21 @@ def get_template(self, instance: Dict[str, Any]) -> Template: return random_generator.choice(self.templates) -class InputOutputTemplate(Template): - """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance. - - Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references'). - """ - +class InputFormatTemplate(Template): input_format: str - output_format: str = None - def input_fields_to_source( - self, input_fields: Dict[str, object] - ) -> Tuple[str, str]: + def input_fields_to_source(self, input_fields: Dict[str, object]) -> str: return self.apply_formatting( input_fields, "input field", self.input_format, "input_format", - serialize=True, ) + +class OutputFormatTemplate(Template): + output_format: str = None + def reference_fields_to_target_and_references( self, reference_fields: Dict[str, object] ) -> str: @@ -225,12 +247,20 @@ def reference_fields_to_target_and_references( "reference field", self.output_format, "output_format", - serialize=True, ) references = [target] return target, references +class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate): + """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance. + + Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references'). + """ + + pass + + class InputOutputTemplateWithCustomTarget(InputOutputTemplate): reference: str @@ -242,14 +272,12 @@ def reference_fields_to_target_and_references( "reference field", self.output_format, "output_format", - serialize=True, ) reference = self.apply_formatting( reference_fields, "reference field", self.reference, "reference", - serialize=True, ) return target, [reference] @@ -374,22 +402,12 @@ def process_dialog(self, input_fields: Dict[str, object]): input_fields[dialog_fields.dialog_field] = dialog_str return input_fields - def preprocess_input_and_reference_fields( - self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - return self.process_dialog(input_fields), reference_fields + def preprocess_input_fields(self, input_fields: Dict[str, Any]): + return self.process_dialog(input_fields) class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate): - def preprocess_input_and_reference_fields( - self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, reference_fields = DialogTemplate.preprocess_input_and_reference_fields( - self, input_fields, reference_fields - ) - return PairwiseChoiceTemplate.preprocess_input_and_reference_fields( - self, input_fields, reference_fields - ) + pass class PairwiseComparativeRatingTemplate(InputOutputTemplate): @@ -448,10 +466,9 @@ def preprocess_input_and_reference_fields( return input_fields, reference_fields -class MultipleChoiceTemplate(Template): +class MultipleChoiceTemplate(InputFormatTemplate): """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer.""" - input_format: str target_prefix: str = "" choices_field: str = "choices" target_field: str = "label" @@ -493,7 +510,7 @@ def prepare(self): "XX", ] - def inputs_to_choices(self, data: Dict[str, object], choice_format: str) -> str: + def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str: choices = data[self.choices_field] enumrated_choices = [] for i, choice in enumerate(choices): @@ -505,12 +522,12 @@ def inputs_to_choices(self, data: Dict[str, object], choice_format: str) -> str: ) return enumrated_choices - def inputs_to_numerals(self, input_fields: Dict[str, object]) -> Tuple[str, str]: + def inputs_to_numerals(self, input_fields: Dict[str, Any]) -> Tuple[str, str]: return self.inputs_to_choices(input_fields, "{choice_numeral}") def prepare_multiple_choice_inputs( - self, input_fields: Dict[str, object] - ) -> Dict[str, object]: + self, input_fields: Dict[str, Any] + ) -> Dict[str, Any]: choices = self.inputs_to_choices(input_fields, self.source_choice_format) return { "numerals": self.inputs_to_numerals(input_fields), @@ -518,23 +535,10 @@ def prepare_multiple_choice_inputs( self.choices_field: self.choices_separator.join(choices), } - def input_fields_to_source( - self, input_fields: Dict[str, object] - ) -> Tuple[str, str]: - input_fields = self.prepare_multiple_choice_inputs(input_fields) - return self.apply_formatting( - input_fields, - "input field", - self.input_format, - "input_format", - serialize=True, - ) + def preprocess_input_fields(self, input_fields: Dict[str, Any]) -> Dict[str, Any]: + return self.prepare_multiple_choice_inputs(input_fields) - def input_fields_to_instruction_and_target_prefix(self, input_fields): - input_fields = self.prepare_multiple_choice_inputs(input_fields) - return super().input_fields_to_instruction_and_target_prefix(input_fields) - - def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> str: + def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> int: target = reference_fields[self.target_field] if not isinstance(target, int): @@ -547,9 +551,7 @@ def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> str: ) from e return target - def reference_fields_to_target_and_references( - self, reference_fields: Dict[str, object] - ) -> str: + def preprocess_reference_fields(self, reference_fields: Dict[str, Any]): target = reference_fields[self.target_field] if not isinstance(target, int): @@ -571,51 +573,40 @@ def reference_fields_to_target_and_references( Documentation.ADDING_TEMPLATE, ) from e + return {self.target_field: target} + + def reference_fields_to_target_and_references( + self, reference_fields: Dict[str, object] + ) -> str: + target = reference_fields[self.target_field] return target, [target] - def _shuffle_choices(self, instance, stream_name): - if stream_name != constants.inference_stream: - target_index = self.outputs_to_target_index(instance["reference_fields"]) - original_label_choice = instance["reference_fields"][self.choices_field][ - target_index - ] - choices = instance["input_fields"][self.choices_field] + def preprocess_input_and_reference_fields( + self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if self.shuffle_choices: + target_index = self.outputs_to_target_index(reference_fields) + original_label_choice = reference_fields[self.choices_field][target_index] + choices = input_fields[self.choices_field] + random_seed = {**input_fields} - random_seed = {**instance["input_fields"]} + random_generator = new_random_generator(random_seed) + random_generator.shuffle(choices) + input_fields[self.choices_field] = choices - random_generator = new_random_generator(random_seed) - random_generator.shuffle(choices) - instance["input_fields"][self.choices_field] = choices + reference_fields[self.choices_field] = choices + reference_fields[self.target_field] = choices.index(original_label_choice) - if stream_name == constants.inference_stream: - return instance + return input_fields, reference_fields - instance["reference_fields"][self.choices_field] = choices - instance["reference_fields"][self.target_field] = choices.index( - original_label_choice + def post_process_instance(self, instance): + instance["input_fields"]["options"] = self.inputs_to_choices( + instance["input_fields"], self.target_choice_format ) - return instance - def process( - self, instance: Dict[str, Any], stream_name: Optional[str] = None - ) -> Dict[str, Any]: - if self.shuffle_choices: - instance = self._shuffle_choices(instance, stream_name) - result = super().process(instance, stream_name) - if stream_name == constants.inference_stream: - result["input_fields"]["options"] = self.inputs_to_choices( - instance["input_fields"], self.target_choice_format - ) - else: - if "options" not in result["reference_fields"]: - result["reference_fields"]["options"] = self.inputs_to_choices( - instance["reference_fields"], self.target_choice_format - ) - return result - -class YesNoTemplate(Template): +class YesNoTemplate(InputFormatTemplate): """A template for generating binary Yes/No questions asking whether an input text is of a specific class. input_format: @@ -641,17 +632,6 @@ class YesNoTemplate(Template): yes_answer: str = "Yes" no_answer: str = "No" - def input_fields_to_source( - self, input_fields: Dict[str, object] - ) -> Tuple[str, str]: - return self.apply_formatting( - input_fields, - "input field", - self.input_format, - "input_format", - serialize=True, - ) - def reference_fields_to_target_and_references( self, reference_fields: Dict[str, object] ) -> str: @@ -695,16 +675,13 @@ class KeyValTemplate(Template): def process_dict( self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys ) -> str: - data = self.serialize_data(data) pairs = [] for key, val in data.items(): key_val = [key, str(val)] if use_keys else [str(val)] pairs.append(key_val_sep.join(key_val)) return pairs_sep.join(pairs) - def input_fields_to_source( - self, input_fields: Dict[str, object] - ) -> Tuple[str, str]: + def input_fields_to_source(self, input_fields: Dict[str, object]) -> str: return self.process_dict( input_fields, key_val_sep=self.key_val_separator, @@ -725,25 +702,14 @@ def reference_fields_to_target_and_references( class OutputQuantizingTemplate(InputOutputTemplate): - quantum: Union[float, int] = 0.1 # Now supports both int and float + serializer: DynamicSerializer = NonPositionalField( + default_factory=DynamicSerializer + ) + quantum: Union[float, int] = 0.1 - def reference_fields_to_target_and_references( - self, reference_fields: Dict[str, object] - ) -> str: - if isinstance(self.quantum, int): - # When quantum is an int, format quantized values as ints - quantized_outputs = { - key: f"{int(round(value / self.quantum) * self.quantum)}" - for key, value in reference_fields.items() - } - else: - # When quantum is a float, format quantized values with precision based on quantum - quantum_str = f"{self.quantum:.10f}".rstrip("0").rstrip(".") - quantized_outputs = { - key: f"{round(value / self.quantum) * self.quantum:{quantum_str}}" - for key, value in reference_fields.items() - } - return super().reference_fields_to_target_and_references(quantized_outputs) + def prepare(self): + super().prepare() + self.serializer.number = NumberQuantizingSerializer(quantum=self.quantum) class MultiLabelTemplate(InputOutputTemplate): @@ -753,9 +719,9 @@ class MultiLabelTemplate(InputOutputTemplate): output_format: str = "{labels}" empty_label: str = "None" - def reference_fields_to_target_and_references( - self, reference_fields: Dict[str, object] - ) -> str: + def preprocess_reference_fields( + self, reference_fields: Dict[str, Any] + ) -> Dict[str, Any]: labels = reference_fields[self.labels_field] if not isinstance(labels, list): raise UnitxtError( @@ -765,18 +731,29 @@ def reference_fields_to_target_and_references( if len(labels) == 0: labels = [self.empty_label] labels_str = self.labels_separator.join(labels) - return super().reference_fields_to_target_and_references( - {self.labels_field: labels_str} - ) + return {self.labels_field: labels_str} class MultiReferenceTemplate(InputOutputTemplate): references_field: str = "references" random_reference: bool = False + serializer: Serializer = NonPositionalField(default_factory=DynamicSerializer) + + def serialize( + self, data: Dict[str, Any], instance: Dict[str, Any] + ) -> Dict[str, str]: + result = {} + for k, v in data.items(): + if k == self.references_field: + v = [self.serializer.serialize(item, instance) for item in v] + else: + v = self.serializer.serialize(v, instance) + result[k] = v + return result def reference_fields_to_target_and_references( self, reference_fields: Dict[str, object] - ) -> List[str]: + ) -> Tuple[str, List[str]]: references = reference_fields[self.references_field] if not isoftype(references, List[str]): raise UnitxtError( @@ -825,12 +802,12 @@ def extract_span_label_pairs(self, reference_fields): if self.labels_support is None or span[3] in self.labels_support: yield span[2], span[3] - def reference_fields_to_target_and_references( - self, reference_fields: Dict[str, object] - ) -> Dict[str, object]: + def preprocess_reference_fields( + self, reference_fields: Dict[str, Any] + ) -> Dict[str, Any]: span_labels_pairs = self.extract_span_label_pairs(reference_fields) targets = self.span_label_pairs_to_targets(span_labels_pairs) - return super().reference_fields_to_target_and_references({"labels": targets}) + return super().preprocess_reference_fields({"labels": targets}) @abstractmethod def span_label_pairs_to_targets(self, pairs): diff --git a/src/unitxt/types.py b/src/unitxt/types.py index d3817577a..4a76b8002 100644 --- a/src/unitxt/types.py +++ b/src/unitxt/types.py @@ -1,8 +1,9 @@ -from typing import Any, List, Literal, NewType, TypedDict +from typing import Any, List, Literal, NewType, TypedDict, Union from .type_utils import register_type Text = NewType("Text", str) +Number = NewType("Number", Union[float, int]) class Turn(TypedDict): @@ -27,6 +28,7 @@ class Table(TypedDict): register_type(Text) +register_type(Number) register_type(Turn) register_type(Dialog) register_type(Table) diff --git a/tests/library/test_serializers.py b/tests/library/test_serializers.py new file mode 100644 index 000000000..58b506b59 --- /dev/null +++ b/tests/library/test_serializers.py @@ -0,0 +1,138 @@ +from unitxt.serializers import ( + DefaultSerializer, + DialogSerializer, + DynamicSerializer, + NumberQuantizingSerializer, + NumberSerializer, + TableSerializer, +) +from unitxt.types import Dialog, Image, Number, Table, Text, Turn + +from tests.utils import UnitxtTestCase + + +class TestSerializers(UnitxtTestCase): + def setUp(self): + self.default_serializer = DefaultSerializer() + self.dialog_serializer = DialogSerializer() + self.number_serializer = NumberSerializer() + self.table_serializer = TableSerializer() + self.custom_serializer = DynamicSerializer() + self.custom_serializer_with_number = DynamicSerializer( + number=NumberSerializer() + ) + self.number_quantizing_serializer = NumberQuantizingSerializer(quantum=0.2) + + def test_default_serializer_with_string(self): + result = self.default_serializer.serialize("test", {}) + self.assertEqual(result, "test") + + def test_default_serializer_with_number_no_serialization(self): + result = self.default_serializer.serialize(123, {}) + self.assertEqual(result, "123") + + def test_default_serializer_with_number(self): + result = self.custom_serializer_with_number.serialize(123, {}) + self.assertEqual(result, "123") + + def test_default_serializer_with_dict(self): + test_dict = {"key": "value"} + result = self.default_serializer.serialize(test_dict, {}) + self.assertEqual(result, "{'key': 'value'}") + + def test_dialog_serializer(self): + dialog_data = Dialog( + [Turn(role="user", content="Hello"), Turn(role="agent", content="Hi there")] + ) + expected_output = "user: Hello\nagent: Hi there" + result = self.dialog_serializer.serialize(dialog_data, {}) + self.assertEqual(result, expected_output) + + def test_number_serializer_with_integer(self): + number_data = Number(42) + result = self.number_serializer.serialize(number_data, {}) + self.assertEqual(result, "42") + + def test_number_serializer_with_float(self): + number_data = Number(42.123) + result = self.number_serializer.serialize(number_data, {}) + self.assertEqual(result, "42.1") + + def test_number_quantizing_serializer_with_int_quantum(self): + serializer = NumberQuantizingSerializer(quantum=2) + result = serializer.serialize(31, {}) + self.assertEqual(result, "32") + serializer = NumberQuantizingSerializer(quantum=1) + result = serializer.serialize(31, {}) + self.assertEqual(result, "31") + serializer = NumberQuantizingSerializer(quantum=2) + result = serializer.serialize(31.1, {}) + self.assertEqual(result, "32") + serializer = NumberQuantizingSerializer(quantum=1) + result = serializer.serialize(31.1, {}) + self.assertEqual(result, "31") + + def test_number_quantizing_serializer_with_float_quantum(self): + serializer = NumberQuantizingSerializer(quantum=0.2) + result = serializer.serialize(31, {}) + self.assertEqual(result, "31.0") + serializer = NumberQuantizingSerializer(quantum=0.2) + result = serializer.serialize(31.5, {}) + self.assertEqual(result, "31.6") + serializer = NumberQuantizingSerializer(quantum=0.2) + result = serializer.serialize(31.1, {}) + self.assertEqual(result, "31.2") + serializer = NumberQuantizingSerializer(quantum=0.2) + result = serializer.serialize(29.999, {}) + self.assertEqual(result, "30.0") + + def test_table_serializer(self): + table_data = Table(header=["col1", "col2"], rows=[[1, 2], [3, 4]]) + expected_output = "col1,col2\n1,2\n3,4" + result = self.table_serializer.serialize(table_data, {}) + self.assertEqual(result, expected_output) + + def test_custom_serializer_with_image(self): + image_data = Image(image="fake_image_data") + instance = {} + result = self.custom_serializer.serialize(image_data, instance) + self.assertEqual( + result, '' + ) # Using default serialization + self.assertEqual( + instance, {"media": {"images": [{"image": "fake_image_data"}]}} + ) + + def test_custom_serializer_with_table(self): + table_data = Table(header=["col1", "col2"], rows=[[1, 2], [3, 4]]) + expected_output = "col1,col2\n1,2\n3,4" + result = self.custom_serializer.serialize(table_data, {}) + self.assertEqual(result, expected_output) + + def test_custom_serializer_with_dialog(self): + dialog_data = Dialog( + [Turn(role="user", content="Hello"), Turn(role="agent", content="Hi there")] + ) + result = self.custom_serializer.serialize(dialog_data, {}) + self.assertEqual( + result, "user: Hello\nagent: Hi there" + ) # Using default serialization + + def test_custom_serializer_with_text(self): + text_data = Text("Sample text") + result = self.custom_serializer.serialize(text_data, {}) + self.assertEqual( + result, "Sample text" + ) # Since Text is a NewType of str, it should return the string directly + + def test_custom_serializer_with_unrecognized_type(self): + custom_object = {"key": "value"} + result = self.custom_serializer.serialize(custom_object, {}) + self.assertEqual( + result, "{'key': 'value'}" + ) # Should fall back to str conversion + + def test_custom_serializer_with_number(self): + number_data = Number(42) + result = self.custom_serializer.serialize(number_data, {}) + self.assertEqual(result, "42") # Should return the number as a string diff --git a/tests/library/test_templates.py b/tests/library/test_templates.py index 2f343ca2f..bf13cc977 100644 --- a/tests/library/test_templates.py +++ b/tests/library/test_templates.py @@ -11,6 +11,7 @@ MultiLabelTemplate, MultipleChoiceTemplate, MultiReferenceTemplate, + OutputQuantizingTemplate, SpanLabelingJsonTemplate, SpanLabelingTemplate, Template, @@ -132,6 +133,47 @@ def test_multi_label_template(self): check_operator(template, inputs, targets, tester=self) + def test_output_quantizing_template(self): + template = OutputQuantizingTemplate( + input_format="{text}", output_format="{label}", quantum=0.5 + ) + + inputs = [ + { + "input_fields": {"text": "hello world"}, + "reference_fields": {"label": 3.4}, + }, + { + "input_fields": {"text": "hello world"}, + "reference_fields": {"label": 1}, + }, + ] + + targets = [ + { + "input_fields": {"text": "hello world"}, + "reference_fields": {"label": 3.4}, + "source": "hello world", + "target": "3.5", + "references": ["3.5"], + "instruction": "", + "target_prefix": "", + "postprocessors": ["processors.to_string_stripped"], + }, + { + "input_fields": {"text": "hello world"}, + "reference_fields": {"label": 1}, + "source": "hello world", + "target": "1.0", + "references": ["1.0"], + "instruction": "", + "target_prefix": "", + "postprocessors": ["processors.to_string_stripped"], + }, + ] + + check_operator(template, inputs, targets, tester=self) + def test_apply_single_template(self): base_template = MultiLabelTemplate(input_format="{text}") template = ApplySingleTemplate(template=base_template, demos_field="demos") @@ -295,7 +337,9 @@ def test_apply_random_template(self): check_operator(template, inputs, targets, tester=self) - def _test_multi_reference_template(self, target, random_reference): + def _test_multi_reference_template( + self, target, random_reference, references=("Dan", "Yossi") + ): template = MultiReferenceTemplate( input_format="This is my sentence: {text}", references_field="answer", @@ -305,17 +349,17 @@ def _test_multi_reference_template(self, target, random_reference): inputs = [ { "input_fields": {"text": "who was he?"}, - "reference_fields": {"answer": ["Dan", "Yossi"]}, + "reference_fields": {"answer": list(references)}, } ] targets = [ { "input_fields": {"text": "who was he?"}, - "reference_fields": {"answer": ["Dan", "Yossi"]}, + "reference_fields": {"answer": list(references)}, "source": "This is my sentence: who was he?", "target": target, - "references": ["Dan", "Yossi"], + "references": [str(r) for r in references], "instruction": "", "target_prefix": "", "postprocessors": ["processors.to_string_stripped"], @@ -352,9 +396,8 @@ def test_multi_reference_template_with_empty_references(self): ) def test_multi_reference_template_with_wrong_references_type(self): - self._test_multi_reference_template_with_exception( - references=[0, "dkd"], - expected_exception_message="MultiReferenceTemplate requires references field 'answer' to be List[str]. Got answer: [0, 'dkd']", + self._test_multi_reference_template( + target="0", references=[0, "dkd"], random_reference=False ) def test_input_output_template_and_standard_template(self): @@ -474,7 +517,7 @@ def test_input_output_template_and_standard_template(self): err_output_template = InputOutputTemplate(output_format="{label}") err_output_template.process(inputs[0]) self.assertIn( - "Required field 'input_format' of class InputOutputTemplate not set in InputOutputTemplate", + "Required field 'input_format' of class InputFormatTemplate not set in InputOutputTemplate", str(ke.exception), ) @@ -814,7 +857,7 @@ def test_span_labeling_json_template(self): check_operator(template, inputs, targets, tester=self) def test_multiple_choice_template(self): - enumerators = ["capitals", "lowercase", "numbers", "roman"] + enumerators = ["capitals"] # , "lowercase", "numbers", "roman"] firsts = ["A", "a", "1", "I"] seconds = ["B", "b", "2", "II"] for enumerator, first, second in zip(enumerators, firsts, seconds): @@ -843,11 +886,14 @@ def test_multiple_choice_template(self): targets = [ { - "input_fields": {"choices": choices, "text": "example A"}, + "input_fields": { + "choices": choices, + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": choices, "label": 0, - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. True, {second}. False.", "target": f"{first}", @@ -857,11 +903,14 @@ def test_multiple_choice_template(self): "postprocessors": ["processors.to_string_stripped"], }, { - "input_fields": {"choices": choices, "text": "example A"}, + "input_fields": { + "choices": choices, + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": choices, "label": "False", - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. True, {second}. False.", "target": f"{second}", @@ -871,11 +920,14 @@ def test_multiple_choice_template(self): "postprocessors": ["processors.to_string_stripped"], }, { - "input_fields": {"choices": ["True", "small"], "text": "example A"}, + "input_fields": { + "choices": ["True", "small"], + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": ["True", "small"], "label": "small", - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. True, {second}. small.", "target": f"{second}", @@ -933,11 +985,14 @@ def test_multiple_choice_template_with_shuffle(self): targets = [ { - "input_fields": {"choices": ["False", "True"], "text": "example A"}, + "input_fields": { + "choices": ["False", "True"], + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": ["False", "True"], "label": 1, - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. False, {second}. True.", "target": f"{second}", @@ -947,11 +1002,14 @@ def test_multiple_choice_template_with_shuffle(self): "postprocessors": ["processors.to_string_stripped"], }, { - "input_fields": {"choices": ["False", "True"], "text": "example A"}, + "input_fields": { + "choices": ["False", "True"], + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": ["False", "True"], "label": 0, - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. False, {second}. True.", "target": f"{first}", @@ -961,11 +1019,14 @@ def test_multiple_choice_template_with_shuffle(self): "postprocessors": ["processors.to_string_stripped"], }, { - "input_fields": {"choices": [temp, "True"], "text": "example A"}, + "input_fields": { + "choices": [temp, "True"], + "text": "example A", + "options": [f"{first}", f"{second}"], + }, "reference_fields": { "choices": [temp, "True"], "label": 0, - "options": [f"{first}", f"{second}"], }, "source": f"Text: example A, Choices: {first}. {temp}, {second}. True.", "target": f"{first}", @@ -993,25 +1054,25 @@ def test_multiple_choice_template_with_shuffle(self): def test_key_val_template_simple(self): template = KeyValTemplate() + instance = { + "input_fields": {"hello": "world", "str_list": ["djjd", "djjd"]}, + "reference_fields": {"label": "negative"}, + } + result = template.process_instance(instance) - dic = {"hello": "world", "str_list": ["djjd", "djjd"]} - - result = template.process_dict( - dic, key_val_sep=": ", pairs_sep=", ", use_keys=True - ) - target = "hello: world, str_list: djjd, djjd" - self.assertEqual(result, target) + self.assertEqual(result["target"], "negative") + self.assertEqual(result["source"], "hello: world, str_list: djjd, djjd") def test_key_val_template_int_list(self): template = KeyValTemplate() + instance = { + "input_fields": {"hello": "world", "int_list": [0, 1]}, + "reference_fields": {"label": "negative"}, + } + result = template.process_instance(instance) - dic = {"hello": "world", "int_list": [0, 1]} - - result = template.process_dict( - dic, key_val_sep=": ", pairs_sep=", ", use_keys=True - ) - target = "hello: world, int_list: 0, 1" - self.assertEqual(result, target) + self.assertEqual(result["target"], "negative") + self.assertEqual(result["source"], "hello: world, int_list: 0, 1") def test_render_template(self): instance = {