Skip to content

Commit

Permalink
Merge pull request #1917 from backend-developers-ltd/improve_no_torch
Browse files Browse the repository at this point in the history
fix PoW and other functions not working with USE_TORCH=0 despite torch being available
  • Loading branch information
thewhaleking committed May 22, 2024
2 parents bd4bf0b + 03303de commit f54157f
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 89 deletions.
14 changes: 10 additions & 4 deletions bittensor/extrinsics/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import time
from rich.prompt import Confirm
from typing import List, Union, Optional, Tuple
from bittensor.utils.registration import POWSolution, create_pow, torch, use_torch
from bittensor.utils.registration import (
POWSolution,
create_pow,
torch,
log_no_torch_error,
)


def register_extrinsic(
Expand Down Expand Up @@ -100,7 +105,8 @@ def register_extrinsic(
):
return False

if not use_torch():
if not torch:
log_no_torch_error()
return False

# Attempt rolling registration.
Expand Down Expand Up @@ -380,8 +386,8 @@ def run_faucet_extrinsic(
):
return False, ""

if not use_torch():
torch.error()
if not torch:
log_no_torch_error()
return False, "Requires torch"

# Unlock coldkey
Expand Down
101 changes: 61 additions & 40 deletions bittensor/utils/registration.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import binascii
import functools
import hashlib
import math
import multiprocessing
import os
import random
import time
import typing
from dataclasses import dataclass
from datetime import timedelta
from queue import Empty, Full
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import backoff
import numpy

import bittensor
from Crypto.Hash import keccak
from rich import console as rich_console
Expand All @@ -19,62 +23,79 @@
from .formatting import get_human_readable, millify
from ._register_cuda import solve_cuda

try:
import torch
except ImportError:
torch = None


def use_torch() -> bool:
"""Force the use of torch over numpy for certain operations."""
return True if os.getenv("USE_TORCH") == "1" else False


class Torch:
def __init__(self):
self._transformed = False
def legacy_torch_api_compat(func):
"""
Convert function operating on numpy Input&Output to legacy torch Input&Output API if `use_torch()` is True.
@staticmethod
def _error():
bittensor.logging.warning(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)
Args:
func (function):
Function with numpy Input/Output to be decorated.
Returns:
decorated (function):
Decorated function.
"""

@functools.wraps(func)
def decorated(*args, **kwargs):
if use_torch():
# if argument is a Torch tensor, convert it to numpy
args = [
arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
for arg in args
]
kwargs = {
key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
for key, value in kwargs.items()
}
ret = func(*args, **kwargs)
if use_torch():
# if return value is a numpy array, convert it to Torch tensor
if isinstance(ret, numpy.ndarray):
ret = torch.from_numpy(ret)
return ret

return decorated


@functools.cache
def _get_real_torch():
try:
import torch as _real_torch
except ImportError:
_real_torch = None
return _real_torch

def error(self):
self._error()

def _transform(self):
try:
import torch as real_torch
def log_no_torch_error():
bittensor.btlogging.error(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)

self.__dict__.update(real_torch.__dict__)
self._transformed = True
except ImportError:
self._error()

class LazyLoadedTorch:
def __bool__(self):
return False
return bool(_get_real_torch())

def __getattr__(self, name):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return getattr(self, name)
else:
self._error()

def __call__(self, *args, **kwargs):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return self(*args, **kwargs)
if real_torch := _get_real_torch():
return getattr(real_torch, name)
else:
self._error()
log_no_torch_error()
raise ImportError("torch not installed")


if not torch or not use_torch():
torch = Torch()
if typing.TYPE_CHECKING:
import torch
else:
torch = LazyLoadedTorch()


class CUDAException(Exception):
Expand Down
26 changes: 7 additions & 19 deletions bittensor/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import bittensor
from numpy.typing import NDArray
from typing import Tuple, List, Union
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.registration import torch, use_torch, legacy_torch_api_compat

U32_MAX = 4294967295
U16_MAX = 65535


@legacy_torch_api_compat
def normalize_max_weight(
x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1
) -> Union[NDArray[np.float32], "torch.FloatTensor"]:
Expand All @@ -43,14 +44,8 @@ def normalize_max_weight(
"""
epsilon = 1e-7 # For numerical stability after normalization

weights = x.clone() if use_torch() else x.copy()
if use_torch():
values, _ = torch.sort(weights)
else:
values = np.sort(weights)

if use_torch() and x.sum() == 0 or len(x) * limit <= 1:
return torch.ones_like(x) / x.size(0)
weights = x.copy()
values = np.sort(weights)

if x.sum() == 0 or x.shape[0] * limit <= 1:
return np.ones_like(x) / x.shape[0]
Expand All @@ -61,18 +56,11 @@ def normalize_max_weight(
return weights / weights.sum()

# Find the cumlative sum and sorted tensor
cumsum = (
torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0)
)
cumsum = np.cumsum(estimation, 0)

# Determine the index of cutoff
estimation_sum_data = [
(len(values) - i - 1) * estimation[i] for i in range(len(values))
]
estimation_sum = (
torch.tensor(estimation_sum_data)
if use_torch()
else np.array(estimation_sum_data)
estimation_sum = np.array(
[(len(values) - i - 1) * estimation[i] for i in range(len(values))]
)
n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum()

Expand Down
11 changes: 5 additions & 6 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# To use Torch functionality in bittensor, you must set the USE_TORCH flag to 1:
USE_TORCH=1

# If set to 0 (or anything else), you will use the numpy functions.
# This is generally what you want unless you have a specific reason for using torch
# such as POW registration or legacy interoperability.
# To use legacy Torch-based of bittensor, you must set USE_TORCH=1
USE_TORCH=0
# If set to 0 (or anything else than 1), it will use current, numpy-based, bittensor interface.
# This is generally what you want unless you want legacy interoperability.
# Please note that the legacy interface is deprecated, and is not tested nearly as much.
2 changes: 0 additions & 2 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,7 +2090,6 @@ def test_register(self, _):
def test_pow_register(self, _):
# Not the best way to do this, but I need to finish these tests, and unittest doesn't make this
# as simple as pytest
os.environ["USE_TORCH"] = "1"
config = self.config
config.command = "subnets"
config.subcommand = "pow_register"
Expand All @@ -2114,7 +2113,6 @@ class MockException(Exception):
mock_create_wallet.assert_called_once()

self.assertEqual(mock_is_stale.call_count, 1)
del os.environ["USE_TORCH"]

def test_stake(self, _):
amount_to_stake: Balance = Balance.from_tao(0.5)
Expand Down
8 changes: 0 additions & 8 deletions tests/integration_tests/test_subtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def test_is_hotkey_registered_not_registered(self):
self.assertFalse(registered, msg="Hotkey should not be registered")

def test_registration_multiprocessed_already_registered(self):
os.environ["USE_TORCH"] = "1"
workblocks_before_is_registered = random.randint(5, 10)
# return False each work block but return True after a random number of blocks
is_registered_return_values = (
Expand Down Expand Up @@ -477,10 +476,8 @@ def test_registration_multiprocessed_already_registered(self):
self.subtensor.is_hotkey_registered.call_count
== workblocks_before_is_registered + 2
)
del os.environ["USE_TORCH"]

def test_registration_partly_failed(self):
os.environ["USE_TORCH"] = "1"
do_pow_register_mock = MagicMock(
side_effect=[(False, "Failed"), (False, "Failed"), (True, None)]
)
Expand Down Expand Up @@ -514,10 +511,8 @@ def is_registered_side_effect(*args, **kwargs):
),
msg="Registration should succeed",
)
del os.environ["USE_TORCH"]

def test_registration_failed(self):
os.environ["USE_TORCH"] = "1"
is_registered_return_values = [False for _ in range(100)]
current_block = [i for i in range(0, 100)]
mock_neuron = MagicMock()
Expand Down Expand Up @@ -551,11 +546,9 @@ def test_registration_failed(self):
msg="Registration should fail",
)
self.assertEqual(mock_create_pow.call_count, 3)
del os.environ["USE_TORCH"]

def test_registration_stale_then_continue(self):
# verify that after a stale solution, the solve will continue without exiting
os.environ["USE_TORCH"] = "1"

class ExitEarly(Exception):
pass
Expand Down Expand Up @@ -596,7 +589,6 @@ class ExitEarly(Exception):
1,
msg="only tries to submit once, then exits",
)
del os.environ["USE_TORCH"]

def test_defaults_to_finney(self):
sub = bittensor.subtensor()
Expand Down
5 changes: 5 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from aioresponses import aioresponses


@pytest.fixture
def force_legacy_torch_compat_api(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.fixture
def mock_aioresponse():
with aioresponses() as m:
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def mock_new_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion,wait_for_finalization,prompt,cuda,dev_id,tpb,num_processes,update_interval,log_verbose,expected",
[
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def mock_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion, wait_for_finalization, hotkey_registered, registration_success, prompt, user_response, expected_result",
[
Expand Down
45 changes: 45 additions & 0 deletions tests/unit_tests/utils/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from bittensor.utils.registration import LazyLoadedTorch


class MockBittensorLogging:
def __init__(self):
self.messages = []

def error(self, message):
self.messages.append(message)


@pytest.fixture
def mock_bittensor_logging(monkeypatch):
mock_logger = MockBittensorLogging()
monkeypatch.setattr("bittensor.btlogging", mock_logger)
return mock_logger


def test_lazy_loaded_torch__torch_installed(monkeypatch, mock_bittensor_logging):
import torch

lazy_torch = LazyLoadedTorch()

assert bool(torch) is True

assert lazy_torch.nn is torch.nn
with pytest.raises(AttributeError):
lazy_torch.no_such_thing


def test_lazy_loaded_torch__no_torch(monkeypatch, mock_bittensor_logging):
monkeypatch.setattr("bittensor.utils.registration._get_real_torch", lambda: None)

torch = LazyLoadedTorch()

assert not torch

with pytest.raises(ImportError):
torch.some_attribute

# Check if the error message is logged correctly
assert len(mock_bittensor_logging.messages) == 1
assert "This command requires torch." in mock_bittensor_logging.messages[0]
Loading

0 comments on commit f54157f

Please sign in to comment.