Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tests to cover all compatibility layer cases. #1937

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 67 additions & 58 deletions bittensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,54 @@
from bittensor.utils.registration import torch, use_torch
from pydantic import ConfigDict, BaseModel, Field, field_validator

NUMPY_DTYPES = {
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"uint8": np.uint8,
"int16": np.int16,
"int8": np.int8,
"int32": np.int32,
"int64": np.int64,
"bool": bool,
}

if use_torch():
TORCH_DTYPES = {
"torch.float16": torch.float16,
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.uint8": torch.uint8,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
"torch.bool": torch.bool,
}


def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str:

class DTypes(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.torch: bool = False
self.update(
{
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"uint8": np.uint8,
"int16": np.int16,
"int8": np.int8,
"int32": np.int32,
"int64": np.int64,
"bool": bool,
}
)

def __getitem__(self, key):
self._add_torch()
return super().__getitem__(key)

def __contains__(self, key):
self._add_torch()
return super().__contains__(key)

def _add_torch(self):
if self.torch is False:
torch_dtypes = {
"torch.float16": torch.float16,
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.uint8": torch.uint8,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
"torch.bool": torch.bool,
}
self.update(torch_dtypes)
self.torch = True


dtypes = DTypes()


def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> Optional[str]:
"""
Casts the raw value to a string representing the
`numpy data type <https://numpy.org/doc/stable/user/basics.types.html>`_, or the
Expand All @@ -67,29 +88,24 @@ def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str:
"""
if not raw:
return None
if isinstance(raw, np.dtype):
return NUMPY_DTYPES[raw]
elif use_torch():
if isinstance(raw, torch.dtype):
return TORCH_DTYPES[raw]
if use_torch() and isinstance(raw, torch.dtype):
return dtypes[raw]
elif isinstance(raw, np.dtype):
return dtypes[raw]
elif isinstance(raw, str):
if use_torch():
assert (
raw in TORCH_DTYPES
), f"{raw} not a valid torch type in dict {TORCH_DTYPES}"
assert raw in dtypes, f"{raw} not a valid torch type in dict {dtypes}"
return raw
else:
assert (
raw in NUMPY_DTYPES
), f"{raw} not a valid numpy type in dict {NUMPY_DTYPES}"
assert raw in dtypes, f"{raw} not a valid numpy type in dict {dtypes}"
return raw
else:
raise Exception(
f"{raw} of type {type(raw)} does not have a valid type in Union[None, numpy.dtype, torch.dtype, str]"
)


def cast_shape(raw: Union[None, List[int], str]) -> str:
def cast_shape(raw: Union[None, List[int], str]) -> Optional[Union[str, list]]:
"""
Casts the raw value to a string representing the tensor shape.

Expand All @@ -105,9 +121,7 @@ def cast_shape(raw: Union[None, List[int], str]) -> str:
if not raw:
return None
elif isinstance(raw, list):
if len(raw) == 0:
return raw
elif isinstance(raw[0], int):
if len(raw) == 0 or isinstance(raw[0], int):
return raw
else:
raise Exception(f"{raw} list elements are not of type int")
Expand All @@ -124,7 +138,7 @@ class tensor:
def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]):
if isinstance(tensor, list) or isinstance(tensor, np.ndarray):
tensor = torch.tensor(tensor) if use_torch() else np.array(tensor)
return Tensor.serialize(tensor=tensor)
return Tensor.serialize(tensor_=tensor)


class Tensor(BaseModel):
Expand Down Expand Up @@ -170,40 +184,35 @@ def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]:
# Reshape does not work for (0) or [0]
if not (len(shape) == 1 and shape[0] == 0):
torch_object = torch_object.reshape(shape)
return torch_object.type(TORCH_DTYPES[self.dtype])
return torch_object.type(dtypes[self.dtype])
else:
# Reshape does not work for (0) or [0]
if not (len(shape) == 1 and shape[0] == 0):
numpy_object = numpy_object.reshape(shape)
return numpy_object.astype(NUMPY_DTYPES[self.dtype])
return numpy_object.astype(dtypes[self.dtype])

@staticmethod
def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor":
def serialize(tensor_: Union["np.ndarray", "torch.Tensor"]) -> "Tensor":
"""
Serializes the given tensor.

Args:
tensor (np.array or torch.Tensor): The tensor to serialize.
tensor_ (np.array or torch.Tensor): The tensor to serialize.

Returns:
Tensor: The serialized tensor.

Raises:
Exception: If the serialization process encounters an error.
"""
dtype = str(tensor.dtype)
shape = list(tensor.shape)
dtype = str(tensor_.dtype)
shape = list(tensor_.shape)
if len(shape) == 0:
shape = [0]
if use_torch():
torch_numpy = tensor.cpu().detach().numpy().copy()
data_buffer = base64.b64encode(
msgpack.packb(torch_numpy, default=msgpack_numpy.encode)
).decode("utf-8")
else:
data_buffer = base64.b64encode(
msgpack.packb(tensor, default=msgpack_numpy.encode)
).decode("utf-8")
tensor__ = tensor_.cpu().detach().numpy().copy() if use_torch() else tensor_
data_buffer = base64.b64encode(
msgpack.packb(tensor__, default=msgpack_numpy.encode)
).decode("utf-8")
return Tensor(buffer=data_buffer, shape=shape, dtype=dtype)

# Represents the tensor buffer data.
Expand Down
2 changes: 1 addition & 1 deletion bittensor/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def normalize_max_weight(
if estimation.max() <= limit:
return weights / weights.sum()

# Find the cumlative sum and sorted tensor
# Find the cumulative sum and sorted tensor
cumsum = np.cumsum(estimation, 0)

# Determine the index of cutoff
Expand Down
67 changes: 67 additions & 0 deletions tests/unit_tests/test_chain_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import bittensor
import torch
from bittensor.chain_data import AxonInfo, ChainDataType, DelegateInfo, NeuronInfo

SS58_FORMAT = bittensor.__ss58_format__
Expand Down Expand Up @@ -204,6 +205,36 @@ def test_to_parameter_dict(axon_info, test_case):
assert result[key] == value, f"Test case: {test_case}"


@pytest.mark.parametrize(
"axon_info, test_case",
[
(
AxonInfo(
version=1,
ip="127.0.0.1",
port=8080,
ip_type=4,
hotkey="hot",
coldkey="cold",
),
"ID_to_parameter_dict",
),
],
)
def test_to_parameter_dict_torch(
axon_info,
test_case,
force_legacy_torch_compat_api,
):
result = axon_info.to_parameter_dict()

# Assert
assert isinstance(result, torch.nn.ParameterDict)
for key, value in axon_info.__dict__.items():
assert key in result
assert result[key] == value, f"Test case: {test_case}"


@pytest.mark.parametrize(
"parameter_dict, expected, test_case",
[
Expand Down Expand Up @@ -236,6 +267,42 @@ def test_from_parameter_dict(parameter_dict, expected, test_case):
assert result == expected, f"Test case: {test_case}"


@pytest.mark.parametrize(
"parameter_dict, expected, test_case",
[
(
torch.nn.ParameterDict(
{
"version": 1,
"ip": "127.0.0.1",
"port": 8080,
"ip_type": 4,
"hotkey": "hot",
"coldkey": "cold",
}
),
AxonInfo(
version=1,
ip="127.0.0.1",
port=8080,
ip_type=4,
hotkey="hot",
coldkey="cold",
),
"ID_from_parameter_dict",
),
],
)
def test_from_parameter_dict_torch(
parameter_dict, expected, test_case, force_legacy_torch_compat_api
):
# Act
result = AxonInfo.from_parameter_dict(parameter_dict)

# Assert
assert result == expected, f"Test case: {test_case}"


def create_neuron_info_decoded(
hotkey,
coldkey,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_set_metagraph_attributes(mock_environment):
metagraph.consensus,
np.array([neuron.consensus for neuron in neurons], dtype=np.float32),
)
== True
is True
)
# Similarly for other attributes...

Expand Down
Loading