Skip to content

Commit

Permalink
Improve tests logic for Part 3 of subtensor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-opentensor committed May 23, 2024
1 parent a38091e commit afa165a
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion tests/unit_tests/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@

# Application
import bittensor
from bittensor.subtensor import subtensor as Subtensor, _logger, Balance
from bittensor.subtensor import (
subtensor as Subtensor,
_logger,
Balance,
U16_NORMALIZED_FLOAT,
U64_NORMALIZED_FLOAT,
)
from bittensor import subtensor_module


Expand Down Expand Up @@ -402,6 +408,10 @@ def test_hyper_parameter_success_calls(
# Prep
subtensor._get_hyperparameter = mocker.MagicMock(return_value=value)

spy_u16_normalized_float = mocker.spy(subtensor_module, "U16_NORMALIZED_FLOAT")
spy_u64_normalized_float = mocker.spy(subtensor_module, "U64_NORMALIZED_FLOAT")
spy_balance_from_rao = mocker.spy(Balance, "from_rao")

# Call
subtensor_method = getattr(subtensor, method)
result = subtensor_method(netuid=7, block=707)
Expand All @@ -413,6 +423,21 @@ def test_hyper_parameter_success_calls(
# if we change the methods logic in the future we have to be make sure tha returned type is correct
assert isinstance(result, expected_result_type)

# Special cases
if method in [
"kappa",
"validator_logits_divergence",
"validator_exclude_quantile",
"max_weight_limit",
]:
spy_u16_normalized_float.assert_called_once()

if method in ["adjustment_alpha", "bonds_moving_avg"]:
spy_u64_normalized_float.assert_called_once()

if method in ["recycle"]:
spy_balance_from_rao.assert_called_once()


def test_blocks_since_last_update_success_calls(subtensor, mocker):
"""Tests the weights_rate_limit method to ensure it correctly fetches the LastUpdate hyperparameter."""
Expand All @@ -427,6 +452,7 @@ def test_blocks_since_last_update_success_calls(subtensor, mocker):
result = subtensor.blocks_since_last_update(netuid=7, uid=uid)

# Assertions
subtensor.get_current_block.assert_called_once()
subtensor._get_hyperparameter.assert_called_once_with(
param_name="LastUpdate", netuid=7
)
Expand Down

0 comments on commit afa165a

Please sign in to comment.