Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converters #1

Merged
merged 6 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = []
dependencies = [
"pydantic>=2.1.1",
"safetensors>=0.3.1",
]
description = "🤝 Trade any tensors over the network"
dynamic = ["version"]
keywords = ["tensors", "machine learning", "deep learning", "artificial intelligence"]
Expand All @@ -32,6 +35,11 @@ Source = "https://github.com/chainyo/tensorshare"
path = "src/tensorshare/__init__.py"

[project.optional-dependencies]
torch = ["torch>=1.10"]
numpy = ["numpy>=1.21.6"]
# tensorflow = ["tensorflow"] Disable until 2.14 is released for pydantic v2 compatibility
jax = ["flax>=0.6.3", "jax<0.4.14", "jaxlib<0.4.14"]
paddlepaddle = ["paddlepaddle>=2.4.1"]
docs = [
"mkdocs~=1.4.0",
"mkdocs-material~=8.5.4",
Expand All @@ -44,7 +52,42 @@ quality = [
"pre-commit~=2.20.0",
]
tests = [
"flax>=0.6.3",
"jax<0.4.14",
"jaxlib<0.4.14",
"numpy>=1.21.6",
"paddlepaddle>=2.4.1",
"pytest~=7.1.2",
# "tensorflow",
"torch>=1.10",
]
# To install all optional frameworks dependencies: pip install tensorshare[all]
all = [
"flax>=0.6.3",
"jax<0.4.14",
"jaxlib<0.4.14",
"numpy>=1.21.6",
"paddlepaddle>=2.4.1",
# "tensorflow",
"torch>=1.10",
]
# To install all the dependencies for development purposes: pip install -e ".[dev]"
dev = [
"black~=22.10.0",
"flax>=0.6.3",
"jax<0.4.14",
"jaxlib<0.4.14",
"mkdocs~=1.4.0",
"mkdocs-material~=8.5.4",
"mkdocs-git-revision-date-localized-plugin~=1.1.0",
"mkdocstrings[python]~=0.19.0",
"pre-commit~=2.20.0",
"pytest~=7.1.2",
"ruff~=0.0.263",
"numpy>=1.21.6",
"paddlepaddle>=2.4.1",
# "tensorflow",
"torch>=1.10",
]

[tool.hatch.envs.quality]
Expand Down
5 changes: 5 additions & 0 deletions src/tensorshare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@

__author__ = "Thomas Chaigneau <t.chaigneau.tc@gmail.com>"
__version__ = "0.0.1"


from tensorshare.schema import TensorShare

__all__ = ["TensorShare"]
14 changes: 14 additions & 0 deletions src/tensorshare/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023 Thomas Chaigneau. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converter is used to convert any tensors format to safetensors format"""
126 changes: 126 additions & 0 deletions src/tensorshare/converter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2023 Thomas Chaigneau. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils functions for the converter module."""

from typing import Dict, Optional

import numpy as np
import paddle

# import tensorflow as tf
import torch
from jax import Array
from safetensors.flax import save as flax_save
from safetensors.numpy import save as np_save
from safetensors.paddle import save as paddle_save

# from safetensors.tensorflow import save as tf_save
from safetensors.torch import save as torch_save

from tensorshare.import_utils import require_backend


@require_backend("flax", "jax", "jaxlib")
def convert_flax_to_safetensors(
tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None
) -> bytes:
"""
Convert flax tensors to safetensors format using `safetensors.flax.save`.

Args:
tensors (Dict[str, Array]):
Flax tensors stored in a dictionary with their name as key.
metadata (Optional[Dict[str, str]], optional):
Metadata to add to the safetensors file. Defaults to None.

Returns:
bytes: Tensors formatted with their metadata if any.
"""
return flax_save(tensors, metadata=metadata)


@require_backend("numpy")
def convert_numpy_to_safetensors(
tensors: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None
) -> bytes:
"""
Convert numpy tensors to safetensors format using `safetensors.numpy.save`.

Args:
tensors (Dict[str, np.ndarray]):
Numpy tensors stored in a dictionary with their name as key.
metadata (Optional[Dict[str, str]], optional):
Metadata to add to the safetensors file. Defaults to None.

Returns:
bytes: Tensors formatted with their metadata if any.
"""
return np_save(tensors, metadata=metadata)


@require_backend("paddlepaddle")
def convert_paddle_to_safetensors(
tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None
) -> bytes:
"""
Convert paddle tensors to safetensors format using `safetensors.paddle.save`.

Args:
tensors (Dict[str, paddle.Tensor]):
Paddle tensors stored in a dictionary with their name as key.
metadata (Optional[Dict[str, str]], optional):
Metadata to add to the safetensors file. Defaults to None.

Returns:
bytes: Tensors formatted with their metadata if any.
"""
return paddle_save(tensors, metadata=metadata)


# @require_backend("tensorflow")
# def convert_tensorflow_to_safetensors(
# tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None
# ) -> bytes:
# """
# Convert tensorflow tensors to safetensors format using `safetensors.tensorflow.save`.

# Args:
# tensors (Dict[str, tf.Tensor]):
# Tensorflow tensors stored in a dictionary with their name as key.
# metadata (Optional[Dict[str, str]], optional):
# Metadata to add to the safetensors file. Defaults to None.

# Returns:
# bytes: Tensors formatted with their metadata if any.
# """
# return tf_save(tensors, metadata=metadata)


@require_backend("torch")
def convert_torch_to_safetensors(
tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> bytes:
"""
Convert torch tensors to safetensors format using `safetensors.torch.save`.

Args:
tensors (Dict[str, torch.Tensor]):
Torch tensors stored in a dictionary with their name as key.
metadata (Optional[Dict[str, str]], optional):
Metadata to add to the safetensors file. Defaults to None.

Returns:
bytes: Tensors formatted with their metadata if any.
"""
return torch_save(tensors, metadata=metadata)
90 changes: 90 additions & 0 deletions src/tensorshare/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2023 Thomas Chaigneau. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Regroup all the utilities for importing modules/libraries in a single file."""

import importlib
from functools import lru_cache
from typing import Callable


# This function is taken from the Hugging Face Transformers library:
# https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py#L41
@lru_cache(maxsize=None)
def _is_package_available(package_name: str) -> bool:
"""
Check if package is available.

Args:
package_name (str):
Package name to check availability for.

Returns:
bool: Whether the package is available.
"""
if importlib.util.find_spec(package_name) is not None:
return True

return False


@lru_cache(maxsize=None)
def _is_padddle_available() -> bool:
"""
Check if paddle is available.

Returns:
bool: Whether the package is available.
"""
try:
import paddle # noqa: F401

return True
except ImportError:
return False


def require_backend(*backend_names: str) -> Callable[[], Callable]:
"""
A decorator that checks if the required backends are available.

Args:
*check_funcs (Callable): Functions that return True if the backend is available.

Returns:
Callable[[], Callable]: Decorator.
"""

def decorator(func: Callable) -> Callable[[], Callable]:
"""Decorator."""

def wrapper(*args, **kwargs) -> Callable:
"""Wrapper."""
for backend_name in backend_names:
if backend_name == "paddlepaddle":
if not _is_padddle_available():
raise ImportError(
f"`{func.__name__}` requires `paddle` to be installed."
)
else:
if not _is_package_available(backend_name):
raise ImportError(
f"`{func.__name__}` requires `{backend_name}` to be"
" installed."
)

return func(*args, **kwargs)

return wrapper

return decorator
23 changes: 23 additions & 0 deletions src/tensorshare/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023 Thomas Chaigneau. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic schemas to enable resilient and secure tensor sharing."""

from pydantic import BaseModel, ByteSize


class TensorShare(BaseModel):
"""Base model for tensor sharing."""

tensors: bytes
size: ByteSize
Loading