diff --git a/neurons/validator.py b/neurons/validator.py index 7b502029..d3ee8da0 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -44,7 +44,6 @@ def __init__(self, config=None): super(Validator, self).__init__(config=config) bt.logging.info("load_state()") - self.load_state() # TODO(developer): Anything specific to your use case you can do here diff --git a/template/base/neuron.py b/template/base/neuron.py index ef2caf05..eb0a6542 100644 --- a/template/base/neuron.py +++ b/template/base/neuron.py @@ -99,6 +99,8 @@ def __init__(self, config=None): ) self.step = 0 + self.load_state() + @abstractmethod async def forward(self, synapse: bt.Synapse) -> bt.Synapse: ... diff --git a/template/base/validator.py b/template/base/validator.py index ec069d47..00f592af 100644 --- a/template/base/validator.py +++ b/template/base/validator.py @@ -19,6 +19,7 @@ import copy +import os import torch import asyncio import threading @@ -339,8 +340,12 @@ def load_state(self): """Loads the state of the validator from a file.""" bt.logging.info("Loading validator state.") + if not os.path.exists(self.config.neuron.full_path + "/state.pt"): + bt.logging.warning("No saved state found") + return + # Load the state of the validator from file. state = torch.load(self.config.neuron.full_path + "/state.pt") self.step = state["step"] - self.scores = state["scores"] + self.scores = state["scores"].to(self.device) self.hotkeys = state["hotkeys"]