Skip to content

Commit

Permalink
fix PoW and other functions not working with USE_TORCH=0 despite torc…
Browse files Browse the repository at this point in the history
…h being available
  • Loading branch information
mjurbanski-reef committed May 22, 2024
1 parent 87df079 commit 85cdc00
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 71 deletions.
14 changes: 10 additions & 4 deletions bittensor/extrinsics/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import time
from rich.prompt import Confirm
from typing import List, Union, Optional, Tuple
from bittensor.utils.registration import POWSolution, create_pow, torch, use_torch
from bittensor.utils.registration import (
POWSolution,
create_pow,
torch,
log_no_torch_error,
)


def register_extrinsic(
Expand Down Expand Up @@ -100,7 +105,8 @@ def register_extrinsic(
):
return False

if not use_torch():
if not torch:
log_no_torch_error()
return False

# Attempt rolling registration.
Expand Down Expand Up @@ -380,8 +386,8 @@ def run_faucet_extrinsic(
):
return False, ""

if not use_torch():
torch.error()
if not torch:
log_no_torch_error()
return False, "Requires torch"

# Unlock coldkey
Expand Down
67 changes: 26 additions & 41 deletions bittensor/utils/registration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import binascii
import functools
import hashlib
import math
import multiprocessing
import os
import random
import time
import typing
from dataclasses import dataclass
from datetime import timedelta
from queue import Empty, Full
Expand All @@ -19,62 +21,45 @@
from .formatting import get_human_readable, millify
from ._register_cuda import solve_cuda

try:
import torch
except ImportError:
torch = None


def use_torch() -> bool:
"""Force the use of torch over numpy for certain operations."""
return True if os.getenv("USE_TORCH") == "1" else False


class Torch:
def __init__(self):
self._transformed = False

@staticmethod
def _error():
bittensor.logging.warning(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)
@functools.cache
def _get_real_torch():
try:
import torch as _real_torch
except ImportError:
_real_torch = None
return _real_torch

def error(self):
self._error()

def _transform(self):
try:
import torch as real_torch
def log_no_torch_error():
bittensor.btlogging.error(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)

self.__dict__.update(real_torch.__dict__)
self._transformed = True
except ImportError:
self._error()

class LazyLoadedTorch:
def __bool__(self):
return False
return bool(_get_real_torch())

def __getattr__(self, name):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return getattr(self, name)
if real_torch := _get_real_torch():
return getattr(real_torch, name)
else:
self._error()
log_no_torch_error()
raise ImportError("torch not installed")

def __call__(self, *args, **kwargs):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return self(*args, **kwargs)
else:
self._error()


if not torch or not use_torch():
torch = Torch()
if typing.TYPE_CHECKING:
import torch
else:
torch = LazyLoadedTorch()


class CUDAException(Exception):
Expand Down
11 changes: 5 additions & 6 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# To use Torch functionality in bittensor, you must set the USE_TORCH flag to 1:
USE_TORCH=1

# If set to 0 (or anything else), you will use the numpy functions.
# This is generally what you want unless you have a specific reason for using torch
# such as POW registration or legacy interoperability.
# To use legacy Torch-based of bittensor, you must set USE_TORCH=1
USE_TORCH=0
# If set to 0 (or anything else than 1), it will use current, numpy-based, bittensor interface.
# This is generally what you want unless you want legacy interoperability.
# Please note that the legacy interface is deprecated, and is not tested nearly as much.
2 changes: 0 additions & 2 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,7 +2090,6 @@ def test_register(self, _):
def test_pow_register(self, _):
# Not the best way to do this, but I need to finish these tests, and unittest doesn't make this
# as simple as pytest
os.environ["USE_TORCH"] = "1"
config = self.config
config.command = "subnets"
config.subcommand = "pow_register"
Expand All @@ -2114,7 +2113,6 @@ class MockException(Exception):
mock_create_wallet.assert_called_once()

self.assertEqual(mock_is_stale.call_count, 1)
del os.environ["USE_TORCH"]

def test_stake(self, _):
amount_to_stake: Balance = Balance.from_tao(0.5)
Expand Down
8 changes: 0 additions & 8 deletions tests/integration_tests/test_subtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def test_is_hotkey_registered_not_registered(self):
self.assertFalse(registered, msg="Hotkey should not be registered")

def test_registration_multiprocessed_already_registered(self):
os.environ["USE_TORCH"] = "1"
workblocks_before_is_registered = random.randint(5, 10)
# return False each work block but return True after a random number of blocks
is_registered_return_values = (
Expand Down Expand Up @@ -477,10 +476,8 @@ def test_registration_multiprocessed_already_registered(self):
self.subtensor.is_hotkey_registered.call_count
== workblocks_before_is_registered + 2
)
del os.environ["USE_TORCH"]

def test_registration_partly_failed(self):
os.environ["USE_TORCH"] = "1"
do_pow_register_mock = MagicMock(
side_effect=[(False, "Failed"), (False, "Failed"), (True, None)]
)
Expand Down Expand Up @@ -514,10 +511,8 @@ def is_registered_side_effect(*args, **kwargs):
),
msg="Registration should succeed",
)
del os.environ["USE_TORCH"]

def test_registration_failed(self):
os.environ["USE_TORCH"] = "1"
is_registered_return_values = [False for _ in range(100)]
current_block = [i for i in range(0, 100)]
mock_neuron = MagicMock()
Expand Down Expand Up @@ -551,11 +546,9 @@ def test_registration_failed(self):
msg="Registration should fail",
)
self.assertEqual(mock_create_pow.call_count, 3)
del os.environ["USE_TORCH"]

def test_registration_stale_then_continue(self):
# verify that after a stale solution, the solve will continue without exiting
os.environ["USE_TORCH"] = "1"

class ExitEarly(Exception):
pass
Expand Down Expand Up @@ -596,7 +589,6 @@ class ExitEarly(Exception):
1,
msg="only tries to submit once, then exits",
)
del os.environ["USE_TORCH"]

def test_defaults_to_finney(self):
sub = bittensor.subtensor()
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def mock_new_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion,wait_for_finalization,prompt,cuda,dev_id,tpb,num_processes,update_interval,log_verbose,expected",
[
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def mock_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion, wait_for_finalization, hotkey_registered, registration_success, prompt, user_response, expected_result",
[
Expand Down
45 changes: 45 additions & 0 deletions tests/unit_tests/utils/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from bittensor.utils.registration import LazyLoadedTorch


class MockBittensorLogging:
def __init__(self):
self.messages = []

def error(self, message):
self.messages.append(message)


@pytest.fixture
def mock_bittensor_logging(monkeypatch):
mock_logger = MockBittensorLogging()
monkeypatch.setattr("bittensor.btlogging", mock_logger)
return mock_logger


def test_lazy_loaded_torch__torch_installed(monkeypatch, mock_bittensor_logging):
import torch

lazy_torch = LazyLoadedTorch()

assert bool(torch) is True

assert lazy_torch.nn is torch.nn
with pytest.raises(AttributeError):
lazy_torch.no_such_thing


def test_lazy_loaded_torch__no_torch(monkeypatch, mock_bittensor_logging):
monkeypatch.setattr("bittensor.utils.registration._get_real_torch", lambda: None)

torch = LazyLoadedTorch()

assert not torch

with pytest.raises(ImportError):
torch.some_attribute

# Check if the error message is logged correctly
assert len(mock_bittensor_logging.messages) == 1
assert "This command requires torch." in mock_bittensor_logging.messages[0]

0 comments on commit 85cdc00

Please sign in to comment.