Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thewhaleking/further-legacy-torch-improvements #1919

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 4 additions & 26 deletions bittensor/commands/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from rich.prompt import Prompt
from rich.table import Table
from .utils import get_delegates_details, DelegatesDetails
from bittensor.utils.registration import torch, use_torch

from . import defaults

Expand Down Expand Up @@ -283,7 +282,6 @@ def run(cli: "bittensor.cli"):
def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
r"""Set weights for root network."""
wallet = bittensor.wallet(config=cli.config)
subnets: List[bittensor.SubnetInfo] = subtensor.get_all_subnets_info()

root = subtensor.metagraph(0, lite=False)
try:
Expand All @@ -301,11 +299,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
f"Boosting weight for netuid {cli.config.netuid} from {prev_weight} -> {new_weight}"
)
my_weights[cli.config.netuid] = new_weight
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)
all_netuids = np.arange(len(my_weights))

bittensor.__console__.print("Setting root weights...")
subtensor.root_set_weights(
Expand Down Expand Up @@ -405,7 +399,6 @@ def run(cli: "bittensor.cli"):
@staticmethod
def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
wallet = bittensor.wallet(config=cli.config)
subnets: List[bittensor.SubnetInfo] = subtensor.get_all_subnets_info()

bittensor.__console__.print(
"Slashing weight for subnet: {} by amount: {}".format(
Expand All @@ -423,11 +416,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
my_weights = root.weights[my_uid]
my_weights[cli.config.netuid] -= cli.config.amount
my_weights[my_weights < 0] = 0 # Ensure weights don't go negative
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)
all_netuids = np.arange(len(my_weights))

subtensor.root_set_weights(
wallet=wallet,
Expand Down Expand Up @@ -529,23 +518,12 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):

# Parse from string
matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids)))
netuids = (
torch.tensor(matched_netuids, dtype=torch.long)
if use_torch()
else np.array(matched_netuids, dtype=np.int64)
)
netuids = np.array(matched_netuids, dtype=np.int64)

matched_weights = [
float(weight) for weight in re.split(r"[ ,]+", cli.config.weights)
]
weights = (
torch.tensor(matched_weights, dtype=torch.float32)
if use_torch()
else np.array(
matched_weights,
dtype=np.float32,
)
)
weights = np.array(matched_weights, dtype=np.float32)

# Run the set weights operation.
subtensor.root_set_weights(
Expand Down
29 changes: 8 additions & 21 deletions bittensor/extrinsics/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Union, List
import bittensor.utils.weight_utils as weight_utils
from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.registration import torch, legacy_torch_api_compat

logger = logging.getLogger(BITTENSOR_LOGGER_NAME)

Expand Down Expand Up @@ -100,6 +100,7 @@ def root_register_extrinsic(
)


@legacy_torch_api_compat
def set_root_weights_extrinsic(
subtensor: "bittensor.subtensor",
wallet: "bittensor.wallet",
Expand Down Expand Up @@ -133,36 +134,22 @@ def set_root_weights_extrinsic(
"""
# First convert types.
if isinstance(netuids, list):
netuids = (
torch.tensor(netuids, dtype=torch.int64)
if use_torch()
else np.array(netuids, dtype=np.int64)
)
netuids = np.array(netuids, dtype=np.int64)
if isinstance(weights, list):
weights = (
torch.tensor(weights, dtype=torch.float32)
if use_torch()
else np.array(weights, dtype=np.float32)
)
weights = np.array(weights, dtype=np.float32)

# Get weight restrictions.
min_allowed_weights = subtensor.min_allowed_weights(netuid=0)
max_weight_limit = subtensor.max_weight_limit(netuid=0)

# Get non zero values.
non_zero_weight_idx = (
torch.argwhere(weights > 0).squeeze(dim=1)
if use_torch()
else np.argwhere(weights > 0).squeeze(axis=1)
)
non_zero_weight_idx = np.argwhere(weights > 0).squeeze(axis=1)
non_zero_weight_uids = netuids[non_zero_weight_idx]
non_zero_weights = weights[non_zero_weight_idx]
non_zero_weights_size = (
non_zero_weights.numel() if use_torch() else non_zero_weights.size
)
if non_zero_weights_size < min_allowed_weights:
if non_zero_weights.size < min_allowed_weights:
raise ValueError(
"The minimum number of weights required to set weights is {}, got {}".format(
min_allowed_weights, non_zero_weights_size
min_allowed_weights, non_zero_weights.size
)
)

Expand Down
2 changes: 2 additions & 0 deletions bittensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from .utils.balance import Balance
from .utils.registration import POWSolution
from .utils.subtensor import get_subtensor_errors
from .utils.registration import legacy_torch_api_compat


KEY_NONCE: Dict[str, int] = {}
Expand Down Expand Up @@ -2347,6 +2348,7 @@ def make_substrate_call_with_retry():

return make_substrate_call_with_retry()

@legacy_torch_api_compat
def root_set_weights(
self,
wallet: "bittensor.wallet",
Expand Down
94 changes: 94 additions & 0 deletions tests/unit_tests/extrinsics/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,97 @@ def test_set_root_weights_extrinsic(
mock_confirm.assert_called_once()
else:
mock_confirm.assert_not_called()


@pytest.mark.parametrize(
"wait_for_inclusion, wait_for_finalization, netuids, weights, prompt, user_response, expected_success",
[
(True, False, [1, 2], [0.5, 0.5], True, True, True), # Success - weights set
(
False,
False,
[1, 2],
[0.5, 0.5],
False,
None,
True,
), # Success - weights set no wait
(
True,
False,
[1, 2],
[2000, 20],
True,
True,
True,
), # Success - large value to be normalized
(
True,
False,
[1, 2],
[2000, 0],
True,
True,
True,
), # Success - single large value
(
True,
False,
[1, 2],
[0.5, 0.5],
True,
False,
False,
), # Failure - prompt declined
(
True,
False,
[1, 2],
[0.5, 0.5],
False,
None,
False,
), # Failure - setting weights failed
(
True,
False,
[],
[],
None,
False,
False,
), # Exception catched - ValueError 'min() arg is an empty sequence'
],
ids=[
"success-weights-set",
"success-not-wait",
"success-large-value",
"success-single-value",
"failure-user-declines",
"failure-setting-weights",
"failure-value-error-exception",
],
)
def test_set_root_weights_extrinsic_torch(
mock_subtensor,
mock_wallet,
wait_for_inclusion,
wait_for_finalization,
netuids,
weights,
prompt,
user_response,
expected_success,
force_legacy_torch_compat_api,
):
test_set_root_weights_extrinsic(
mock_subtensor,
mock_wallet,
wait_for_inclusion,
wait_for_finalization,
netuids,
weights,
prompt,
user_response,
expected_success,
)
8 changes: 4 additions & 4 deletions tests/unit_tests/utils/test_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ def test_normalize_with_max_weight__legacy_torch_api_compat(
wn = weight_utils.normalize_max_weight(weights, limit=1)
assert torch.isclose(wn, weights / weights.sum(), atol=1e-08, rtol=0).all()

# Check for eplison changes
eplison = 0.01
# Check for epsilon changes
epsilon = 0.01
weights, _ = torch.sort(torch.rand(100))
x = weights / weights.sum()
limit = x[-10]
change = eplison * limit
change = epsilon * limit
y = weight_utils.normalize_max_weight(x, limit=limit - change)
z = weight_utils.normalize_max_weight(x, limit=limit + change)
assert (y - z).abs().sum() < eplison
assert (y - z).abs().sum() < epsilon


@pytest.mark.parametrize(
Expand Down