Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require required hash fields (security improvement) #1818

9 changes: 0 additions & 9 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def __init__(
self.priority_fns: Dict[str, Optional[Callable]] = {}
self.forward_fns: Dict[str, Optional[Callable]] = {}
self.verify_fns: Dict[str, Optional[Callable]] = {}
self.required_hash_fields: Dict[str, str] = {}

# Instantiate FastAPI
self.app = FastAPI()
Expand Down Expand Up @@ -566,12 +565,6 @@ def verify_custom(synapse: MyCustomSynapse):
) # Use 'default_verify' if 'verify_fn' is None
self.forward_fns[request_name] = forward_fn

# Parse required hash fields from the forward function protocol defaults
required_hash_fields = request_class.__dict__["__fields__"][
"required_hash_fields"
].default
self.required_hash_fields[request_name] = required_hash_fields

return self

@classmethod
Expand Down Expand Up @@ -696,9 +689,7 @@ async def verify_body_integrity(self, request: Request):
body = await request.body()
request_body = body.decode() if isinstance(body, bytes) else body

# Gather the required field names from the axon's required_hash_fields dict
request_name = request.url.path.split("/")[1]
required_hash_fields = self.required_hash_fields[request_name]

# Load the body dict and check if all required field hashes match
body_dict = json.loads(request_body)
Expand Down
55 changes: 34 additions & 21 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import base64
import json
import sys
import typing
import warnings

import pydantic
from pydantic.schema import schema
import bittensor
from typing import Optional, List, Any, Dict
from typing import Optional, Any, Dict


def get_size(obj, seen=None) -> int:
Expand Down Expand Up @@ -293,6 +295,8 @@ class Synapse(pydantic.BaseModel):
5. Body Hash Computation (``computed_body_hash``, ``required_hash_fields``):
Ensures data integrity and security by computing hashes of transmitted data. Provides users with a
mechanism to verify data integrity and detect any tampering during transmission.
It is recommended that names of fields in `required_hash_fields` are listed in the order they are
defined in the class.

6. Serialization and Deserialization Methods:
Facilitates the conversion of Synapse objects to and from a format suitable for network transmission.
Expand Down Expand Up @@ -480,14 +484,7 @@ def set_name_type(cls, values) -> dict:
repr=False,
)

required_hash_fields: Optional[List[str]] = pydantic.Field(
mjurbanski-reef marked this conversation as resolved.
Show resolved Hide resolved
title="required_hash_fields",
description="The list of required fields to compute the body hash.",
examples=["roles", "messages"],
default=[],
allow_mutation=False,
repr=False,
)
required_hash_fields: typing.ClassVar[typing.Tuple[str, ...]] = ()

def __setattr__(self, name: str, value: Any):
"""
Expand Down Expand Up @@ -683,21 +680,37 @@ def body_hash(self) -> str:
Returns:
str: The SHA3-256 hash as a hexadecimal string, providing a fingerprint of the Synapse instance's data for integrity checks.
"""
# Hash the body for verification
hashes = []

# Getting the fields of the instance
instance_fields = self.dict()
hash_fields_field = self.__class__.__fields__.get("required_hash_fields")
instance_fields = None
if hash_fields_field:
warnings.warn(
"The 'required_hash_fields' field handling deprecated and will be removed. "
"Please update Synapse class definition to use 'required_hash_fields' class variable instead.",
DeprecationWarning,
)
required_hash_fields = hash_fields_field.default

if required_hash_fields:
instance_fields = self.dict()
# Preserve backward compatibility in which fields will added in .dict() order
# instead of the order one from `self.required_hash_fields`
required_hash_fields = [
field for field in instance_fields if field in required_hash_fields
]

# Hack to cache the required hash fields names
if len(required_hash_fields) == len(required_hash_fields):
self.__class__.required_hash_fields = tuple(required_hash_fields)
else:
required_hash_fields = self.__class__.required_hash_fields

if required_hash_fields:
instance_fields = instance_fields or self.dict()
for field in required_hash_fields:
hashes.append(bittensor.utils.hash(str(instance_fields[field])))

for field, value in instance_fields.items():
# If the field is required in the subclass schema, hash and add it.
if (
self.required_hash_fields is not None
and field in self.required_hash_fields
):
hashes.append(bittensor.utils.hash(str(value)))

# Hash and return the hashes that have been concatenated
return bittensor.utils.hash("".join(hashes))

@classmethod
Expand Down
57 changes: 44 additions & 13 deletions tests/unit_tests/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,6 @@ def test_body_hash_override():
synapse_instance.body_hash = []


def test_required_fields_override():
# Create a Synapse instance
synapse_instance = bittensor.Synapse()

# Try to set the required_hash_fields property and expect a TypeError
with pytest.raises(
TypeError,
match='"required_hash_fields" has allow_mutation set to False and cannot be assigned',
):
synapse_instance.required_hash_fields = []


def test_default_instance_fields_dict_consistency():
synapse_instance = bittensor.Synapse()
assert synapse_instance.dict() == {
Expand Down Expand Up @@ -221,5 +209,48 @@ def test_default_instance_fields_dict_consistency():
"signature": None,
},
"computed_body_hash": "",
"required_hash_fields": [],
}


class LegacyHashedSynapse(bittensor.Synapse):
"""Legacy Synapse subclass that serialized `required_hash_fields`."""

a: int
b: int
c: typing.Optional[int]
d: typing.Optional[typing.List[str]]
required_hash_fields: typing.Optional[typing.List[str]] = ["b", "a", "d"]


class HashedSynapse(bittensor.Synapse):
a: int
b: int
c: typing.Optional[int]
d: typing.Optional[typing.List[str]]
required_hash_fields: typing.ClassVar[typing.Tuple[str]] = ("a", "b", "d")


@pytest.mark.parametrize("synapse_cls", [LegacyHashedSynapse, HashedSynapse])
def test_synapse_body_hash(synapse_cls):
synapse_instance = synapse_cls(a=1, b=2, d=["foobar"])
assert (
synapse_instance.body_hash
== "ae06397d08f30f75c91395c59f05c62ac3b62b88250eb78b109213258e6ced0c"
)

# Extra non-hashed values should not influence the body hash
synapse_instance_slightly_different = synapse_cls(d=["foobar"], c=3, a=1, b=2)
assert synapse_instance.body_hash == synapse_instance_slightly_different.body_hash

# Even if someone tries to override the required_hash_fields, it should still be the same
synapse_instance_try_override_hash_fields = synapse_cls(
a=1, b=2, d=["foobar"], required_hash_fields=["a"]
)
assert (
synapse_instance.body_hash
== synapse_instance_try_override_hash_fields.body_hash
)

# Different hashed values should result in different body hashes
synapse_different = synapse_cls(a=1, b=2)
assert synapse_instance.body_hash != synapse_different.body_hash