Skip to content

Commit

Permalink
Add prepare_tensors_to_dict (#39)
Browse files Browse the repository at this point in the history
* add prepare_tensors_to_dict

* add the function to documentation

* add the last formatting case

* add pydantic version check for match error case

* fix linter
  • Loading branch information
chainyo committed Aug 19, 2023
1 parent 39e5bc0 commit 399cbd5
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 2 deletions.
Empty file removed docs/api/default_response.md
Empty file.
5 changes: 5 additions & 0 deletions docs/api/prepare_tensors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Utils functions

These functions are used to prepare tensors to be used in the `TensorShare` class.

::: src.tensorshare.utils.prepare_tensors_to_dict
15 changes: 15 additions & 0 deletions docs/usage/tensorshare.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,18 @@ You must have the desired backend installed in your project to deserialize the t
# Get a dict of tensorflow.Tensor
tensors_tensorflow = ts.to_tensors(backend="tensorflow") # or backend=Backend.TENSORFLOW
```

## Lazy tensors formatting

If you don't want to handle the formatting of the tensors yourself, we provide
an utils function to prepare tensors to be used in the `TensorShare` class.

``` python
from tensorshare import prepare_tensors_to_dict

tensors_in_any_format: Any = ...
tensors = prepare_tensors_to_dict(tensors_in_any_format)
>>> {"embeddings_0": ..., "embeddings_1": ..., ...}
```

Check the [utils documentation](../api/prepare_tensors) for more information.
1 change: 1 addition & 0 deletions src/tensorshare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"schema": ["DefaultResponse", "TensorShare", "TensorShareServer"],
"serialization": ["Backend", "TensorProcessor", "TensorType"],
"server": ["create_tensorshare_router", "create_async_tensorshare_router"],
"utils": ["prepare_tensors_to_dict"],
}

sys.modules[__name__] = _LazyModule(
Expand Down
36 changes: 36 additions & 0 deletions src/tensorshare/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Utility functions for the Client module."""

import types
from typing import Any, Dict

from pydantic import ByteSize

from tensorshare.schema import TensorShare
Expand All @@ -22,3 +25,36 @@ def fake_tensorshare_data() -> TensorShare:
),
size=ByteSize(84),
)


def prepare_tensors_to_dict(data: Any) -> Dict[str, Any]:
"""Prepare the tensors from any framework to a dictionary.
This function is used to prepare the tensors from any framework to a
dictionary. The dictionary will be used to create a TensorShare object.
Use this function as a lazy way to prepare the tensors.
If the data is a dictionary, the keys will be used as the tensor names.
If the data is a list, a set, a tuple or any other iterable, the tensors
will be named "embeddings_{i}" where i is the index of the tensor in the
iterable.
Args:
data (Any):
The tensors to prepare. It can be a dictionary, a generator,
a list, a set, a tuple or any other single item.
Returns:
Dict[str, Any]: The prepared data.
"""
if isinstance(data, dict):
return {str(key): tensor for key, tensor in data.items()}

elif isinstance(data, (types.GeneratorType, list, set, tuple)) or hasattr(
data, "__iter__"
):
# The `hasattr(data, "__iter__")` check ensures that even if the input is some other kind of iterable
# (not explicitly listed), it will still be processed here.
return {f"embeddings_{i}": tensor for i, tensor in enumerate(data)}

else:
return {"embeddings": data}
4 changes: 3 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpy as np
import paddle
import pkg_resources
import pytest

# import tensorflow as tf
Expand Down Expand Up @@ -385,13 +386,14 @@ def test_schema_with_mock_response(self) -> None:

def test_schema_validation_error(self) -> None:
"""Test the schema validation error."""
pydantic_version = pkg_resources.get_distribution("pydantic").version[:-2]
with pytest.raises(
ValueError,
match=re.escape(
"1 validation error for TensorShareServer\nurl\n URL scheme should be"
" 'http' or 'https' [type=url_scheme, input_value='localhost:8000',"
" input_type=str]\n For further information visit"
" https://errors.pydantic.dev/2.1/v/url_scheme"
f" https://errors.pydantic.dev/{pydantic_version}/v/url_scheme"
),
):
TensorShareServer(
Expand Down
42 changes: 41 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Test the utils functions of the Client module."""

from typing import Any, Dict

import pytest
from pydantic import ByteSize

from tensorshare.schema import TensorShare
from tensorshare.utils import fake_tensorshare_data
from tensorshare.utils import fake_tensorshare_data, prepare_tensors_to_dict


class TestClientUtils:
Expand All @@ -20,3 +23,40 @@ def test_fake_tensorshare_data(self) -> None:
b" \x00\x00\x00\x00"
)
assert fake_ts.size == ByteSize(84)

@pytest.mark.parametrize(
"data, expected",
[
({"embeddings": [1, 2, 3]}, {"embeddings": [1, 2, 3]}),
(
{"embeddings": [1, 2, 3], "labels": [1, 2, 3]},
{"embeddings": [1, 2, 3], "labels": [1, 2, 3]},
),
(
[[1, 2, 3], [1, 2, 3]],
{"embeddings_0": [1, 2, 3], "embeddings_1": [1, 2, 3]},
),
(
([1, 2, 3], [1, 2, 3]),
{"embeddings_0": [1, 2, 3], "embeddings_1": [1, 2, 3]},
),
(
{0: [1, 2, 3], 1: [1, 2, 3]},
{"0": [1, 2, 3], "1": [1, 2, 3]},
),
(
(lambda: (i for i in [[1, 2, 3], [1, 2, 3], [1, 2, 3]])),
{
"embeddings_0": [1, 2, 3],
"embeddings_1": [1, 2, 3],
"embeddings_2": [1, 2, 3],
},
),
(1, {"embeddings": 1}),
],
)
def test_prepare_tensors_to_dict(self, data: Any, expected: Dict[str, Any]) -> None:
"""Test the prepare_tensors_to_dict function."""
data = data() if callable(data) else data

assert prepare_tensors_to_dict(data) == expected

0 comments on commit 399cbd5

Please sign in to comment.