Skip to content

Commit

Permalink
add deserialization + tests to TensorProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Aug 8, 2023
1 parent afd715f commit 1e92a17
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 110 deletions.
10 changes: 9 additions & 1 deletion src/tensorshare/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
"""Converter is used to convert any tensors format to safetensors format"""
"""Serialization module for tensorshare."""

from .constants import Backend
from .processor import TensorProcessor

__all__ = [
"Backend",
"TensorProcessor",
]
21 changes: 18 additions & 3 deletions src/tensorshare/serialization/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Constants for tensorshare."""
"""Constants for tensorshare.serialization."""

from enum import Enum
from typing import Callable, Dict, OrderedDict, Union
Expand All @@ -11,6 +11,11 @@
import torch

from tensorshare.serialization.utils import (
deserialize_flax,
deserialize_numpy,
deserialize_paddle,
# deserialize_tensorflow,
deserialize_torch,
serialize_flax,
serialize_numpy,
serialize_paddle,
Expand All @@ -29,8 +34,8 @@ class Backend(str, Enum):
TORCH = "torch"


# Mapping between backend and conversion function
BACKENDS_FUNC_MAPPING: Dict[Backend, Callable] = OrderedDict(
# Mapping between backend and serialization function
BACKEND_SER_FUNC_MAPPING: Dict[Backend, Callable] = OrderedDict(
[
(Backend.FLAX, serialize_flax),
(Backend.NUMPY, serialize_numpy),
Expand All @@ -39,6 +44,16 @@ class Backend(str, Enum):
(Backend.TORCH, serialize_torch),
]
)
# Mapping between backend and deserialization function
BACKEND_DESER_FUNC_MAPPING: Dict[Backend, Callable] = OrderedDict(
[
(Backend.FLAX, deserialize_flax),
(Backend.NUMPY, deserialize_numpy),
(Backend.PADDLEPADDLE, deserialize_paddle),
# (Backend.TENSORFLOW, deserialize_tensorflow),
(Backend.TORCH, deserialize_torch),
]
)
# Mapping between tensor type and backend
TENSOR_TYPE_MAPPING: Dict[
Union[jaxlib.xla_extension.ArrayImpl, np.ndarray, paddle.Tensor, torch.Tensor],
Expand Down
72 changes: 61 additions & 11 deletions src/tensorshare/serialization/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from pydantic import ByteSize

from tensorshare.serialization.constants import (
BACKENDS_FUNC_MAPPING,
BACKEND_DESER_FUNC_MAPPING,
BACKEND_SER_FUNC_MAPPING,
TENSOR_TYPE_MAPPING,
Backend,
)
Expand Down Expand Up @@ -50,7 +51,7 @@ def _infer_backend(
raise TypeError(
f"Unsupported tensor type {first_tensor_type}. Supported types are"
f" {list(TENSOR_TYPE_MAPPING.keys())}\nThe supported backends are"
f" {list(BACKENDS_FUNC_MAPPING.keys())}."
f" {list(BACKEND_SER_FUNC_MAPPING.keys())}."
)

return TENSOR_TYPE_MAPPING[first_tensor_type]
Expand All @@ -67,8 +68,9 @@ def serialize(
) -> Tuple[bytes, ByteSize]:
"""Serialize a dictionary of tensors to a TensorShare object.
This method will convert a dictionary of tensors to a TensorShare object using the specified backend
if provided, otherwise it will try to infer the backend from the tensors format.
This method will convert a dictionary of tensors to a tuple containing the serialized tensors
and the size of the serialized tensors. It will use the backend if provided, otherwise it will
try to infer the backend from the tensors format.
Args:
tensors (Dict[str, Union[Array, np.ndarray, paddle.Tensor, torch.Tensor]]):
Expand Down Expand Up @@ -113,18 +115,66 @@ def serialize(
elif not isinstance(backend, Backend):
raise TypeError(
"Backend must be a string or an instance of Backend enum, got"
f" `{type(backend)}` instead. Use `tensorshare.schema.Backend` to"
" access the Backend enum. If you don't specify a backend, it will"
" be inferred from the tensors format."
f" `{type(backend)}` instead. Use"
" `tensorshare.serialization.Backend` to access the Backend enum."
" If you don't specify a backend, it will be inferred from the"
" tensors format."
)
else:
_backend = _infer_backend(tensors)

_tensors = BACKENDS_FUNC_MAPPING[_backend](tensors, metadata=metadata)
_tensors = BACKEND_SER_FUNC_MAPPING[_backend](tensors, metadata=metadata)

return _tensors, ByteSize(len(_tensors))

@staticmethod
def deserialize() -> None:
""""""
pass
def deserialize(
data: bytes,
backend: Union[str, Backend],
) -> Dict[str, Union[Array, np.ndarray, paddle.Tensor, torch.Tensor]]:
"""Deserialize bytes to a dictionary of tensors.
This method will convert TensorShare.tensors to a dictionary of tensors with their name as key.
The backend must be specified in order to deserialize the data.
Args:
data (bytes):
The serialized tensors to deserialize.
backend (Union[str, Backend]):
The backend to use for the conversion. Must be one of the following:
- Backend.FLAX or 'flax'
- Backend.NUMPY or 'numpy'
- Backend.PADDLEPADDLE or 'paddlepaddle'
- Backend.TENSORFLOW or 'tensorflow'
- Backend.TORCH or 'torch'
Raises:
TypeError: If data is not bytes.
TypeError: If backend is not a string or an instance of Backend enum.
KeyError: If backend is not one of the supported backends.
Returns:
Dict[str, Union[Array, np.ndarray, paddle.Tensor, torch.Tensor]]:
A dictionary of tensors in the specified backend with their name as key.
"""
if not isinstance(data, bytes):
raise TypeError(f"Data must be bytes, got `{type(data)}` instead.")

if isinstance(backend, str):
try:
_backend = Backend[backend.upper()]
except KeyError as e:
raise KeyError(
f"Invalid backend `{backend}`. Must be one of"
f" {list(Backend.__members__)}."
) from e
elif not isinstance(backend, Backend):
raise TypeError(
"Backend must be a string or an instance of Backend enum, got"
f" `{type(backend)}` instead. Use `tensorshare.serialization.Backend`"
" to access the Backend enum."
)

tensors = BACKEND_DESER_FUNC_MAPPING[_backend](data)

return tensors # type: ignore
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.fixture
def converted_fixed_numpy_tensors() -> bytes:
def serialized_fixed_numpy_tensors() -> bytes:
"""Return a serialized numpy tensor."""
_tensor = {"embeddings": np.zeros((2, 2))}
return serialize_numpy(_tensor)
Expand Down
95 changes: 95 additions & 0 deletions tests/serialization/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Test the tensorshare.serialization.constants module."""

import jaxlib
import numpy as np
import paddle

# import tensorflow as tf
import torch

from tensorshare.serialization.constants import (
BACKEND_DESER_FUNC_MAPPING,
BACKEND_SER_FUNC_MAPPING,
TENSOR_TYPE_MAPPING,
Backend,
)
from tensorshare.serialization.utils import (
deserialize_flax,
deserialize_numpy,
deserialize_paddle,
# deserialize_tensorflow,
deserialize_torch,
serialize_flax,
serialize_numpy,
serialize_paddle,
# serialize_tensorflow,
serialize_torch,
)


class TestBackendEnum:
"""Tests for the backend enum."""

def test_backend_enum(self) -> None:
"""Test the backend enum."""
assert len(Backend) == 4
assert Backend.FLAX == "flax"
assert Backend.NUMPY == "numpy"
assert Backend.PADDLEPADDLE == "paddlepaddle"
# assert Backend.TENSORFLOW == "tensorflow"
assert Backend.TORCH == "torch"


class TestProcessorConstants:
"""Test the backend enum and associated constants."""

def test_backend_ser_func_mapping(self) -> None:
"""Test the backends function mapping."""
assert isinstance(BACKEND_SER_FUNC_MAPPING, dict)
assert len(BACKEND_SER_FUNC_MAPPING) > 0

assert "flax" in BACKEND_SER_FUNC_MAPPING
assert "numpy" in BACKEND_SER_FUNC_MAPPING
assert "paddlepaddle" in BACKEND_SER_FUNC_MAPPING
# assert "tensorflow" in BACKEND_SER_FUNC_MAPPING
assert "torch" in BACKEND_SER_FUNC_MAPPING

assert BACKEND_SER_FUNC_MAPPING["flax"] == serialize_flax
assert BACKEND_SER_FUNC_MAPPING["numpy"] == serialize_numpy
assert BACKEND_SER_FUNC_MAPPING["paddlepaddle"] == serialize_paddle
# assert BACKEND_SER_FUNC_MAPPING["tensorflow"] == serialize_tensorflow
assert BACKEND_SER_FUNC_MAPPING["torch"] == serialize_torch

def test_backend_deser_func_mapping(self) -> None:
"""Test the backend deserialization function mapping."""
assert isinstance(BACKEND_DESER_FUNC_MAPPING, dict)
assert len(BACKEND_DESER_FUNC_MAPPING) > 0

assert "flax" in BACKEND_DESER_FUNC_MAPPING
assert "numpy" in BACKEND_DESER_FUNC_MAPPING
assert "paddlepaddle" in BACKEND_DESER_FUNC_MAPPING
# assert "tensorflow" in BACKEND_DESER_FUNC_MAPPING
assert "torch" in BACKEND_DESER_FUNC_MAPPING

assert BACKEND_DESER_FUNC_MAPPING["flax"] == deserialize_flax
assert BACKEND_DESER_FUNC_MAPPING["numpy"] == deserialize_numpy
assert BACKEND_DESER_FUNC_MAPPING["paddlepaddle"] == deserialize_paddle
# assert BACKEND_DESER_FUNC_MAPPING["tensorflow"] == deserialize_tensorflow
assert BACKEND_DESER_FUNC_MAPPING["torch"] == deserialize_torch

def test_tensor_type_mapping(self) -> None:
"""Test the tensor type mapping."""
assert isinstance(TENSOR_TYPE_MAPPING, dict)
assert len(TENSOR_TYPE_MAPPING) > 0

assert jaxlib.xla_extension.ArrayImpl in TENSOR_TYPE_MAPPING
assert np.ndarray in TENSOR_TYPE_MAPPING
assert paddle.Tensor in TENSOR_TYPE_MAPPING
# assert tf.Tensor in TENSOR_TYPE_MAPPING
assert torch.Tensor in TENSOR_TYPE_MAPPING

assert TENSOR_TYPE_MAPPING[jaxlib.xla_extension.ArrayImpl] == "flax"
assert TENSOR_TYPE_MAPPING[np.ndarray] == "numpy"
assert TENSOR_TYPE_MAPPING[paddle.Tensor] == "paddlepaddle"
# assert TENSOR_TYPE_MAPPING[tf.Tensor] == "tensorflow"
assert TENSOR_TYPE_MAPPING[torch.Tensor] == "torch"
Loading

0 comments on commit 1e92a17

Please sign in to comment.