Skip to content

Commit

Permalink
SynapseType fix in dendrite (#874)
Browse files Browse the repository at this point in the history
* added change to SynapseType

* added TEXT_LAST_HIDDEN_STATE

* added TEXT_CAUSAL_LM

Co-authored-by: joeylegere <joeylegere@gmail.com>
  • Loading branch information
robertalanm and joeylegere committed Aug 18, 2022
1 parent b171dd9 commit b6f04e2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions bittensor/_dendrite/dendrite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def text_causal_lm (
times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`):
times per call.
"""
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextCausalLM:
raise ValueError( "Passed synapse must have type: {} got {} instead".formate( bittensor.proto.Synapse.SynapseType.TextCausalLM, synapses.synapse_type ) )
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM:
raise ValueError( "Passed synapse must have type: {} got {} instead".formate( bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM, synapses.synapse_type ) )

# Format inputs.
formatted_endpoints, formatted_inputs = self.format_text_inputs (
Expand Down Expand Up @@ -486,8 +486,8 @@ def text_causal_lm_next(
times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`):
times per call.
"""
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextCausalLMNext:
raise ValueError(f"Passed synapse must have type: {bittensor.proto.Synapse.SynapseType.TextCausalLMNext} "
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT:
raise ValueError(f"Passed synapse must have type: {bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT} "
f"got {synapse.synapse_type} instead")

# Format inputs.
Expand Down Expand Up @@ -556,8 +556,8 @@ def text_last_hidden_state(
times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`):
times per call.
"""
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextLastHiddenState:
raise ValueError( "Passed synapse must have type:{} got:{} instead".formate( bittensor.proto.Synapse.SynapseType.TextLastHiddenState, synapses.synapse_type ) )
if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE:
raise ValueError( "Passed synapse must have type:{} got:{} instead".formate( bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE, synapses.synapse_type ) )

# Format inputs.
formatted_endpoints, formatted_inputs = self.format_text_inputs (
Expand Down

0 comments on commit b6f04e2

Please sign in to comment.