Skip to content

Commit

Permalink
[Hotfix] Fix CUDA Reg update block (#954)
Browse files Browse the repository at this point in the history
* bump version

* fix block update

* .

* verify new helper

* remove uneeded comment
  • Loading branch information
camfairchild committed Oct 13, 2022
1 parent f8387ca commit eb7aea2
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 49 deletions.
2 changes: 1 addition & 1 deletion bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from prometheus_client import Info

# Bittensor code and protocol version.
__version__ = '3.4.0'
__version__ = '3.4.1'
version_split = __version__.split(".")
__version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))

Expand Down
133 changes: 85 additions & 48 deletions bittensor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,25 +530,17 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True,
pass

# check for new block
block_number = subtensor.get_current_block()
if block_number != old_block_number:
old_block_number = block_number
# update block information
block_hash = subtensor.substrate.get_block_hash( block_number)
while block_hash == None:
block_hash = subtensor.substrate.get_block_hash( block_number)
block_bytes = block_hash.encode('utf-8')[2:]
difficulty = subtensor.difficulty

update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block)
# Set new block events for each solver
for worker in solvers:
worker.newBlockEvent.set()

# update stats
curr_stats.block_number = block_number
curr_stats.block_hash = block_hash
curr_stats.difficulty = difficulty
old_block_number = check_for_newest_block_and_update(
subtensor = subtensor,
old_block_number=old_block_number,
curr_diff=curr_diff,
curr_block=curr_block,
curr_block_num=curr_block_num,
curr_stats=curr_stats,
update_curr_block=update_curr_block,
check_block=check_block,
solvers=solvers
)

num_time = 0
for _ in range(len(solvers)*2):
Expand Down Expand Up @@ -636,6 +628,66 @@ def __exit__(self, *args):
# restore the old start method
multiprocessing.set_start_method(self._old_start_method, force=True)

def check_for_newest_block_and_update(
subtensor: 'bittensor.Subtensor',
old_block_number: int,
curr_diff: multiprocessing.Array,
curr_block: multiprocessing.Array,
curr_block_num: multiprocessing.Value,
update_curr_block: Callable,
check_block: 'multiprocessing.Lock',
solvers: List[Solver],
curr_stats: RegistrationStatistics
) -> int:
"""
Checks for a new block and updates the current block information if a new block is found.
Args:
subtensor (:obj:`bittensor.Subtensor`, `required`):
The subtensor object to use for getting the current block.
old_block_number (:obj:`int`, `required`):
The old block number to check against.
curr_diff (:obj:`multiprocessing.Array`, `required`):
The current difficulty as a multiprocessing array.
curr_block (:obj:`multiprocessing.Array`, `required`):
Where the current block is stored as a multiprocessing array.
curr_block_num (:obj:`multiprocessing.Value`, `required`):
Where the current block number is stored as a multiprocessing value.
update_curr_block (:obj:`Callable`, `required`):
A function that updates the current block.
check_block (:obj:`multiprocessing.Lock`, `required`):
A mp lock that is used to check for a new block.
solvers (:obj:`List[Solver]`, `required`):
A list of solvers to update the current block for.
curr_stats (:obj:`RegistrationStatistics`, `required`):
The current registration statistics to update.
Returns:
(int) The current block number.
"""
block_number = subtensor.get_current_block()
if block_number != old_block_number:
old_block_number = block_number
# update block information
block_hash = subtensor.substrate.get_block_hash( block_number)
while block_hash == None:
block_hash = subtensor.substrate.get_block_hash( block_number)
block_bytes = block_hash.encode('utf-8')[2:]
difficulty = subtensor.difficulty

update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block)
# Set new block events for each solver

for worker in solvers:
worker.newBlockEvent.set()

# update stats
curr_stats.block_number = block_number
curr_stats.block_hash = block_hash
curr_stats.difficulty = difficulty

return old_block_number


def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'bittensor.Wallet', output_in_place: bool = True, update_interval: int = 50_000, TPB: int = 512, dev_id: Union[List[int], int] = 0, n_samples: int = 5, alpha_: float = 0.70, log_verbose: bool = False ) -> Optional[POWSolution]:
"""
Expand Down Expand Up @@ -680,13 +732,6 @@ def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'b
curr_block_num = multiprocessing.Value('i', 0, lock=True) # int
curr_diff = multiprocessing.Array('Q', [0, 0], lock=True) # [high, low]

def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: multiprocessing.Lock):
with lock:
curr_block_num.value = block_number
for i in range(64):
curr_block[i] = block_bytes[i]
registration_diff_pack(diff, curr_diff)

# Establish communication queues
stopEvent = multiprocessing.Event()
stopEvent.clear()
Expand All @@ -712,7 +757,7 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu
old_block_number = block_number

# Set to current block
update_curr_block(block_number, block_bytes, difficulty, check_block)
update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block)

# Set new block events for each solver to start at the initial block
for worker in solvers:
Expand Down Expand Up @@ -755,27 +800,19 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu
except Empty:
# No solution found, try again
pass

if block_number != old_block_number:
old_block_number = block_number
# update block information
block_hash = subtensor.substrate.get_block_hash( block_number)
while block_hash == None:
block_hash = subtensor.substrate.get_block_hash( block_number)
block_bytes = block_hash.encode('utf-8')[2:]
difficulty = subtensor.difficulty

update_curr_block(block_number, block_bytes, difficulty, check_block)
# Set new block events for each solver

for worker in solvers:
worker.newBlockEvent.set()


# update stats
curr_stats.block_number = block_number
curr_stats.block_hash = block_hash
curr_stats.difficulty = difficulty

# check for new block
old_block_number = check_for_newest_block_and_update(
subtensor = subtensor,
curr_diff=curr_diff,
curr_block=curr_block,
curr_block_num=curr_block_num,
old_block_number=old_block_number,
curr_stats=curr_stats,
update_curr_block=update_curr_block,
check_block=check_block,
solvers=solvers
)

num_time = 0
# Get times for each solver
Expand Down
82 changes: 82 additions & 0 deletions tests/unit_tests/bittensor_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,88 @@ def test_registration_diff_pack_unpack_over_32_bits():
bittensor.utils.registration_diff_pack(fake_diff, mock_diff)
assert bittensor.utils.registration_diff_unpack(mock_diff) == fake_diff

class TestUpdateCurrentBlockDuringRegistration(unittest.TestCase):
def test_check_for_newest_block_and_update_same_block(self):
# if the block is the same, the function should return the same block number
subtensor = MagicMock()
current_block_num: int = 1
subtensor.get_current_block = MagicMock( return_value=current_block_num )

self.assertEqual(bittensor.utils.check_for_newest_block_and_update(
subtensor,
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
), current_block_num)

def test_check_for_newest_block_and_update_new_block(self):
# if the block is new, the function should return the new block_number
mock_block_hash = '0xba7ea4eb0b16dee271dbef5911838c3f359fcf598c74da65a54b919b68b67279'

current_block_num: int = 1
current_diff: int = 0

mock_substrate = MagicMock(
get_block_hash=MagicMock(
return_value=mock_block_hash
),

)
subtensor = MagicMock(
substrate=mock_substrate,
difficulty=current_diff + 1, # new diff
)
subtensor.get_current_block = MagicMock( return_value=current_block_num + 1 ) # new block

mock_update_curr_block = MagicMock()

mock_solvers = [
MagicMock(
newBlockEvent=MagicMock(
set=MagicMock()
)
),
MagicMock(
newBlockEvent=MagicMock(
set=MagicMock()
)
)]

mock_curr_stats = MagicMock(
block_number=current_block_num,
block_hash=b'',
difficulty=0,
)

self.assertEqual(bittensor.utils.check_for_newest_block_and_update(
subtensor,
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
mock_update_curr_block,
MagicMock(),
mock_solvers,
mock_curr_stats,
), current_block_num + 1)

# check that the update_curr_block function was called
mock_update_curr_block.assert_called_once()

# check that the solvers got the event
for solver in mock_solvers:
solver.newBlockEvent.set.assert_called_once()

# check the stats were updated
self.assertEqual(mock_curr_stats.block_number, current_block_num + 1)
self.assertEqual(mock_curr_stats.block_hash, mock_block_hash)
self.assertEqual(mock_curr_stats.difficulty, current_diff + 1)

class TestGetBlockWithRetry(unittest.TestCase):
def test_get_block_with_retry_network_error_exit(self):
mock_subtensor = MagicMock(
Expand Down

0 comments on commit eb7aea2

Please sign in to comment.