Skip to content

Commit

Permalink
Merge pull request #1769 from opentensor/tests/gus/extends-coverage-o…
Browse files Browse the repository at this point in the history
…n-axon

Tests: Extends coverage on axon methods
  • Loading branch information
gus-opentensor committed Apr 5, 2024
2 parents bdfa793 + e9e15bc commit 6474c02
Showing 1 changed file with 191 additions and 10 deletions.
201 changes: 191 additions & 10 deletions tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# Standard Lib
import pytest
import unittest
import bittensor
from typing import Any
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock

# Third Party
from starlette.requests import Request
from unittest.mock import MagicMock

# Bittensor
import bittensor
from bittensor.axon import AxonMiddleware
from bittensor.axon import axon as Axon


def test_attach():
Expand Down Expand Up @@ -127,7 +133,47 @@ def test_create_error_response():
assert response.body == b'{"message":"Error"}'


# Mock synapse class for testing
# Fixtures
@pytest.fixture
def middleware():
# Mock AxonMiddleware instance with empty axon object
axon = AxonMock()
return AxonMiddleware(None, axon)


@pytest.fixture
def mock_request():
request = AsyncMock(spec=Request)
request.body = AsyncMock(return_value=b'{"field1": "value1", "field2": "value2"}')
request.url.path = "/test_endpoint"
request.headers = {"computed_body_hash": "correct_hash"}
return request


@pytest.fixture
def axon_instance():
axon = Axon()
axon.required_hash_fields = {"test_endpoint": ["field1", "field2"]}
axon.forward_class_types = {
"test_endpoint": MagicMock(return_value=MagicMock(body_hash="correct_hash"))
}
return axon


# Mocks
class MockWallet:
def __init__(self, hotkey):
self.hotkey = hotkey


class MockHotkey:
def __init__(self, ss58_address):
self.ss58_address = ss58_address


class MockInfo:
def to_string(self):
return "MockInfoString"


class AxonMock:
Expand Down Expand Up @@ -169,13 +215,6 @@ def priority_fn_timeout(synapse) -> float:
return 2.0


@pytest.fixture
def middleware():
# Mock AxonMiddleware instance with empty axon object
axon = AxonMock()
return AxonMiddleware(None, axon)


@pytest.mark.asyncio
async def test_verify_pass(middleware):
synapse = SynapseMock()
Expand Down Expand Up @@ -218,6 +257,148 @@ async def test_priority_pass(middleware):
assert synapse.axon.status_code != 408


@pytest.mark.parametrize(
"body, expected",
[
(
b'{"field1": "value1", "field2": "value2"}',
{"field1": "value1", "field2": "value2"},
),
(
b'{"field1": "different_value", "field2": "another_value"}',
{"field1": "different_value", "field2": "another_value"},
),
],
)
async def test_verify_body_integrity_happy_path(
mock_request, axon_instance, body, expected
):
# Arrange
mock_request.body.return_value = body

# Act
result = await axon_instance.verify_body_integrity(mock_request)

# Assert
assert result == expected, "The parsed body should match the expected dictionary."


@pytest.mark.parametrize(
"body, expected_exception_message",
[
(b"", "EOFError"), # Empty body
(b"not_json", "JSONDecodeError"), # Non-JSON body
],
ids=["empty_body", "non_json_body"],
)
async def test_verify_body_integrity_edge_cases(
mock_request, axon_instance, body, expected_exception_message
):
# Arrange
mock_request.body.return_value = body

# Act & Assert
with pytest.raises(Exception) as exc_info:
await axon_instance.verify_body_integrity(mock_request)
assert expected_exception_message in str(
exc_info.value
), "Expected specific exception message."


@pytest.mark.parametrize(
"computed_hash, expected_error",
[
("incorrect_hash", ValueError),
],
)
async def test_verify_body_integrity_error_cases(
mock_request, axon_instance, computed_hash, expected_error
):
# Arrange
mock_request.headers["computed_body_hash"] = computed_hash

# Act & Assert
with pytest.raises(expected_error) as exc_info:
await axon_instance.verify_body_integrity(mock_request)
assert "Hash mismatch" in str(exc_info.value), "Expected a hash mismatch error."


@pytest.mark.parametrize(
"info_return, expected_output, test_id",
[
(MockInfo(), "MockInfoString", "happy_path_basic"),
(MockInfo(), "MockInfoString", "edge_case_empty_string"),
],
)
def test_to_string(info_return, expected_output, test_id, mocker):
# Arrange
axon = Axon()
mocker.patch.object(axon, "info", return_value=info_return)

# Act
output = axon.to_string()

# Assert
assert output == expected_output, f"Test ID: {test_id}"


@pytest.mark.parametrize(
"ip, port, ss58_address, started, forward_fns, expected_str, test_id",
[
# Happy path
(
"127.0.0.1",
8080,
"5G9RtsTbiYJYQYJzUfTCs...",
True,
{"fn1": None},
"Axon(127.0.0.1, 8080, 5G9RtsTbiYJYQYJzUfTCs..., started, ['fn1'])",
"happy_path_started_with_forward_fn",
),
(
"192.168.1.1",
3030,
"5HqUkGuo62b5...",
False,
{},
"Axon(192.168.1.1, 3030, 5HqUkGuo62b5..., stopped, [])",
"happy_path_stopped_no_forward_fn",
),
# Edge cases
("", 0, "", False, {}, "Axon(, 0, , stopped, [])", "edge_empty_values"),
(
"255.255.255.255",
65535,
"5G9RtsTbiYJYQYJzUfTCs...",
True,
{"fn1": None, "fn2": None},
"Axon(255.255.255.255, 65535, 5G9RtsTbiYJYQYJzUfTCs..., started, ['fn1', 'fn2'])",
"edge_max_values",
),
],
)
def test_axon_str_representation(
ip, port, ss58_address, started, forward_fns, expected_str, test_id
):
# Arrange
hotkey = MockHotkey(ss58_address)
wallet = MockWallet(hotkey)
axon = Axon()
axon.ip = ip
axon.port = port
axon.wallet = wallet
axon.started = started
axon.forward_fns = forward_fns

# Act
result_dunder_str = axon.__str__()
result_dunder_repr = axon.__repr__()

# Assert
assert result_dunder_str == expected_str, f"Test ID: {test_id}"
assert result_dunder_repr == expected_str, f"Test ID: {test_id}"


class TestAxonMiddleware(IsolatedAsyncioTestCase):
def setUp(self):
# Create a mock app
Expand Down

0 comments on commit 6474c02

Please sign in to comment.