Skip to content

Commit

Permalink
Merge pull request #1919 from opentensor/thewhaleking/further-legacy-…
Browse files Browse the repository at this point in the history
…torch-improvements

thewhaleking/further-legacy-torch-improvements
  • Loading branch information
thewhaleking committed May 22, 2024
2 parents f54157f + c822a6e commit c39fc78
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 51 deletions.
30 changes: 4 additions & 26 deletions bittensor/commands/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from rich.prompt import Prompt
from rich.table import Table
from .utils import get_delegates_details, DelegatesDetails
from bittensor.utils.registration import torch, use_torch

from . import defaults

Expand Down Expand Up @@ -283,7 +282,6 @@ def run(cli: "bittensor.cli"):
def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
r"""Set weights for root network."""
wallet = bittensor.wallet(config=cli.config)
subnets: List[bittensor.SubnetInfo] = subtensor.get_all_subnets_info()

root = subtensor.metagraph(0, lite=False)
try:
Expand All @@ -301,11 +299,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
f"Boosting weight for netuid {cli.config.netuid} from {prev_weight} -> {new_weight}"
)
my_weights[cli.config.netuid] = new_weight
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)
all_netuids = np.arange(len(my_weights))

bittensor.__console__.print("Setting root weights...")
subtensor.root_set_weights(
Expand Down Expand Up @@ -405,7 +399,6 @@ def run(cli: "bittensor.cli"):
@staticmethod
def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
wallet = bittensor.wallet(config=cli.config)
subnets: List[bittensor.SubnetInfo] = subtensor.get_all_subnets_info()

bittensor.__console__.print(
"Slashing weight for subnet: {} by amount: {}".format(
Expand All @@ -423,11 +416,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
my_weights = root.weights[my_uid]
my_weights[cli.config.netuid] -= cli.config.amount
my_weights[my_weights < 0] = 0 # Ensure weights don't go negative
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)
all_netuids = np.arange(len(my_weights))

subtensor.root_set_weights(
wallet=wallet,
Expand Down Expand Up @@ -529,23 +518,12 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):

# Parse from string
matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids)))
netuids = (
torch.tensor(matched_netuids, dtype=torch.long)
if use_torch()
else np.array(matched_netuids, dtype=np.int64)
)
netuids = np.array(matched_netuids, dtype=np.int64)

matched_weights = [
float(weight) for weight in re.split(r"[ ,]+", cli.config.weights)
]
weights = (
torch.tensor(matched_weights, dtype=torch.float32)
if use_torch()
else np.array(
matched_weights,
dtype=np.float32,
)
)
weights = np.array(matched_weights, dtype=np.float32)

# Run the set weights operation.
subtensor.root_set_weights(
Expand Down
29 changes: 8 additions & 21 deletions bittensor/extrinsics/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Union, List
import bittensor.utils.weight_utils as weight_utils
from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.registration import torch, legacy_torch_api_compat

logger = logging.getLogger(BITTENSOR_LOGGER_NAME)

Expand Down Expand Up @@ -100,6 +100,7 @@ def root_register_extrinsic(
)


@legacy_torch_api_compat
def set_root_weights_extrinsic(
subtensor: "bittensor.subtensor",
wallet: "bittensor.wallet",
Expand Down Expand Up @@ -133,36 +134,22 @@ def set_root_weights_extrinsic(
"""
# First convert types.
if isinstance(netuids, list):
netuids = (
torch.tensor(netuids, dtype=torch.int64)
if use_torch()
else np.array(netuids, dtype=np.int64)
)
netuids = np.array(netuids, dtype=np.int64)
if isinstance(weights, list):
weights = (
torch.tensor(weights, dtype=torch.float32)
if use_torch()
else np.array(weights, dtype=np.float32)
)
weights = np.array(weights, dtype=np.float32)

# Get weight restrictions.
min_allowed_weights = subtensor.min_allowed_weights(netuid=0)
max_weight_limit = subtensor.max_weight_limit(netuid=0)

# Get non zero values.
non_zero_weight_idx = (
torch.argwhere(weights > 0).squeeze(dim=1)
if use_torch()
else np.argwhere(weights > 0).squeeze(axis=1)
)
non_zero_weight_idx = np.argwhere(weights > 0).squeeze(axis=1)
non_zero_weight_uids = netuids[non_zero_weight_idx]
non_zero_weights = weights[non_zero_weight_idx]
non_zero_weights_size = (
non_zero_weights.numel() if use_torch() else non_zero_weights.size
)
if non_zero_weights_size < min_allowed_weights:
if non_zero_weights.size < min_allowed_weights:
raise ValueError(
"The minimum number of weights required to set weights is {}, got {}".format(
min_allowed_weights, non_zero_weights_size
min_allowed_weights, non_zero_weights.size
)
)

Expand Down
2 changes: 2 additions & 0 deletions bittensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from .utils.balance import Balance
from .utils.registration import POWSolution
from .utils.subtensor import get_subtensor_errors
from .utils.registration import legacy_torch_api_compat


KEY_NONCE: Dict[str, int] = {}
Expand Down Expand Up @@ -2347,6 +2348,7 @@ def make_substrate_call_with_retry():

return make_substrate_call_with_retry()

@legacy_torch_api_compat
def root_set_weights(
self,
wallet: "bittensor.wallet",
Expand Down
94 changes: 94 additions & 0 deletions tests/unit_tests/extrinsics/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,97 @@ def test_set_root_weights_extrinsic(
mock_confirm.assert_called_once()
else:
mock_confirm.assert_not_called()


@pytest.mark.parametrize(
"wait_for_inclusion, wait_for_finalization, netuids, weights, prompt, user_response, expected_success",
[
(True, False, [1, 2], [0.5, 0.5], True, True, True), # Success - weights set
(
False,
False,
[1, 2],
[0.5, 0.5],
False,
None,
True,
), # Success - weights set no wait
(
True,
False,
[1, 2],
[2000, 20],
True,
True,
True,
), # Success - large value to be normalized
(
True,
False,
[1, 2],
[2000, 0],
True,
True,
True,
), # Success - single large value
(
True,
False,
[1, 2],
[0.5, 0.5],
True,
False,
False,
), # Failure - prompt declined
(
True,
False,
[1, 2],
[0.5, 0.5],
False,
None,
False,
), # Failure - setting weights failed
(
True,
False,
[],
[],
None,
False,
False,
), # Exception catched - ValueError 'min() arg is an empty sequence'
],
ids=[
"success-weights-set",
"success-not-wait",
"success-large-value",
"success-single-value",
"failure-user-declines",
"failure-setting-weights",
"failure-value-error-exception",
],
)
def test_set_root_weights_extrinsic_torch(
mock_subtensor,
mock_wallet,
wait_for_inclusion,
wait_for_finalization,
netuids,
weights,
prompt,
user_response,
expected_success,
force_legacy_torch_compat_api,
):
test_set_root_weights_extrinsic(
mock_subtensor,
mock_wallet,
wait_for_inclusion,
wait_for_finalization,
netuids,
weights,
prompt,
user_response,
expected_success,
)
8 changes: 4 additions & 4 deletions tests/unit_tests/utils/test_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ def test_normalize_with_max_weight__legacy_torch_api_compat(
wn = weight_utils.normalize_max_weight(weights, limit=1)
assert torch.isclose(wn, weights / weights.sum(), atol=1e-08, rtol=0).all()

# Check for eplison changes
eplison = 0.01
# Check for epsilon changes
epsilon = 0.01
weights, _ = torch.sort(torch.rand(100))
x = weights / weights.sum()
limit = x[-10]
change = eplison * limit
change = epsilon * limit
y = weight_utils.normalize_max_weight(x, limit=limit - change)
z = weight_utils.normalize_max_weight(x, limit=limit + change)
assert (y - z).abs().sum() < eplison
assert (y - z).abs().sum() < epsilon


@pytest.mark.parametrize(
Expand Down

0 comments on commit c39fc78

Please sign in to comment.