Skip to content

Commit

Permalink
Hotfix/3.6.2/validator logit parameters (#1057)
Browse files Browse the repository at this point in the history
* additional parameters

* fixed naming to logit divergence

* versioning and fixes

* typo fixes

* bug fixes

* Tests cli fixes (#1058)

* fix btcli list with wallet.path (#1036)

fix path join

* remove mock subtensor and replace with mock calls

* additional fixes

* mock wallet

Co-authored-by: Cameron Fairchild <cameron@opentensor.ai>

* Log prune_len and logits_divergence

* Always get latest prune_len

Co-authored-by: Cameron Fairchild <cameron@opentensor.ai>
Co-authored-by: opentaco <opentaco@protonmail.com>
  • Loading branch information
3 people committed Jan 19, 2023
1 parent 6551641 commit 5557396
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 183 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.1
3.6.2
2 changes: 1 addition & 1 deletion bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
nest_asyncio.apply()

# Bittensor code and protocol version.
__version__ = '3.6.1'
__version__ = '3.6.2'
version_split = __version__.split(".")
__version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))

Expand Down
2 changes: 1 addition & 1 deletion bittensor/_cli/cli_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def list(self):
coldkeypub_str = '?'

wallet_tree = root.add("\n[bold white]{} ({})".format(w_name, coldkeypub_str))
hotkeys_path = self.config.wallet.path + w_name + '/hotkeys'
hotkeys_path = os.path.join(self.config.wallet.path, w_name, 'hotkeys')
try:
hotkeys = next(os.walk(os.path.expanduser(hotkeys_path)))
if len( hotkeys ) > 1:
Expand Down
36 changes: 23 additions & 13 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
self.device = torch.device ( device = self.config.neuron.device )
self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device )
self.dataset = (bittensor.dataset(config=self.config, batch_size=self.subtensor.validator_batch_size,
block_size=self.subtensor.validator_sequence_length + self.config.neuron.validation_len)
block_size=self.subtensor.validator_sequence_length + self.config.neuron.validation_len + self.subtensor.prune_len)
if dataset is None else dataset)
self.optimizer = torch.optim.SGD(
self.nucleus.parameters(), lr=self.config.neuron.learning_rate, momentum=self.config.neuron.momentum
Expand Down Expand Up @@ -234,7 +234,7 @@ def add_args( cls, parser ):
parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch, -1 value means we use the chain value.', default = -1 )
parser.add_argument('--neuron.epochs_until_reset', type=int, help='Number of epochs before weights are reset.', default = -1 )
parser.add_argument('--neuron.validation_len', type=int, help='Number of tokens to holdout for phrase validation beyond sequence context.', default=8)
parser.add_argument('--neuron.prune_len', type=int, help='Number of tokens to prune from each validation input sequence.', default=1)
parser.add_argument('--neuron.prune_len', type=int, help='Number of tokens to prune from each validation input sequence. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0 )
parser.add_argument('--neuron.track_hotkey_changes', action='store_true', help='If True, track hotkey changes.', default=False)
Expand Down Expand Up @@ -400,13 +400,15 @@ def run_epoch( self ):
batch_size = self.subtensor.validator_batch_size
sequence_length = self.subtensor.validator_sequence_length
validation_len = self.config.neuron.validation_len # Number of tokens to holdout for phrase validation beyond sequence context
prune_len = self.config.neuron.prune_len # Number of tokens to holdout for phrase validation beyond sequence context
# Number of tokens to prune for phrase validation beyond sequence context
prune_len = self.config.neuron.prune_len = self.subtensor.prune_len
min_allowed_weights = self.subtensor.min_allowed_weights
max_weight_limit = self.subtensor.max_weight_limit
blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch
epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset
self.config.nucleus.scaling_law_power = self.subtensor.scaling_law_power
self.config.nucleus.synergy_scaling_law_power = self.subtensor.synergy_scaling_law_power
self.config.nucleus.logits_divergence = self.subtensor.logits_divergence

# === Logs Prometheus ===
self.prometheus_gauges.labels("current_block").set( current_block )
Expand Down Expand Up @@ -688,7 +690,7 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):

if 'logits_excess_nxt' in stats:
# penalize by logits divergence excess
extra_stats['shapley_values_nxt'] /= 1 + stats['logits_excess_nxt']
extra_stats['shapley_values_nxt'] /= 1 + self.config.nucleus.logits_divergence * stats['logits_excess_nxt']

# === EMA zeroing update ===
# Push zero into EMA for synapse_keys to exponentially decay weighting keys if neuron non-responsive
Expand Down Expand Up @@ -825,6 +827,7 @@ def __init__( self, config, device, subtensor ):

self.config.nucleus.scaling_law_power = subtensor.scaling_law_power if self.config.nucleus.scaling_law_power == -1 else self.config.nucleus.scaling_law_power
self.config.nucleus.synergy_scaling_law_power = subtensor.synergy_scaling_law_power if self.config.nucleus.synergy_scaling_law_power == -1 else self.config.nucleus.synergy_scaling_law_power
self.config.nucleus.logits_divergence = subtensor.logits_divergence if self.config.nucleus.logits_divergence == -1 else self.config.nucleus.logits_divergence

self.device = device
self.max_n = subtensor.max_n
Expand Down Expand Up @@ -872,6 +875,7 @@ def add_args( cls, parser ):
parser.add_argument('--nucleus.no_dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--nucleus.logits_divergence', type=float, help=' the divergence value for logit anomaly detection (default value: -1, pulling from subtensor directly)', default=-1)

@classmethod
def config ( cls ):
Expand Down Expand Up @@ -983,7 +987,7 @@ def forward(
num_endpoints = len(random_endpoints) # in case len(self.permute_uids) < num_endpoints during random_uids select

logger.info(f'Forward \t| Routing forward <dim>[{time.time() - start_time:.3g}s]</dim>')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)} (prune_len={prune_len})')
request_start_time = time.time()

# === Define which synapse we want to use ===
Expand Down Expand Up @@ -1028,6 +1032,7 @@ def forward(
validation_params = (random_uids, query_responses, return_ops, times, routing_score,
inputs, val_len, self.loss_fct,
self.config.nucleus.scaling_law_power, self.config.nucleus.synergy_scaling_law_power,
self.config.nucleus.logits_divergence,
console_width, self.config.logging.debug or self.config.logging.trace)

loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses
Expand Down Expand Up @@ -1057,7 +1062,7 @@ def scaling_law_loss_to_params(loss):
def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
scaling_law_power: float, synergy_scaling_law_power: float, logits_divergence_penalty: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Expand All @@ -1084,6 +1089,8 @@ def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTenso
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
logits_divergence_penalty (:obj:`float`, `required`):
Penalty scaling for logits divergence.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -1135,7 +1142,7 @@ def _synergy(first, second, target, _ext):
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='')

logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f}) '
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

synergy_start_time = time.time()
Expand All @@ -1162,7 +1169,7 @@ def _synergy(first, second, target, _ext):
if hasattr(s[key], 'item'):
s[key] = s[key].item()

logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f}) '
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

if logging:
Expand All @@ -1184,8 +1191,8 @@ def _synergy(first, second, target, _ext):
def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0
scaling_law_power: float, synergy_scaling_law_power: float, logits_divergence_penalty: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0,
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Calculate Shapley values and neuron response validation measure statistics, given TextCausalLMNext synapse responses.
Expand All @@ -1211,6 +1218,8 @@ def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatT
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
logits_divergence_penalty (:obj:`float`, `required`):
Penalty scaling for logits divergence.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -1250,17 +1259,18 @@ def _synergy(first, second, target, ext):
shapley_start_time = time.time()
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='_nxt')
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f}) '
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

divergence_start_time = time.time()
with torch.no_grad():
logits_divergence(stats, uids, query_responses, return_ops, times, index_s, ext='_nxt')
logger.info(f'{str(synapse)} \t| Logits divergences <dim>[{time.time() - divergence_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Logits divergences (penalty={logits_divergence_penalty}) '
f'<dim>[{time.time() - divergence_start_time:.3g}s]</dim>')

synergy_start_time = time.time()
syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=synergy_scaling_law_power)
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f}) '
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

# === Shapley value combination ===
Expand Down
26 changes: 26 additions & 0 deletions bittensor/_subtensor/subtensor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,32 @@ def make_substrate_call_with_retry():
).value
return make_substrate_call_with_retry()

@property
def prune_len (self) -> int:
r""" Returns PruneLen
Returns:
prune_len (int):
the number of pruned tokens from each requests
"""
@retry(delay=2, tries=3, backoff=2, max_delay=4)
def make_substrate_call_with_retry():
with self.substrate as substrate:
return substrate.query( module='SubtensorModule', storage_function = 'ValidatorPruneLen' ).value
return make_substrate_call_with_retry()

@property
def logits_divergence (self) -> int:
r""" Returns logits_divergence
Returns:
logits_divergence (int):
the divergence value for logit distances, a measure for anomaly detection
"""
@retry(delay=2, tries=3, backoff=2, max_delay=4)
def make_substrate_call_with_retry():
with self.substrate as substrate:
U64MAX = 18446744073709551615
return substrate.query( module='SubtensorModule', storage_function = 'ValidatorLogitsDivergence' ).value/U64MAX
return make_substrate_call_with_retry()

def serve_axon (
self,
Expand Down
Loading

0 comments on commit 5557396

Please sign in to comment.