Skip to content

Commit

Permalink
Add serializers to templates and reorganize and unite all templates (#…
Browse files Browse the repository at this point in the history
…1195)

* Add serializers to templates and reorganize and unite all templates

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fixes

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

---------

Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Sep 8, 2024
1 parent 19accb4 commit 10d26ed
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 187 deletions.
1 change: 1 addition & 0 deletions src/unitxt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .recipe import __file__ as _
from .register import __file__ as _
from .schema import __file__ as _
from .serializers import __file__ as _
from .settings_utils import get_constants
from .span_lableing_operators import __file__ as _
from .split_utils import __file__ as _
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .recipe import __file__ as _
from .register import __file__ as _
from .schema import __file__ as _
from .serializers import __file__ as _
from .settings_utils import get_constants
from .span_lableing_operators import __file__ as _
from .split_utils import __file__ as _
Expand Down
112 changes: 112 additions & 0 deletions src/unitxt/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import csv
import io
from abc import abstractmethod
from typing import Any, Dict, Union

from .operators import InstanceFieldOperator
from .type_utils import isoftype
from .types import Dialog, Image, Number, Table, Text


class Serializer(InstanceFieldOperator):
def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
return self.serialize(value, instance)

@abstractmethod
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
pass


class DefaultSerializer(Serializer):
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
return str(value)


class DefaultListSerializer(Serializer):
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
if isinstance(value, list):
return ", ".join(str(item) for item in value)
return str(value)


class DialogSerializer(Serializer):
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
# Convert the Dialog into a string representation, typically combining roles and content
return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value)


class NumberSerializer(Serializer):
def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
# Check if the value is an integer or a float
if isinstance(value, int):
return str(value)
# For floats, format to one decimal place
if isinstance(value, float):
return f"{value:.1f}"
raise ValueError("Unsupported type for NumberSerializer")


class NumberQuantizingSerializer(NumberSerializer):
quantum: Union[float, int] = 0.1

def serialize(self, value: Number, instance: Dict[str, Any]) -> str:
if isoftype(value, Number):
quantized_value = round(value / self.quantum) / (1 / self.quantum)
if isinstance(self.quantum, int):
quantized_value = int(quantized_value)
return str(quantized_value)
raise ValueError("Unsupported type for NumberSerializer")


class TableSerializer(Serializer):
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")

# Write the header and rows to the CSV writer
writer.writerow(value["header"])
writer.writerows(value["rows"])

# Retrieve the CSV string
return output.getvalue().strip()


class ImageSerializer(Serializer):
def serialize(self, value: Image, instance: Dict[str, Any]) -> str:
if "media" not in instance:
instance["media"] = {}
if "images" not in instance["media"]:
instance["media"]["images"] = []
idx = len(instance["media"]["images"])
instance["media"]["images"].append(value)
return f'<img src="media/images/{idx}">'


class DynamicSerializer(Serializer):
image: Serializer = ImageSerializer()
number: Serializer = DefaultSerializer()
table: Serializer = TableSerializer()
dialog: Serializer = DialogSerializer()
text: Serializer = DefaultSerializer()
list: Serializer = DefaultSerializer()

def serialize(self, value: Any, instance: Dict[str, Any]) -> Any:
if isoftype(value, Image):
return self.image.serialize(value, instance)

if isoftype(value, Table):
return self.table.serialize(value, instance)

if isoftype(value, Dialog) and len(value) > 0:
return self.dialog.serialize(value, instance)

if isoftype(value, Text):
return self.text.serialize(value, instance)

if isoftype(value, Number):
return self.number.serialize(value, instance)

if isinstance(value, list):
return self.list.serialize(value, instance)

return str(value)
Loading

0 comments on commit 10d26ed

Please sign in to comment.