Skip to content

Commit

Permalink
Merge pull request #1873 from opentensor/tests/gus/extend-coverage-ov…
Browse files Browse the repository at this point in the history
…erview

Tests: extends coverage for overview cmd part 1
  • Loading branch information
gus-opentensor committed May 14, 2024
2 parents 760e52b + 124ffbc commit e5c0693
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 37 deletions.
102 changes: 65 additions & 37 deletions bittensor/commands/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,12 @@ def run(cli: "bittensor.cli"):
subtensor.close()
bittensor.logging.debug("closing subtensor connection")

def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
r"""Prints an overview for the wallet's colkey."""
console = bittensor.__console__
wallet = bittensor.wallet(config=cli.config)

all_hotkeys = []
total_balance = bittensor.Balance(0)

# We are printing for every coldkey.
@staticmethod
def _get_total_balance(
total_balance: "bittensor.Balance",
subtensor: "bittensor.subtensor",
cli: "bittensor.cli",
) -> Tuple[List["bittensor.wallet"], "bittensor.Balance"]:
if cli.config.get("all", d=None):
cold_wallets = get_coldkey_wallets_for_path(cli.config.wallet.path)
for cold_wallet in tqdm(cold_wallets, desc="Pulling balances"):
Expand All @@ -125,23 +122,59 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
return
all_hotkeys = get_hotkey_wallets_for_wallet(coldkey_wallet)

# We are printing for a select number of hotkeys from all_hotkeys.
return all_hotkeys, total_balance

if cli.config.get("hotkeys", []):
if not cli.config.get("all_hotkeys", False):
# We are only showing hotkeys that are specified.
all_hotkeys = [
hotkey
for hotkey in all_hotkeys
if hotkey.hotkey_str in cli.config.hotkeys
]
else:
# We are excluding the specified hotkeys from all_hotkeys.
all_hotkeys = [
hotkey
for hotkey in all_hotkeys
if hotkey.hotkey_str not in cli.config.hotkeys
]
@staticmethod
def _get_hotkeys(
cli: "bittensor.cli", all_hotkeys: List["bittensor.wallet"]
) -> List["bittensor.wallet"]:
if not cli.config.get("all_hotkeys", False):
# We are only showing hotkeys that are specified.
all_hotkeys = [
hotkey
for hotkey in all_hotkeys
if hotkey.hotkey_str in cli.config.hotkeys
]
else:
# We are excluding the specified hotkeys from all_hotkeys.
all_hotkeys = [
hotkey
for hotkey in all_hotkeys
if hotkey.hotkey_str not in cli.config.hotkeys
]
return all_hotkeys

@staticmethod
def _get_key_address(all_hotkeys: List["bittensor.wallet"]):
hotkey_coldkey_to_hotkey_wallet = {}
for hotkey_wallet in all_hotkeys:
if hotkey_wallet.hotkey.ss58_address not in hotkey_coldkey_to_hotkey_wallet:
hotkey_coldkey_to_hotkey_wallet[hotkey_wallet.hotkey.ss58_address] = {}

hotkey_coldkey_to_hotkey_wallet[hotkey_wallet.hotkey.ss58_address][
hotkey_wallet.coldkeypub.ss58_address
] = hotkey_wallet

all_hotkey_addresses = list(hotkey_coldkey_to_hotkey_wallet.keys())

return all_hotkey_addresses, hotkey_coldkey_to_hotkey_wallet

def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
r"""Prints an overview for the wallet's colkey."""
console = bittensor.__console__
wallet = bittensor.wallet(config=cli.config)

all_hotkeys = []
total_balance = bittensor.Balance(0)

# We are printing for every coldkey.
all_hotkeys, total_balance = OverviewCommand._get_total_balance(
total_balance, subtensor, cli
)

# We are printing for a select number of hotkeys from all_hotkeys.
if cli.config.get("hotkeys"):
all_hotkeys = OverviewCommand._get_hotkeys(cli, all_hotkeys)

# Check we have keys to display.
if len(all_hotkeys) == 0:
Expand All @@ -161,21 +194,16 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
for netuid in netuids:
neurons[str(netuid)] = []

all_wallet_names = set([wallet.name for wallet in all_hotkeys])
all_wallet_names = {wallet.name for wallet in all_hotkeys}
all_coldkey_wallets = [
bittensor.wallet(name=wallet_name) for wallet_name in all_wallet_names
]

hotkey_coldkey_to_hotkey_wallet = {}
for hotkey_wallet in all_hotkeys:
if hotkey_wallet.hotkey.ss58_address not in hotkey_coldkey_to_hotkey_wallet:
hotkey_coldkey_to_hotkey_wallet[hotkey_wallet.hotkey.ss58_address] = {}

hotkey_coldkey_to_hotkey_wallet[hotkey_wallet.hotkey.ss58_address][
hotkey_wallet.coldkeypub.ss58_address
] = hotkey_wallet
(
all_hotkey_addresses,
hotkey_coldkey_to_hotkey_wallet,
) = OverviewCommand._get_key_address(all_hotkeys)

all_hotkey_addresses = list(hotkey_coldkey_to_hotkey_wallet.keys())
with console.status(
":satellite: Syncing with chain: [white]{}[/white] ...".format(
cli.config.subtensor.get(
Expand Down Expand Up @@ -249,7 +277,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
)
)

if len(coldkeys_to_check) > 0:
if coldkeys_to_check:
# We have some stake that is not with a registered hotkey.
if "-1" not in neurons:
neurons["-1"] = []
Expand Down Expand Up @@ -294,7 +322,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
wallet_.hotkey_ss58 = hotkey_addr
wallet.hotkey_str = hotkey_addr[:5] # Max length of 5 characters
# Indicates a hotkey not on local machine but exists in stake_info obj on-chain
if hotkey_coldkey_to_hotkey_wallet.get(hotkey_addr) == None:
if hotkey_coldkey_to_hotkey_wallet.get(hotkey_addr) is None:
hotkey_coldkey_to_hotkey_wallet[hotkey_addr] = {}
hotkey_coldkey_to_hotkey_wallet[hotkey_addr][
coldkey_wallet.coldkeypub.ss58_address
Expand Down
185 changes: 185 additions & 0 deletions tests/unit_tests/test_overview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Standard Lib
from copy import deepcopy
from unittest.mock import MagicMock, patch

# Pytest
import pytest

# Bittensor
import bittensor
from bittensor.commands.overview import OverviewCommand


@pytest.fixture
def mock_subtensor():
mock = MagicMock()
mock.get_balance = MagicMock(return_value=100)
return mock


def fake_config(**kwargs):
config = deepcopy(construct_config())
for key, value in kwargs.items():
setattr(config, key, value)
return config


def construct_config():
parser = bittensor.cli.__create_parser__()
defaults = bittensor.config(parser=parser, args=[])
# Parse commands and subcommands
for command in bittensor.ALL_COMMANDS:
if (
command in bittensor.ALL_COMMANDS
and "commands" in bittensor.ALL_COMMANDS[command]
):
for subcommand in bittensor.ALL_COMMANDS[command]["commands"]:
defaults.merge(
bittensor.config(parser=parser, args=[command, subcommand])
)
else:
defaults.merge(bittensor.config(parser=parser, args=[command]))

defaults.netuid = 1
# Always use mock subtensor.
defaults.subtensor.network = "finney"
# Skip version checking.
defaults.no_version_checking = True

return defaults


@pytest.fixture
def mock_wallet():
mock = MagicMock()
mock.coldkeypub_file.exists_on_device = MagicMock(return_value=True)
mock.coldkeypub_file.is_encrypted = MagicMock(return_value=False)
mock.coldkeypub.ss58_address = "fake_address"
return mock


class MockHotkey:
def __init__(self, hotkey_str):
self.hotkey_str = hotkey_str


class MockCli:
def __init__(self, config):
self.config = config


@pytest.mark.parametrize(
"config_all, exists_on_device, is_encrypted, expected_balance, test_id",
[
(True, True, False, 100, "happy_path_all_wallets"),
(False, True, False, 100, "happy_path_single_wallet"),
(True, False, False, 0, "edge_case_no_wallets_found"),
(True, True, True, 0, "edge_case_encrypted_wallet"),
],
)
def test_get_total_balance(
mock_subtensor,
mock_wallet,
config_all,
exists_on_device,
is_encrypted,
expected_balance,
test_id,
):
# Arrange
cli = MockCli(fake_config(all=config_all))
mock_wallet.coldkeypub_file.exists_on_device.return_value = exists_on_device
mock_wallet.coldkeypub_file.is_encrypted.return_value = is_encrypted

with patch(
"bittensor.wallet", return_value=mock_wallet
) as mock_wallet_constructor, patch(
"bittensor.commands.overview.get_coldkey_wallets_for_path",
return_value=[mock_wallet] if config_all else [],
), patch(
"bittensor.commands.overview.get_all_wallets_for_path",
return_value=[mock_wallet],
), patch(
"bittensor.commands.overview.get_hotkey_wallets_for_wallet",
return_value=[mock_wallet],
):
# Act
result_hotkeys, result_balance = OverviewCommand._get_total_balance(
0, mock_subtensor, cli
)

# Assert
assert result_balance == expected_balance, f"Test ID: {test_id}"
assert all(
isinstance(hotkey, MagicMock) for hotkey in result_hotkeys
), f"Test ID: {test_id}"


@pytest.mark.parametrize(
"config, all_hotkeys, expected_result, test_id",
[
# Happy path tests
(
{"all_hotkeys": False, "hotkeys": ["abc123", "xyz456"]},
[MockHotkey("abc123"), MockHotkey("xyz456"), MockHotkey("mno567")],
["abc123", "xyz456"],
"test_happy_path_included",
),
(
{"all_hotkeys": True, "hotkeys": ["abc123", "xyz456"]},
[MockHotkey("abc123"), MockHotkey("xyz456"), MockHotkey("mno567")],
["mno567"],
"test_happy_path_excluded",
),
# Edge cases
(
{"all_hotkeys": False, "hotkeys": []},
[MockHotkey("abc123"), MockHotkey("xyz456")],
[],
"test_edge_no_hotkeys_specified",
),
(
{"all_hotkeys": True, "hotkeys": []},
[MockHotkey("abc123"), MockHotkey("xyz456")],
["abc123", "xyz456"],
"test_edge_all_hotkeys_excluded",
),
(
{"all_hotkeys": False, "hotkeys": ["abc123", "xyz456"]},
[],
[],
"test_edge_no_hotkeys_available",
),
(
{"all_hotkeys": True, "hotkeys": ["abc123", "xyz456"]},
[],
[],
"test_edge_no_hotkeys_available_excluded",
),
],
)
def test_get_hotkeys(config, all_hotkeys, expected_result, test_id):
# Arrange
cli = MockCli(
fake_config(
hotkeys=config.get("hotkeys"), all_hotkeys=config.get("all_hotkeys")
)
)

# Act
result = OverviewCommand._get_hotkeys(cli, all_hotkeys)

# Assert
assert [
hotkey.hotkey_str for hotkey in result
] == expected_result, f"Failed {test_id}"


def test_get_hotkeys_error():
# Arrange
cli = MockCli(fake_config(hotkeys=["abc123", "xyz456"], all_hotkeys=False))
all_hotkeys = None

# Act
with pytest.raises(TypeError):
OverviewCommand._get_hotkeys(cli, all_hotkeys)

0 comments on commit e5c0693

Please sign in to comment.