Skip to content

Commit

Permalink
Merge pull request #1818 from backend-developers-ltd/require_required…
Browse files Browse the repository at this point in the history
…_hash_fields

Require required hash fields (security improvement)
  • Loading branch information
gus-opentensor committed May 22, 2024
2 parents 782fde3 + cb4e3a6 commit 87df079
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 56 deletions.
9 changes: 0 additions & 9 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def __init__(
self.priority_fns: Dict[str, Optional[Callable]] = {}
self.forward_fns: Dict[str, Optional[Callable]] = {}
self.verify_fns: Dict[str, Optional[Callable]] = {}
self.required_hash_fields: Dict[str, str] = {}

# Instantiate FastAPI
self.app = FastAPI()
Expand Down Expand Up @@ -566,12 +565,6 @@ def verify_custom(synapse: MyCustomSynapse):
) # Use 'default_verify' if 'verify_fn' is None
self.forward_fns[request_name] = forward_fn

# Parse required hash fields from the forward function protocol defaults
required_hash_fields = request_class.__dict__["model_fields"][
"required_hash_fields"
].default
self.required_hash_fields[request_name] = required_hash_fields

return self

@classmethod
Expand Down Expand Up @@ -696,9 +689,7 @@ async def verify_body_integrity(self, request: Request):
body = await request.body()
request_body = body.decode() if isinstance(body, bytes) else body

# Gather the required field names from the axon's required_hash_fields dict
request_name = request.url.path.split("/")[1]
required_hash_fields = self.required_hash_fields[request_name]

# Load the body dict and check if all required field hashes match
body_dict = json.loads(request_body)
Expand Down
55 changes: 34 additions & 21 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import base64
import json
import sys
import typing
import warnings

from pydantic import (
BaseModel,
Expand All @@ -29,7 +31,7 @@
model_validator,
)
import bittensor
from typing import Optional, List, Any, Dict
from typing import Optional, Any, Dict


def get_size(obj, seen=None) -> int:
Expand Down Expand Up @@ -301,6 +303,8 @@ class Synapse(BaseModel):
5. Body Hash Computation (``computed_body_hash``, ``required_hash_fields``):
Ensures data integrity and security by computing hashes of transmitted data. Provides users with a
mechanism to verify data integrity and detect any tampering during transmission.
It is recommended that names of fields in `required_hash_fields` are listed in the order they are
defined in the class.
6. Serialization and Deserialization Methods:
Facilitates the conversion of Synapse objects to and from a format suitable for network transmission.
Expand Down Expand Up @@ -478,14 +482,7 @@ def set_name_type(cls, values) -> dict:
repr=False,
)

required_hash_fields: Optional[List[str]] = Field(
title="required_hash_fields",
description="The list of required fields to compute the body hash.",
examples=["roles", "messages"],
default=[],
frozen=True,
repr=False,
)
required_hash_fields: typing.ClassVar[typing.Tuple[str, ...]] = ()

_extract_total_size = field_validator("total_size", mode="before")(cast_int)

Expand Down Expand Up @@ -692,21 +689,37 @@ def body_hash(self) -> str:
Returns:
str: The SHA3-256 hash as a hexadecimal string, providing a fingerprint of the Synapse instance's data for integrity checks.
"""
# Hash the body for verification
hashes = []

# Getting the fields of the instance
instance_fields = self.model_dump()
hash_fields_field = self.model_fields.get("required_hash_fields")
instance_fields = None
if hash_fields_field:
warnings.warn(
"The 'required_hash_fields' field handling deprecated and will be removed. "
"Please update Synapse class definition to use 'required_hash_fields' class variable instead.",
DeprecationWarning,
)
required_hash_fields = hash_fields_field.default

if required_hash_fields:
instance_fields = self.model_dump()
# Preserve backward compatibility in which fields will added in .dict() order
# instead of the order one from `self.required_hash_fields`
required_hash_fields = [
field for field in instance_fields if field in required_hash_fields
]

# Hack to cache the required hash fields names
if len(required_hash_fields) == len(required_hash_fields):
self.__class__.required_hash_fields = tuple(required_hash_fields)
else:
required_hash_fields = self.__class__.required_hash_fields

if required_hash_fields:
instance_fields = instance_fields or self.dict()
for field in required_hash_fields:
hashes.append(bittensor.utils.hash(str(instance_fields[field])))

for field, value in instance_fields.items():
# If the field is required in the subclass schema, hash and add it.
if (
self.required_hash_fields is not None
and field in self.required_hash_fields
):
hashes.append(bittensor.utils.hash(str(value)))

# Hash and return the hashes that have been concatenated
return bittensor.utils.hash("".join(hashes))

@classmethod
Expand Down
83 changes: 57 additions & 26 deletions tests/unit_tests/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
# DEALINGS IN THE SOFTWARE.
import json
import base64
from typing import List, Optional
import typing
from typing import Optional

import pydantic_core
import pytest
import bittensor


def test_parse_headers_to_inputs():
class Test(bittensor.Synapse):
key1: List[int]
key1: list[int]

# Define a mock headers dictionary to use for testing
headers = {
Expand Down Expand Up @@ -60,7 +60,7 @@ class Test(bittensor.Synapse):

def test_from_headers():
class Test(bittensor.Synapse):
key1: List[int]
key1: list[int]

# Define a mock headers dictionary to use for testing
headers = {
Expand Down Expand Up @@ -131,13 +131,13 @@ class Test(bittensor.Synapse):
a: int # Carried through because required.
b: int = None # Not carried through headers
c: Optional[int] # Required, carried through headers, cannot be None
d: Optional[List[int]] # Required, carried though headers, cannot be None
e: List[int] # Carried through headers
d: Optional[list[int]] # Required, carried though headers, cannot be None
e: list[int] # Carried through headers
f: Optional[
int
] = None # Not Required, Not carried through headers, can be None
g: Optional[
List[int]
list[int]
] = None # Not Required, Not carried though headers, can be None

# Create an instance of the custom Synapse subclass
Expand All @@ -152,12 +152,12 @@ class Test(bittensor.Synapse):
assert isinstance(synapse, Test)
assert synapse.name == "Test"
assert synapse.a == 1
assert synapse.b == None
assert synapse.b is None
assert synapse.c == 3
assert synapse.d == [1, 2, 3, 4]
assert synapse.e == [1, 2, 3, 4]
assert synapse.f == None
assert synapse.g == None
assert synapse.f is None
assert synapse.g is None

# Convert the Test instance to a headers dictionary
headers = synapse.to_headers()
Expand All @@ -169,12 +169,12 @@ class Test(bittensor.Synapse):
# Create a new Test from the headers and check its properties
next_synapse = synapse.from_headers(synapse.to_headers())
assert next_synapse.a == 0 # Default value is 0
assert next_synapse.b == None
assert next_synapse.b is None
assert next_synapse.c == 0 # Default is 0
assert next_synapse.d == [] # Default is []
assert next_synapse.e == [] # Empty list is default for list types
assert next_synapse.f == None
assert next_synapse.g == None
assert next_synapse.f is None
assert next_synapse.g is None


def test_body_hash_override():
Expand All @@ -189,18 +189,6 @@ def test_body_hash_override():
synapse_instance.body_hash = []


def test_required_fields_override():
# Create a Synapse instance
synapse_instance = bittensor.Synapse()

# Try to set the required_hash_fields property and expect a TypeError
with pytest.raises(
pydantic_core.ValidationError,
match="required_hash_fields\n Field is frozen",
):
synapse_instance.required_hash_fields = []


def test_default_instance_fields_dict_consistency():
synapse_instance = bittensor.Synapse()
assert synapse_instance.dict() == {
Expand Down Expand Up @@ -233,5 +221,48 @@ def test_default_instance_fields_dict_consistency():
"signature": None,
},
"computed_body_hash": "",
"required_hash_fields": [],
}


class LegacyHashedSynapse(bittensor.Synapse):
"""Legacy Synapse subclass that serialized `required_hash_fields`."""

a: int
b: int
c: Optional[int] = None
d: Optional[list[str]] = None
required_hash_fields: Optional[list[str]] = ["b", "a", "d"]


class HashedSynapse(bittensor.Synapse):
a: int
b: int
c: Optional[int] = None
d: Optional[list[str]] = None
required_hash_fields: typing.ClassVar[tuple[str, ...]] = ("a", "b", "d")


@pytest.mark.parametrize("synapse_cls", [LegacyHashedSynapse, HashedSynapse])
def test_synapse_body_hash(synapse_cls):
synapse_instance = synapse_cls(a=1, b=2, d=["foobar"])
assert (
synapse_instance.body_hash
== "ae06397d08f30f75c91395c59f05c62ac3b62b88250eb78b109213258e6ced0c"
)

# Extra non-hashed values should not influence the body hash
synapse_instance_slightly_different = synapse_cls(d=["foobar"], c=3, a=1, b=2)
assert synapse_instance.body_hash == synapse_instance_slightly_different.body_hash

# Even if someone tries to override the required_hash_fields, it should still be the same
synapse_instance_try_override_hash_fields = synapse_cls(
a=1, b=2, d=["foobar"], required_hash_fields=["a"]
)
assert (
synapse_instance.body_hash
== synapse_instance_try_override_hash_fields.body_hash
)

# Different hashed values should result in different body hashes
synapse_different = synapse_cls(a=1, b=2)
assert synapse_instance.body_hash != synapse_different.body_hash

0 comments on commit 87df079

Please sign in to comment.