Skip to content

Commit

Permalink
SN1-195: add inference task (#344)
Browse files Browse the repository at this point in the history
- Added Dynamic Model Loading/unloading to allow validators to run
multiple reference models
- Added inference task to allow us to force miners to generate with a
specific model (layer to be used for ensembling)
- Made scoring loop run asynchronously from task generation

---------

Co-authored-by: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com>
Co-authored-by: bkb2135 <98138173+bkb2135@users.noreply.github.com>
Co-authored-by: cassova <29239592+cassova@users.noreply.github.com>
Co-authored-by: Nicholas Miller <nicholasmiller@Nicholass-MBP.fritz.box>
  • Loading branch information
5 people committed Sep 4, 2024
1 parent 4a5cd1a commit e766486
Show file tree
Hide file tree
Showing 46 changed files with 1,909 additions and 792 deletions.
11 changes: 10 additions & 1 deletion .env.validator.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ SUBTENSOR_CHAIN_ENDPOINT = None
WALLET_NAME="validator"

# The name of the hotkey associated with the validator wallet.
HOTKEY="default"
HOTKEY="validator_hotkey"

# Open port which can be used to connect to the network.
AXON_PORT=22116

# HuggingFace Access Token.
HF_TOKEN=""

# Key for logging to wandb
WANDB_API_KEY="ae29a588c238d0e168d620e0b18a5e29e283935a"
WANDB_ENTITY = "macrocosmos"
WANDB_PROJECT_NAME = "prompting-validators"
LLM_MODEL = "casperhansen/llama-3-8b-instruct-awq"
SUBTENSOR_NETWORK = "test"
MAX_ALLOWED_VRAM_GB = 40
LLM_MODEL_RAM = 44
27 changes: 11 additions & 16 deletions neurons/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from prompting.base.dendrite import SynapseStreamResult
from prompting.base.protocol import StreamPromptingSynapse
from prompting.utils.misc import async_log, serialize_exception_to_string
from transformers import PreTrainedTokenizerFast as Tokenizer
from prompting.tasks.base_task import BaseTask
from prompting.tasks.base_task import BaseTextTask
from prompting.llms.base_llm import BasePipeline
from loguru import logger

Expand All @@ -17,24 +16,19 @@ async def execute_dendrite_call(dendrite_call):
return responses


async def process_stream(uid: int, async_iterator: Awaitable, tokenizer: Tokenizer) -> SynapseStreamResult:
async def process_stream(uid: int, async_iterator: Awaitable) -> SynapseStreamResult:
"""Process a single response asynchronously."""
synapse = None # Initialize chunk with a default value
exception = None
accumulated_chunks = []
accumulated_chunks_timings = []
accumulated_tokens_per_chunk = []
start_time = time.time()

try:
async for chunk in async_iterator: # most important loop, as this is where we acquire the final synapse.
if isinstance(chunk, str):
accumulated_chunks.append(chunk)
accumulated_chunks_timings.append(time.time() - start_time)

tokens_in_chunk = len(tokenizer.tokenize(chunk))
accumulated_tokens_per_chunk.append(tokens_in_chunk)

logger.debug(f"\nchunk for uid {uid}: {chunk}")

# Assuming last chunk of async_iterator holds the last value yielded as a StreamingSynapse
Expand All @@ -53,15 +47,14 @@ async def process_stream(uid: int, async_iterator: Awaitable, tokenizer: Tokeniz
return SynapseStreamResult(
accumulated_chunks=accumulated_chunks,
accumulated_chunks_timings=accumulated_chunks_timings,
tokens_per_chunk=accumulated_tokens_per_chunk,
synapse=synapse,
uid=uid,
exception=exception,
)


@async_log
async def handle_response(stream_results_dict: Dict[int, Awaitable], tokenizer: Tokenizer) -> List[SynapseStreamResult]:
async def handle_response(stream_results_dict: Dict[int, Awaitable]) -> List[SynapseStreamResult]:
"""The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
Expand All @@ -79,31 +72,33 @@ async def handle_response(stream_results_dict: Dict[int, Awaitable], tokenizer:
] # Pair UIDs with their tasks

# Start tasks, preserving order and their associated UIDs
process_stream_tasks = [process_stream(uid, resp, tokenizer) for uid, resp in tasks_with_uid]
process_stream_tasks = [process_stream(uid, resp) for uid, resp in tasks_with_uid]
processed_stream_results = await asyncio.gather(*process_stream_tasks, return_exceptions=True)

return processed_stream_results


@async_log
async def generate_reference(task: BaseTask, pipeline: BasePipeline) -> str:
async def generate_reference(task: BaseTextTask, pipeline: BasePipeline) -> str:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, task.generate_reference, pipeline)
return result


def log_stream_results(stream_results: List[SynapseStreamResult]):
failed_responses = [response for response in stream_results if response.exception is not None]
failed_responses = [
response for response in stream_results if response.exception is not None or response.synapse is None
]
empty_responses = [
response for response in stream_results if response.exception is None and response.synapse.completion == ""
]
non_empty_responses = [
response for response in stream_results if response.exception is None and response.synapse.completion != ""
]

logger.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
logger.info(f"Total of empty responses: ({len(empty_responses)})")
logger.info(f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}")
logger.debug(f"Total of non_empty responses: ({len(non_empty_responses)})\nRESPONSES: {non_empty_responses}")
logger.debug(f"Total of empty responses: ({len(empty_responses)})")
logger.debug(f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}")

for failed_response in failed_responses:
formatted_exception = serialize_exception_to_string(failed_response.exception)
Expand Down
165 changes: 165 additions & 0 deletions neurons/miners/inference_miner/miner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# ruff: noqa: E402

# This is an example miner that can respond to the inference task using a vllm model.
from prompting import settings

settings.settings = settings.Settings(mode="miner")
settings = settings.settings
import time
from functools import partial
from loguru import logger
from pydantic import model_validator
from prompting.base.miner import BaseStreamMinerNeuron
from prompting.base.protocol import StreamPromptingSynapse
from vllm import LLM, SamplingParams
from starlette.types import Send
from prompting.utils.logging import ErrorLoggingEvent, log_event
from prompting.base.protocol import AvailabilitySynapse
from prompting.llms.utils import GPUInfo

NEURON_MAX_TOKENS: int = 256
NEURON_TEMPERATURE: float = 0.7
NEURON_TOP_K: int = 50
NEURON_TOP_P: float = 0.95
NEURON_STREAMING_BATCH_SIZE: int = 12
NEURON_STOP_ON_FORWARD_EXCEPTION: bool = False

SYSTEM_PROMPT = """You are a helpful agent that does its best to answer all questions!"""


class VLLMMiner(BaseStreamMinerNeuron):
llm: LLM | None = None
accumulated_total_tokens: int = 0
accumulated_prompt_tokens: int = 0
accumulated_completion_tokens: int = 0
accumulated_total_cost: float = 0
should_exit: bool = False

@model_validator(mode="after")
def init_vllm(self) -> "VLLMMiner":
GPUInfo.log_gpu_info()
logger.debug("Loading vLLM model...")
self.llm = LLM(model=settings.MINER_LLM_MODEL, gpu_memory_utilization=0.3)
logger.debug("vLLM model loaded.")
GPUInfo.log_gpu_info()
return self

def forward(self, synapse: StreamPromptingSynapse) -> StreamPromptingSynapse:
"""The forward function generates text based on a prompt, model, and seed."""

async def _forward(
self: "VLLMMiner",
synapse: StreamPromptingSynapse,
init_time: float,
timeout_threshold: float,
send: Send,
):
buffer = []
accumulated_chunks = []
accumulated_chunks_timings = []
temp_completion = "" # for wandb logging
timeout_reached = False

try:
start_time = time.time()
sampling_params = SamplingParams(
seed=synapse.seed,
)

stream_response = self.llm.generate(prompts=[synapse.messages[0]], sampling_params=sampling_params)

for chunk in stream_response:
chunk_content = chunk.outputs[0].text

if not chunk_content:
logger.info("vLLM returned chunk content with None")
continue

accumulated_chunks.append(chunk_content)
accumulated_chunks_timings.append(time.time() - start_time)

buffer.append(chunk_content)

if time.time() - init_time > timeout_threshold:
logger.debug("⏰ Timeout reached, stopping streaming")
timeout_reached = True
break

if len(buffer) == NEURON_STREAMING_BATCH_SIZE:
joined_buffer = "".join(buffer)
temp_completion += joined_buffer
logger.debug(f"Streamed tokens: {joined_buffer}")

await send(
{
"type": "http.response.body",
"body": joined_buffer.encode("utf-8"),
"more_body": True,
}
)
buffer = []

if buffer and not timeout_reached: # Don't send the last buffer of data if timeout.
joined_buffer = "".join(buffer)
await send(
{
"type": "http.response.body",
"body": joined_buffer.encode("utf-8"),
"more_body": False,
}
)

except Exception as e:
logger.exception(e)
logger.error(f"Error in forward: {e}")
log_event(ErrorLoggingEvent(error=str(e)))
if NEURON_STOP_ON_FORWARD_EXCEPTION:
self.should_exit = True

finally:
synapse_latency = time.time() - init_time
self.log_event(
synapse=synapse,
timing=synapse_latency,
messages=synapse.messages,
accumulated_chunks=accumulated_chunks,
accumulated_chunks_timings=accumulated_chunks_timings,
)

logger.debug(
f"📧 Message received from {synapse.dendrite.hotkey}, IP: {synapse.dendrite.ip}; \nForwarding synapse: {synapse}"
)
init_time = time.time()
timeout_threshold = synapse.timeout

token_streamer = partial(
_forward,
self,
synapse,
init_time,
timeout_threshold,
)
return synapse.create_streaming_response(token_streamer)

def check_availability(self, synapse: AvailabilitySynapse) -> AvailabilitySynapse:
"""The check_availability function returns an AvailabilitySynapse which indicates
which tasks and models this miner can handle."""

logger.info(f"Checking availability of miner... {synapse}")
synapse.task_availabilities = {
task: True
for task, _ in synapse.task_availabilities.items()
if task == "SyntheticInferenceTask" or "OrganicInferenceTask"
}
synapse.llm_model_availabilities = {
model: True for model, _ in synapse.llm_model_availabilities.items() if model == settings.MINER_LLM_MODEL
}
return synapse


if __name__ == "__main__":
with VLLMMiner() as miner:
while not miner.should_exit:
miner.log_status()
time.sleep(5)
logger.warning("Ending miner...")
14 changes: 9 additions & 5 deletions neurons/miners/openai/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from prompting.base.protocol import StreamPromptingSynapse
from neurons.miners.openai.utils import OpenAIUtils
from starlette.types import Send
from prompting.utils.logging import ErrorEvent, log_event
from prompting.utils.logging import ErrorLoggingEvent, log_event
from prompting.base.protocol import AvailabilitySynapse

MODEL_ID: str = "gpt-3.5-turbo"
NEURON_MAX_TOKENS: int = 256
Expand Down Expand Up @@ -118,7 +119,7 @@ async def _forward(
except Exception as e:
logger.exception(e)
logger.error(f"Error in forward: {e}")
log_event(ErrorEvent(error=str(e)))
log_event(ErrorLoggingEvent(error=str(e)))
if NEURON_STOP_ON_FORWARD_EXCEPTION:
self.should_exit = True

Expand All @@ -135,7 +136,6 @@ async def _forward(
logger.debug(
f"📧 Message received from {synapse.dendrite.hotkey}, IP: {synapse.dendrite.ip}; \nForwarding synapse: {synapse}"
)

timeout_threshold = synapse.timeout

token_streamer = partial(
Expand All @@ -144,9 +144,13 @@ async def _forward(
synapse,
timeout_threshold,
)
return synapse.create_streaming_response(token_streamer)

streaming_response = synapse.create_streaming_response(token_streamer)
return streaming_response
def check_availability(self, synapse: AvailabilitySynapse) -> AvailabilitySynapse:
logger.info(f"Checking availability of miner... {synapse}")
# allow all tasks to be sent through
synapse.task_availabilities = {task: True for task, _ in synapse.task_availabilities.items()}
return synapse


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e766486

Please sign in to comment.