Skip to content

Commit

Permalink
Merge pull request #872 from opentensor/BIT-540-validator-weight-sett…
Browse files Browse the repository at this point in the history
…ing-improvement

[BIT-540] Choose responsive UIDs for setting weights in validator + validator save/load
  • Loading branch information
opentaco committed Aug 12, 2022
2 parents 4d8bd79 + e54f025 commit e2ca6a9
Showing 1 changed file with 116 additions and 39 deletions.
155 changes: 116 additions & 39 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
import os
import wandb
import math
import random
import pandas
import traceback
from rich import print
from rich.console import Console
from rich.style import Style
from rich.table import Table
from rich.traceback import install
from typing import List, Tuple, Callable, Dict, Any, Union
from typing import List, Tuple, Callable, Dict, Any, Union, Set

from ..neuron_utilities import ThreadQueue, PositionalEncoding, calc_loss_fct
from bittensor.utils.tokenizer_utils import phrase_cross_entropy
Expand Down Expand Up @@ -158,7 +159,8 @@ def __init__(
self.loss_agg_mutex = Lock()

# === Neuron statistics variables ===
self.neuron_stats = {}
self.neuron_stats = {} # neuron statistics dict of dicts: [uid] -> {'stat1': val1, 'stat2': val2, ...}
self.neuron_hotkeys = [] # keep neuron hotkeys to compare and check for changes after metagraph.sync()
self.alpha = 0.05 # EMA coefficient in [0, 1], higher alpha discounts older observations faster

if self.config.neuron.validation_synapse == 'TextCausalLMNext':
Expand All @@ -170,6 +172,10 @@ def __init__(
# stat keys to duplicate (['key']->['key!']) and push zero to its EMA if neuron non-responsive
self.synapse_keys = ['shapley_values_min']

# load last saved validator values from the file system
if not config.neuron.restart:
self.load()

@classmethod
def check_config( cls, config: 'bittensor.Config' ):
r""" Checks/validates the config namespace object.
Expand Down Expand Up @@ -198,6 +204,8 @@ def add_args( cls, parser ):
parser.add_argument('--neuron.validation_len', type=int, help='Number of tokens to holdout for phrase validation beyond sequence context.', default=8)
parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0 )
parser.add_argument('--neuron.print_neuron_stats', action='store_true', help='If True, print neuron_stats and exit.', default=False)
parser.add_argument('--neuron.restart', action='store_true', help='If True, reset neuron_stats and validate anew.', default=False)
parser.add_argument('--neuron.restart_on_failure', action='store_true', help='''Restart neuron on unknown error.''', default=True )
parser.add_argument('--neuron._mock', action='store_true', help='To turn on neuron mocking for testing purposes.', default=False )
parser.add_argument('--neuron.wait_for_finalization', action='store_true', help='''when setting weights the miner waits for trnasaction finalization.''', default=False)
Expand Down Expand Up @@ -261,6 +269,38 @@ def __enter__(self):
root_dir = self.config.neuron.full_path
)

def save(self, path=None):
r""" Save validated hotkeys and neuron_stats to filesystem. """
try:
if path is None:
path = self.config.neuron.full_path

state_dict = {
'neuron_stats': self.neuron_stats,
'neuron_hotkeys': self.neuron_hotkeys
}

torch.save(state_dict, f'{path}/model.torch')
bittensor.logging.success(prefix='Saved model', sufix=f'<blue>{path}/model.torch</blue>')

except Exception as e:
logger.warning(f'Failed to save model with error: {e}')

def load(self, path=None):
r""" Load validated hotkeys and neuron_stats from filesystem. """
try:
if path is None:
path = self.config.neuron.full_path
state_dict = torch.load(f'{path}/model.torch')

self.neuron_stats = state_dict['neuron_stats']
self.neuron_hotkeys = state_dict['neuron_hotkeys']

bittensor.logging.success(prefix='Reloaded model', sufix=f'<blue>{path}/model.torch</blue>')

except Exception as e:
logger.warning(f'Failed to load model with error: {e}')

def run ( self ):
r""" Run the validator and terminate on Keyboard interrupt.
"""
Expand Down Expand Up @@ -304,7 +344,7 @@ def run_epoch( self ):
batch_size = self.subtensor.validator_batch_size
sequence_length = self.subtensor.validator_sequence_length
validation_len = self.config.neuron.validation_len # Number of tokens to holdout for phrase validation beyond sequence context
n_topk_peer_weights = self.subtensor.min_allowed_weights
min_allowed_weights = self.subtensor.min_allowed_weights
max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio
blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch
epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset
Expand All @@ -317,7 +357,7 @@ def run_epoch( self ):
if self.config.using_wandb:
wandb.log({'era/batch_size': batch_size, 'era/sequence_length': sequence_length,
'era/validation_len': validation_len,
'era/n_topk_peer_weights': n_topk_peer_weights, 'era/max_allowed_ratio': max_allowed_ratio,
'era/min_allowed_weights': min_allowed_weights, 'era/max_allowed_ratio': max_allowed_ratio,
'era/blocks_per_epoch': blocks_per_epoch, 'era/epochs_until_reset': epochs_until_reset},
step=current_block)

Expand All @@ -326,7 +366,10 @@ def run_epoch( self ):
# This gives us a consistent network wide timer.
# Here we run until blocks_per_epochs have progressed.
self.metagraph_sync() # Reset metagraph.

epoch_steps = 0
epoch_responsive_uids = set()
epoch_queried_uids = set()

start_block = self.subtensor.block
while self.subtensor.block < start_block + blocks_per_epoch:
Expand All @@ -341,14 +384,17 @@ def run_epoch( self ):
# Backwards gradients through model to train gating and remote endpoints.
if hasattr(loss, 'grad_fn') and loss.grad_fn is not None:
logger.info(f'Backward <dim>(loss: {loss:.3f})</dim>')
start_time = time.time()
bw_start_time = time.time()
(loss / self.config.neuron.forward_num).backward()
logger.info(f'Backward <dim>[{time.time() - start_time:.3g}s]</dim>')
logger.info(f'Backward <dim>[{time.time() - bw_start_time:.3g}s]</dim>')

# === Stats update ===
# Updates moving averages and history.
responsive_uids, queried_uids = self.neuron_stats_update(stats)

epoch_responsive_uids |= set(responsive_uids)
epoch_queried_uids |= set(queried_uids)

# === State update ===
# Prints step logs to screen.
epoch_steps += 1
Expand Down Expand Up @@ -376,6 +422,9 @@ def run_epoch( self ):
f'Stake \u03C4[magenta not bold]{self.metagraph.stake[self.uid]:.5f}[/magenta not bold] '
f'[dim](retrieved [yellow]{current_block - start_block}[/yellow] blocks ago from {self.subtensor.network})[/dim]')

# save neuron_stats to filesystem
self.save()

# step update console message (every validation step)
print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | "
f"{f'[magenta dim not bold]#{current_block}[/magenta dim not bold]'.center(16 + len('[magenta dim not bold][/magenta dim not bold]'))} | "
Expand All @@ -401,8 +450,8 @@ def run_epoch( self ):
f'[white] Step {epoch_steps} ({self.global_step} global) \[{step_time:.3g}s] [/white]') # caption

# === Calculate neuron weights ===
topk_uids, topk_weights = self.calculate_weights()
self.weights_table(topk_uids, topk_weights,
sample_uids, sample_weights = self.calculate_weights(epoch_responsive_uids, epoch_queried_uids)
self.weights_table(sample_uids, sample_weights,
include_uids=list(stats.keys()), num_rows=2 * len(stats)) # print weights table

# === Logs ===
Expand Down Expand Up @@ -431,43 +480,50 @@ def run_epoch( self ):
self.epoch += 1

# === Calculate neuron weights ===
topk_uids, topk_weights = self.calculate_weights()
sample_uids, sample_weights = self.calculate_weights(epoch_responsive_uids, epoch_queried_uids)

if self.config.logging.debug or self.config.logging.trace:
self.weights_table(topk_uids, topk_weights) # print weights table
self.weights_table(sample_uids, sample_weights) # print weights table

self.subtensor.set_weights(
uids = topk_uids.detach().to('cpu'),
weights = topk_weights.detach().to('cpu'),
wallet = self.wallet,
wait_for_finalization = self.config.neuron.wait_for_finalization,
uids=sample_uids.detach().to('cpu'),
weights=sample_weights.detach().to('cpu'),
wallet=self.wallet,
wait_for_finalization=self.config.neuron.wait_for_finalization,
)

# === Wandb Logs ===
# Optionally send validator logs to wandb.
if self.config.using_wandb:
# Logging history to wandb.
df = pandas.concat( [
bittensor.utils.indexed_values_to_dataframe( prefix = 'weights', index = topk_uids, values = torch.zeros( self.metagraph.n ).scatter( dim = 0, src = topk_weights, index = topk_uids ) ),
bittensor.utils.indexed_values_to_dataframe( prefix = 'weights', index = sample_uids, values = torch.zeros( self.metagraph.n ).scatter( dim = 0, src = sample_weights, index = sample_uids ) ),
self.dendrite.to_dataframe( metagraph = self.metagraph )
], axis = 1); df['uid'] = df.index
wandb_data_dend = self.dendrite.to_wandb()
wandb_weight = {f'stats/weight_{uid}': weight for uid, weight in zip (topk_uids, topk_weights)}
wandb_weight = {f'stats/weight_{uid}': weight for uid, weight in zip (sample_uids, sample_weights)}
wandb_data = { 'stake': self.metagraph.S[ self.uid ].item(), 'dividends': self.metagraph.D[ self.uid ].item() }
wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block, commit=False)
wandb.log( { **wandb_data, **wandb_data_dend, **wandb_weight }, step = current_block, commit=True)

def metagraph_sync(self):
r""" Syncing metagraph together with other metagraph-size related objects
"""
old_hotkeys = self.metagraph.hotkeys
old_hotkeys = self.neuron_hotkeys if self.neuron_hotkeys else self.metagraph.hotkeys
self.metagraph.sync()
self.neuron_hotkeys = self.metagraph.hotkeys

changed_hotkeys = []
# === Reset neuron stats if uid got replaced
for uid, old_hotkey in enumerate(old_hotkeys):
if old_hotkey != self.metagraph.hotkeys[uid]:
if old_hotkey != self.neuron_hotkeys[uid]:
if uid in self.neuron_stats:
del self.neuron_stats[uid]
changed_hotkeys += [uid]

if len(changed_hotkeys):
logger.info(f"Hotkeys changed: {changed_hotkeys}")
self.save() # save neuron_stats and neuron_hotkeys to filesystem

def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
r""" Updates self.neuron_stats with new individual dictionaries per uid.
Expand Down Expand Up @@ -513,39 +569,60 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):

return responsive_uids, list(neuron_stats.keys()) # responsive_uids, queried_uids

def calculate_weights(self):
def calculate_weights(self, responsive_uids: Set, queried_uids: Set):
r""" Calculates neuron set-weights from weight_key mapped values. Defines weight_key as the neuron stats key
used to obtain the mapped stat value (typically a Shapley value) that the final set-weights are calculated from.
"""

weight_key = self.weight_key + '!' # use zeroing key to penalize non-responsive neurons
n_topk_peer_weights = self.subtensor.min_allowed_weights

# === Randomize UIDs in preferred order (responsive -> queried -> rest) ===
min_allowed_weights = self.subtensor.min_allowed_weights
max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio

# === Calculate neuron weights ===
neuron_weights = torch.zeros_like(self.metagraph.S) # allow unevaluated UIDs to be selected to meet min topk
non_responsive_uids = queried_uids - responsive_uids
non_queried_uids = set(range(self.metagraph.n)) - queried_uids

# random.sample(population, k, *, counts=None): Return a k length list of unique elements chosen from
# the population sequence or set. Used for random sampling without replacement (so no uid duplicates expected).
preferred_uids = (random.sample(list(responsive_uids), len(responsive_uids)) +
random.sample(list(non_responsive_uids), len(non_responsive_uids)) +
random.sample(list(non_queried_uids), len(non_queried_uids))) # preferred UID random order

preferred_uids = torch.LongTensor(preferred_uids)

# === Populate neuron weights ===
neuron_weights = torch.zeros_like(self.metagraph.S) # allow unevaluated UIDs for min_allowed_weights

for uid in self.neuron_stats:
if weight_key in self.neuron_stats[uid]:
neuron_weights[uid] = torch.tensor([self.neuron_stats[uid][weight_key]])

# Find the n_topk_peer_weights peers to set weights to.
topk_weights, topk_uids = bittensor.unbiased_topk(neuron_weights, k=n_topk_peer_weights)
topk_weights = bittensor.utils.weight_utils.normalize_max_multiple(x=topk_weights,
multiple=max_allowed_ratio)
return topk_uids, topk_weights
# === Filter to non-zero weights ===
neuron_weights = neuron_weights[preferred_uids] # rearrange neuron_weights to match preferred_uids order
preferred_uids = preferred_uids[neuron_weights > 0] # filter to non-zero weights
neuron_weights = neuron_weights[neuron_weights > 0] # filter to non-zero weights

# === Slice min_allowed_weights UIDs ===
sample_uids = preferred_uids[:min_allowed_weights] # slice to min_allowed_weights
sample_weights = neuron_weights[:min_allowed_weights] # slice to min_allowed_weights

# === Normalize and apply max_allowed_ratio ===
sample_weights = bittensor.utils.weight_utils.normalize_max_multiple(x=sample_weights,
multiple=max_allowed_ratio)
return sample_uids, sample_weights

def weights_table(self, topk_uids, topk_weights, include_uids=None, num_rows: int = None):
r""" Prints weights table given topk_uids and topk_weights.
def weights_table(self, sample_uids, sample_weights, include_uids=None, num_rows: int = None):
r""" Prints weights table given sample_uids and sample_weights.
"""
n_topk_peer_weights = self.subtensor.min_allowed_weights
min_allowed_weights = self.subtensor.min_allowed_weights
max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio

# === Weight table ===
# Prints exponential moving average statistics of valid neurons and latest weights
_neuron_stats = {}
unvalidated = []
for uid, weight in zip(topk_uids.tolist(), topk_weights.tolist()):
for uid, weight in zip(sample_uids.tolist(), sample_weights.tolist()):
if uid in self.neuron_stats:
_neuron_stats[uid] = {k: v for k, v in self.neuron_stats[uid].items()}
_neuron_stats[uid]['weight'] = weight
Expand All @@ -555,21 +632,21 @@ def weights_table(self, topk_uids, topk_weights, include_uids=None, num_rows: in
avail_include_uids = None
if include_uids is not None and num_rows is not None:
avail_include_uids = list(set(_neuron_stats.keys()) & set(include_uids)) # exclude include_uids with no stats
if len(_neuron_stats) > num_rows: # limit table to included_uids and remaining topk up to num_rows
remaining_uids = set(_neuron_stats.keys()) - set(include_uids) # find topk remaining, loses topk ordering
remaining_uids = [uid for uid in _neuron_stats if uid in remaining_uids] # recover topk ordering
if len(_neuron_stats) > num_rows: # limit table to included_uids and remaining sample up to num_rows
remaining_uids = set(_neuron_stats.keys()) - set(include_uids) # find sample remaining, loses sample ordering
remaining_uids = [uid for uid in _neuron_stats if uid in remaining_uids] # recover sample ordering
limited_uids = avail_include_uids + remaining_uids[:num_rows - len(include_uids)]
_neuron_stats = {uid: stats for uid, stats in _neuron_stats.items() if uid in limited_uids}

print()
stats_table(_neuron_stats, 'weight', self.config.get('width', None),
f'[white] Neuron weights [/white] | ' + str(self), # title
f'Validated {n_topk_peer_weights}/'
f'Validated {min_allowed_weights}/'
f'[bold]{len(self.neuron_stats)}[/bold]/{self.metagraph.n} (min/[bold]valid[/bold]/total) | '
f'sum:{topk_weights.sum().item():.2g} '
f'[white] max:[bold]{topk_weights.max().item():.4g}[/bold] / '
f'min:[bold]{topk_weights.min().item():.4g}[/bold] [/white] '
f'\[{topk_weights.max().item() / topk_weights.min().item():.1f}:1] '
f'sum:{sample_weights.sum().item():.2g} '
f'[white] max:[bold]{sample_weights.max().item():.4g}[/bold] / '
f'min:[bold]{sample_weights.min().item():.4g}[/bold] [/white] '
f'\[{sample_weights.max().item() / sample_weights.min().item():.1f}:1] '
f'({max_allowed_ratio} allowed)', # caption
mark_uids=avail_include_uids)

Expand Down

0 comments on commit e2ca6a9

Please sign in to comment.