diff --git a/.gitignore b/.gitignore index 1a7d166..12f5096 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -localtesting/ \ No newline at end of file +localtesting/ + +# pm2 config file +app.config.js \ No newline at end of file diff --git a/docs/running_a_validator.md b/docs/running_a_validator.md index 0672de5..4fd86a3 100644 --- a/docs/running_a_validator.md +++ b/docs/running_a_validator.md @@ -57,7 +57,7 @@ btcli wallet faucet --wallet.name validator --subtensor.network test Register your UID on the test network: ```sh -btcli wallet recycle_register --subtensor.network test +btcli subnets register --subtensor.network test ``` ## 5. Start the Process diff --git a/neurons/miners/bittensorLM/README.md b/neurons/miners/bittensorLM/README.md index 84f062b..63f1cb7 100644 --- a/neurons/miners/bittensorLM/README.md +++ b/neurons/miners/bittensorLM/README.md @@ -14,7 +14,7 @@ usage: miner.py [-h] [--axon.port AXON.PORT] [--subtensor.network SUBTENSOR.NETW [--miner.blacklist.force_validator_permit] [--miner.blacklist.allow_non_registered] [--miner.blacklist.minimum_stake_requirement MINER.BLACKLIST.MINIMUM_STAKE_REQUIREMENT] [--miner.blacklist.prompt_cache_block_span MINER.BLACKLIST.PROMPT_CACHE_BLOCK_SPAN] [--miner.blacklist.use_prompt_cache] [--miner.blacklist.min_request_period MINER.BLACKLIST.MIN_REQUEST_PERIOD] [--miner.priority.default MINER.PRIORITY.DEFAULT] [--miner.priority.time_stake_multiplicate MINER.PRIORITY.TIME_STAKE_MULTIPLICATE] - [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.no_register] [--miner.mock_subtensor] [--wandb.on] + [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.mock_subtensor] [--wandb.on] [--wandb.project_name WANDB.PROJECT_NAME] [--wandb.entity WANDB.ENTITY] [--logging.debug] [--logging.trace] [--logging.record_log] [--logging.logging_dir LOGGING.LOGGING_DIR] [--wallet.name WALLET.NAME] [--wallet.hotkey WALLET.HOTKEY] [--wallet.path WALLET.PATH] [--config CONFIG] [--strict] [--no_version_checking] [--no_prompt] @@ -60,7 +60,6 @@ options: --miner.no_serve If True, the miner doesnt serve the axon. --miner.no_start_axon If True, the miner doesnt start the axon. - --miner.no_register If True, the miner doesnt register its wallet. --miner.mock_subtensor If True, the miner will allow non-registered hotkeys to mine. --wandb.on Turn on wandb. diff --git a/neurons/miners/openai/README.md b/neurons/miners/openai/README.md index 593c969..2f07337 100644 --- a/neurons/miners/openai/README.md +++ b/neurons/miners/openai/README.md @@ -38,7 +38,7 @@ usage: miner.py [-h] [--axon.port AXON.PORT] [--subtensor.network SUBTENSOR.NETW [--miner.blacklist.force_validator_permit] [--miner.blacklist.allow_non_registered] [--miner.blacklist.minimum_stake_requirement MINER.BLACKLIST.MINIMUM_STAKE_REQUIREMENT] [--miner.blacklist.prompt_cache_block_span MINER.BLACKLIST.PROMPT_CACHE_BLOCK_SPAN] [--miner.blacklist.use_prompt_cache] [--miner.blacklist.min_request_period MINER.BLACKLIST.MIN_REQUEST_PERIOD] [--miner.priority.default MINER.PRIORITY.DEFAULT] [--miner.priority.time_stake_multiplicate MINER.PRIORITY.TIME_STAKE_MULTIPLICATE] - [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.no_register] [--miner.mock_subtensor] [--wandb.on] + [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.mock_subtensor] [--wandb.on] [--wandb.project_name WANDB.PROJECT_NAME] [--wandb.entity WANDB.ENTITY] [--logging.debug] [--logging.trace] [--logging.record_log] [--logging.logging_dir LOGGING.LOGGING_DIR] [--wallet.name WALLET.NAME] [--wallet.hotkey WALLET.HOTKEY] [--wallet.path WALLET.PATH] [--config CONFIG] [--strict] [--no_version_checking] [--no_prompt] @@ -84,7 +84,6 @@ options: --miner.no_serve If True, the miner doesnt serve the axon. --miner.no_start_axon If True, the miner doesnt start the axon. - --miner.no_register If True, the miner doesnt register its wallet. --miner.mock_subtensor If True, the miner will allow non-registered hotkeys to mine. --wandb.on Turn on wandb. diff --git a/neurons/miners/vicuna/README.md b/neurons/miners/vicuna/README.md index d162dce..75788c1 100644 --- a/neurons/miners/vicuna/README.md +++ b/neurons/miners/vicuna/README.md @@ -69,7 +69,7 @@ usage: miner.py [-h] [--axon.port AXON.PORT] [--subtensor.network SUBTENSOR.NETW [--miner.blacklist.force_validator_permit] [--miner.blacklist.allow_non_registered] [--miner.blacklist.minimum_stake_requirement MINER.BLACKLIST.MINIMUM_STAKE_REQUIREMENT] [--miner.blacklist.prompt_cache_block_span MINER.BLACKLIST.PROMPT_CACHE_BLOCK_SPAN] [--miner.blacklist.use_prompt_cache] [--miner.blacklist.min_request_period MINER.BLACKLIST.MIN_REQUEST_PERIOD] [--miner.priority.default MINER.PRIORITY.DEFAULT] [--miner.priority.time_stake_multiplicate MINER.PRIORITY.TIME_STAKE_MULTIPLICATE] - [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.no_register] [--miner.mock_subtensor] [--wandb.on] + [--miner.priority.len_request_timestamps MINER.PRIORITY.LEN_REQUEST_TIMESTAMPS] [--miner.no_set_weights] [--miner.no_serve] [--miner.no_start_axon] [--miner.mock_subtensor] [--wandb.on] [--wandb.project_name WANDB.PROJECT_NAME] [--wandb.entity WANDB.ENTITY] [--logging.debug] [--logging.trace] [--logging.record_log] [--logging.logging_dir LOGGING.LOGGING_DIR] [--wallet.name WALLET.NAME] [--wallet.hotkey WALLET.HOTKEY] [--wallet.path WALLET.PATH] [--config CONFIG] [--strict] [--no_version_checking] [--no_prompt] @@ -115,7 +115,6 @@ options: --miner.no_serve If True, the miner doesnt serve the axon. --miner.no_start_axon If True, the miner doesnt start the axon. - --miner.no_register If True, the miner doesnt register its wallet. --miner.mock_subtensor If True, the miner will allow non-registered hotkeys to mine. --wandb.on Turn on wandb. diff --git a/neurons/validators/validator.py b/neurons/validators/validator.py index 869c918..fe1ec78 100644 --- a/neurons/validators/validator.py +++ b/neurons/validators/validator.py @@ -46,19 +46,23 @@ # Load gating models from prompting.validators.reward import ( Blacklist, - TaskValidator, NSFWRewardModel, DirectPreferenceRewardModel, OpenAssistantRewardModel, ReciprocateRewardModel, RelevanceRewardModel, - MockRewardModel, DahoasRewardModel, DiversityRewardModel, PromptRewardModel, RewardModelType, ) +from prompting.validators.penalty import ( + TaskValidationPenaltyModel, + KeywordMatchPenaltyModel, + ContentMatchPenaltyModel, +) + class neuron: @classmethod @@ -189,8 +193,12 @@ def __init__(self): self.blacklist, MockRewardModel(RewardModelType.nsfw.value), ] + self.penalty_functions = [ + TaskValidationPenaltyModel(max_penalty=0.1), + ContentMatchPenaltyModel(max_penalty=0.1), + KeywordMatchPenaltyModel(max_penalty=1), + ] bt.logging.debug(str(self.reward_functions)) - self.blacklist = MockRewardModel(RewardModelType.blacklist.value) else: self.reward_weights = torch.tensor( [ @@ -245,11 +253,6 @@ def __init__(self): if not self.config.neuron.blacklist_off else MockRewardModel(RewardModelType.blacklist.value) ) - task_validator = ( - TaskValidator() - if not self.config.neuron.task_validator_off - else MockRewardModel(RewardModelType.task_validator.value) - ) relevance_model = ( RelevanceRewardModel(device=self.device) if not self.config.neuron.relevance_off @@ -268,13 +271,20 @@ def __init__(self): self.masking_functions = [ self.blacklist, - task_validator, relevance_model, self.diversity_model, nsfw_model, ] + + self.penalty_functions = [ + TaskValidationPenaltyModel(max_penalty=0.1), + ContentMatchPenaltyModel(max_penalty=0.1), + KeywordMatchPenaltyModel(max_penalty=1), + ] + bt.logging.debug(str(self.reward_functions)) bt.logging.debug(str(self.masking_functions)) + bt.logging.debug(str(self.penalty_functions)) # Init the event loop. self.loop = asyncio.get_event_loop() diff --git a/prompting/baseminer/config.py b/prompting/baseminer/config.py index decc504..7f52818 100644 --- a/prompting/baseminer/config.py +++ b/prompting/baseminer/config.py @@ -199,12 +199,6 @@ def get_config() -> "bt.Config": help="If True, the miner doesnt start the axon.", default=False, ) - parser.add_argument( - "--miner.no_register", - action="store_true", - help="If True, the miner doesnt register its wallet.", - default=False, - ) # Mocks. parser.add_argument( diff --git a/prompting/baseminer/miner.py b/prompting/baseminer/miner.py index 0d78ed6..bb3da36 100644 --- a/prompting/baseminer/miner.py +++ b/prompting/baseminer/miner.py @@ -68,9 +68,6 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None): # Activating Bittensor's logging with the set configurations. bt.logging(config=self.config, logging_dir=self.config.full_path) - bt.logging.info( - f"Running miner for subnet: {self.config.netuid} on network: {self.config.subtensor.chain_endpoint} with config:" - ) if not self.config.miner.blacklist.force_validator_permit: bt.logging.warning( @@ -89,7 +86,10 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None): # subtensor manages the blockchain connection, facilitating interaction with the Bittensor blockchain. self.subtensor = subtensor or bt.subtensor(config=self.config) - bt.logging.info(f"Subtensor: {subtensor}") + bt.logging.info(f"Subtensor: {self.subtensor}") + bt.logging.info( + f"Running miner for subnet: {self.config.netuid} on network: {self.subtensor.chain_endpoint} with config:" + ) # metagraph provides the network's current state, holding state about other participants in a subnet. self.metagraph = self.subtensor.metagraph(self.config.netuid) diff --git a/prompting/baseminer/run.py b/prompting/baseminer/run.py index 293e657..3327a09 100644 --- a/prompting/baseminer/run.py +++ b/prompting/baseminer/run.py @@ -28,7 +28,7 @@ def run(self): Initiates and manages the main loop for the miner on the Bittensor network. This function performs the following primary tasks: - 1. Optionally registers the miner's wallet with the network. + 1. Check for registration on the Bittensor network. 2. Attaches the miner's forward, blacklist, and priority functions to its axon. 3. Starts the miner's axon, making it active on the network. 4. Regularly updates the metagraph with the latest network state. @@ -48,12 +48,16 @@ def run(self): KeyboardInterrupt: If the miner is stopped by a manual interruption. Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis. """ - # --- Optionally register the wallet. - if not self.config.miner.no_register: - bt.logging.info( - f"Registering wallet: {self.wallet} on netuid {self.config.netuid}" + # --- Check for registration. + if not self.subtensor.is_hotkey_registered( + netuid=self.config.netuid, + hotkey=self.wallet.hotkey.ss58_address, + ): + bt.logging.error( + f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}" + f"Please register the hotkey using `btcli subnets register` before trying again" ) - self.subtensor.register(netuid=self.config.netuid, wallet=self.wallet) + exit() # Serve passes the axon information to the network + netuid we are hosting on. # This will auto-update if the axon port of external ip have changed. diff --git a/prompting/protocol.py b/prompting/protocol.py index b43ab75..074501e 100644 --- a/prompting/protocol.py +++ b/prompting/protocol.py @@ -160,17 +160,6 @@ class StreamPrompting(bt.StreamingSynapse): - `extract_response_json`: Extracts relevant JSON data from the response, useful for gaining insights on the response's metadata or for debugging purposes. - Example usage: - ```python - stream_prompter = StreamPrompting(roles=["role1", "role2"], messages=["message1", "message2"]) - # Process a streaming response... - stream_prompter.process_streaming_response(response) - # Access the result - result = stream_prompter.deserialize() - # Extract response metadata - json_info = stream_prompter.extract_response_json(response) - ``` - Note: While you can directly use the `StreamPrompting` class, it's designed to be extensible. Thus, you can create subclasses to further customize behavior for specific prompting scenarios or requirements. """ @@ -215,18 +204,6 @@ async def process_streaming_response(self, response: StreamingResponse): Args: response: The streaming response object containing the content chunks to be processed. Each chunk in this response is expected to be a set of tokens that can be decoded and split into individual messages or prompts. - - Usage: - Generally, this method is called when there's an incoming streaming response to be processed. - - ```python - stream_prompter = StreamPrompting(roles=["role1", "role2"], messages=["message1", "message2"]) - await stream_prompter.process_streaming_response(response) - ``` - - Note: - It's important to remember that this method is asynchronous. Ensure it's called within an appropriate - asynchronous context. """ if self.completion is None: self.completion = "" @@ -235,6 +212,7 @@ async def process_streaming_response(self, response: StreamingResponse): for token in tokens: if token: self.completion += token + yield tokens def deserialize(self) -> str: """ @@ -266,19 +244,6 @@ def extract_response_json(self, response: StreamingResponse) -> dict: - Dendrite and Axon related information extracted from headers. - Roles and Messages pertaining to the current StreamPrompting instance. - The accumulated completion. - - Usage: - This method can be used after processing a response to gather detailed metadata: - - ```python - stream_prompter = StreamPrompting(roles=["role1", "role2"], messages=["message1", "message2"]) - # After processing the response... - json_info = stream_prompter.extract_response_json(response) - ``` - - Note: - While the primary output is the structured dictionary, understanding this output can be instrumental in - troubleshooting or in extracting specific insights about the interaction with the Bittensor network. """ headers = { k.decode("utf-8"): v.decode("utf-8") diff --git a/prompting/validators/__init__.py b/prompting/validators/__init__.py index 80a4c94..24575cc 100644 --- a/prompting/validators/__init__.py +++ b/prompting/validators/__init__.py @@ -27,7 +27,7 @@ from . import event from . import dataset -__version__ = "2.0.2" +__version__ = "2.1.0" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) diff --git a/prompting/validators/config.py b/prompting/validators/config.py index 430eac8..401848d 100644 --- a/prompting/validators/config.py +++ b/prompting/validators/config.py @@ -135,7 +135,7 @@ def add_args(cls, parser): "--neuron.num_followup_steps", type=int, help="How many followup steps to take.", - default=4, + default=3, ) parser.add_argument( diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py new file mode 100644 index 0000000..ae464dc --- /dev/null +++ b/prompting/validators/criteria.py @@ -0,0 +1,111 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import re +import torch +import numpy as np +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import List +from enum import Enum + + +@dataclass +class TaskCriterion(ABC): + """ + Abstract base class for defining task-specific evaluation criteria. + + Attributes: + text (str): Text of the criterion to be added to the prompt. + penalty (float): Penalty value associated with the criterion. + Returns: + torch.FloatTensor: Tensor containing the penalty values for each response. + """ + + text: str + penalty: float + + @abstractmethod + def evaluate(self, completions: List[str]) -> torch.FloatTensor: + pass + + @abstractmethod + def compose_text(self) -> str: + pass + + +class TextLengthUnitEnum(Enum): + CHARACTERS = "characters" + WORDS = "words" + SENTENCES = "sentences" + PARAGRAPHS = "paragraphs" + + +@dataclass +class MatchLengthCriteria(TaskCriterion): + text: str = "Your response must have {target_length} {unit}." + penalty: float = 0.1 + target_length: int = 100 + unit: TextLengthUnitEnum = TextLengthUnitEnum.WORDS + + def _count_sentences(self, text): + # Define str pattern to match + pattern = r"(? int: + unit_to_split_pattern = { + TextLengthUnitEnum.CHARACTERS: None, + TextLengthUnitEnum.SENTENCES: None, + TextLengthUnitEnum.WORDS: r"\s+", + TextLengthUnitEnum.PARAGRAPHS: r"\n\n+", + } + + if self.unit == TextLengthUnitEnum.CHARACTERS: + return len(response) + elif self.unit == TextLengthUnitEnum.SENTENCES: + return self._count_sentences(response) + else: + split_pattern = unit_to_split_pattern[self.unit] + return len(re.split(split_pattern, response.strip())) + + def evaluate(self, completions: List[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype=torch.float32) + for idx, completion in enumerate(completions): + completion_length = self._get_completion_length(completion) + if completion_length != self.target_length: + # Computes the relative error as the deviation of the response length from the target length, normalized by the target length. + # Scales the penalty using an exponential function based on this relative error. + # The penalty starts off small for minor deviations but increases rapidly for larger deviations. + # The formula ensures that the penalty lies between 0 and 1. + relative_error = ( + self.target_length - completion_length + ) / self.target_length + penalty_scale_factor = 1 - np.exp(-10 * relative_error**2) + + scaled_penalty = self.penalty * penalty_scale_factor + penalties[idx] = scaled_penalty + + return penalties + + def compose_text(self) -> str: + return self.text.format(target_length=self.target_length, unit=self.unit.value) diff --git a/prompting/validators/event.py b/prompting/validators/event.py index 6e51584..df7a26a 100644 --- a/prompting/validators/event.py +++ b/prompting/validators/event.py @@ -19,8 +19,8 @@ import bittensor as bt from dataclasses import dataclass from typing import List, Optional - from prompting.validators.reward import RewardModelType +from prompting.validators.penalty import PenaltyModelType @dataclass @@ -34,6 +34,7 @@ class EventSchema: str ] # List of completion status codes for a given prompt name: str # Prompt type, e.g. 'followup', 'answer' + task_type: str # Task type, e.g. 'summary', 'question' block: float # Current block at given step gating_loss: float # Gating model loss for given step uids: List[int] # Queried uids @@ -59,11 +60,7 @@ class EventSchema: prompt_reward_model: Optional[ List[float] ] # Output vector of the prompt reward model - relevance_filter: Optional[ - List[float] - ] # Output vector of the relevance scoring reward model - task_validator_filter: Optional[List[float]] - + relevance_filter: Optional[List[float]] dahoas_reward_model_normalized: Optional[ List[float] ] # Output vector of the dahoas reward model @@ -83,10 +80,19 @@ class EventSchema: prompt_reward_model_normalized: Optional[ List[float] ] # Output vector of the prompt reward model - relevance_filter_normalized: Optional[ - List[float] - ] # Output vector of the relevance scoring reward model - task_validator_filter_normalized: Optional[List[float]] + relevance_filter_normalized: Optional[List[float]] + # TODO: Add comments + task_validation_penalty_raw: Optional[List[float]] + task_validation_penalty_adjusted: Optional[List[float]] + task_validation_penalty_applied: Optional[List[float]] + + keyword_match_penalty_raw: Optional[List[float]] + keyword_match_penalty_adjusted: Optional[List[float]] + keyword_match_penalty_applied: Optional[List[float]] + + sentence_match_penalty_raw: Optional[List[float]] + sentence_match_penalty_adjusted: Optional[List[float]] + sentence_match_penalty_applied: Optional[List[float]] # Weights data set_weights: Optional[List[List[float]]] @@ -97,9 +103,6 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": rewards = { "blacklist_filter": event_dict.get(RewardModelType.blacklist.value), "dahoas_reward_model": event_dict.get(RewardModelType.dahoas.value), - "task_validator_filter": event_dict.get( - RewardModelType.task_validator.value - ), "nsfw_filter": event_dict.get(RewardModelType.nsfw.value), "relevance_filter": event_dict.get(RewardModelType.relevance.value), "reciprocate_reward_model": event_dict.get( @@ -112,9 +115,6 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": "dahoas_reward_model_normalized": event_dict.get( RewardModelType.dahoas.value + "_normalized" ), - "task_validator_filter_normalized": event_dict.get( - RewardModelType.task_validator.value + "_normalized" - ), "nsfw_filter_normalized": event_dict.get( RewardModelType.nsfw.value + "_normalized" ), @@ -137,6 +137,35 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": RewardModelType.prompt.value + "_normalized" ), } + penalties = { + "task_validation_penalty_raw": event_dict.get( + PenaltyModelType.task_validation_penalty.value + "_raw" + ), + "task_validation_penalty_adjusted": event_dict.get( + PenaltyModelType.task_validation_penalty.value + "_adjusted" + ), + "task_validation_penalty_applied": event_dict.get( + PenaltyModelType.task_validation_penalty.value + "_applied" + ), + "keyword_match_penalty_raw": event_dict.get( + PenaltyModelType.keyword_match_penalty.value + "_raw" + ), + "keyword_match_penalty_adjusted": event_dict.get( + PenaltyModelType.keyword_match_penalty.value + "_adjusted" + ), + "keyword_match_penalty_applied": event_dict.get( + PenaltyModelType.keyword_match_penalty.value + "_applied" + ), + "sentence_match_penalty_raw": event_dict.get( + PenaltyModelType.sentence_match_penalty.value + "_raw" + ), + "sentence_match_penalty_adjusted": event_dict.get( + PenaltyModelType.sentence_match_penalty.value + "_adjusted" + ), + "sentence_match_penalty_applied": event_dict.get( + PenaltyModelType.sentence_match_penalty.value + "_applied" + ), + } # Logs warning that expected data was not set properly if not disable_log_rewards and any(value is None for value in rewards.values()): @@ -152,6 +181,7 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": completion_status_messages=event_dict["completion_status_messages"], completion_status_codes=event_dict["completion_status_codes"], name=event_dict["name"], + task_type=event_dict["task_type"], block=event_dict["block"], gating_loss=event_dict["gating_loss"], uids=event_dict["uids"], @@ -160,5 +190,6 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": best=event_dict["best"], rewards=event_dict["rewards"], **rewards, + **penalties, set_weights=None, ) diff --git a/prompting/validators/forward.py b/prompting/validators/forward.py index 3b9192f..8dfcd1e 100644 --- a/prompting/validators/forward.py +++ b/prompting/validators/forward.py @@ -30,6 +30,12 @@ from prompting.validators.misc import ttl_get_block from prompting.validators.prompts import followup_prompt, answer_prompt, augment_prompt from prompting.validators.utils import check_uid_availability +from prompting.validators.tasks import ( + Task, + create_summarization_task, + create_qg_task, + create_qa_task, +) import prompting @@ -69,22 +75,14 @@ def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.LongTensor return uids -async def run_step( - self, - prompt: str, - k: int, - timeout: float, - name: str, - exclude: list = [], - base_prompt=None, -): - if base_prompt == None: - base_prompt = prompt +async def run_step(self, task: Task, k: int, timeout: float, exclude: list = []): + task_name = task.task_name + prompt = task.compose_prompt() - bt.logging.debug("run_step", name) + bt.logging.debug("run_step", task_name) # Record event start time. - event = {"name": name} + event = {"name": task_name, "task_type": task.task_type} start_time = time.time() # Get the list of uids to query for this step. uids = get_random_uids(self, k=k, exclude=exclude).to(self.device) @@ -103,7 +101,7 @@ async def run_step( # remove leading and trailing periods completion = response.completion.strip(".") - if "followup" in name and len(completion) > 0: + if "followup" in task_name and len(completion) > 0: if "?" in completion: # take first question that is found and only use the sentence before the question mark completion = completion.split("?")[0].split(".")[-1] @@ -119,7 +117,9 @@ async def run_step( self.device ) for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions): - reward_i, reward_i_normalized = reward_fn_i.apply(prompt, responses, name) + reward_i, reward_i_normalized = reward_fn_i.apply( + task.base_text, responses, task_name + ) rewards += weight_i * reward_i_normalized.to(self.device) if not self.config.neuron.disable_log_rewards: event[reward_fn_i.name] = reward_i.tolist() @@ -127,13 +127,28 @@ async def run_step( bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist()) for masking_fn_i in self.masking_functions: - mask_i, mask_i_normalized = masking_fn_i.apply(base_prompt, responses, name) + mask_i, mask_i_normalized = masking_fn_i.apply( + task.base_text, responses, task_name + ) rewards *= mask_i_normalized.to(self.device) # includes diversity if not self.config.neuron.disable_log_rewards: event[masking_fn_i.name] = mask_i.tolist() event[masking_fn_i.name + "_normalized"] = mask_i_normalized.tolist() bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist()) + for penalty_fn_i in self.penalty_functions: + ( + raw_penalty_i, + adjusted_penalty_i, + applied_penalty_i, + ) = penalty_fn_i.apply_penalties(responses, task) + rewards *= applied_penalty_i.to(self.device) + if not self.config.neuron.disable_log_rewards: + event[penalty_fn_i.name + "_raw"] = raw_penalty_i.tolist() + event[penalty_fn_i.name + "_adjusted"] = adjusted_penalty_i.tolist() + event[penalty_fn_i.name + "_applied"] = applied_penalty_i.tolist() + bt.logging.trace(str(penalty_fn_i.name), applied_penalty_i.tolist()) + # Train the gating model based on the predicted scores and the actual rewards. gating_scores: torch.FloatTensor = self.gating_model(prompt).to(self.device) gating_loss: torch.FloatTensor = self.gating_model.backward( @@ -209,68 +224,54 @@ async def forward(self): random_cutoff = random.randint(15, 30) # Truncate context to a limited set of sentences. base_text = ".".join(data.split(".", maxsplit=random_cutoff)[:-1]) - aug_prompt = augment_prompt(base_text) + + # Create a summary task from the context. + summary_task: Task = create_summarization_task(base_text) # Reset Blacklist reward model self.blacklist.reset() # Request a summary, given the original context. - augment_event = await run_step( + summarization_event = await run_step( self, - prompt=aug_prompt, - name="augment", + task=summary_task, k=self.config.neuron.followup_sample_size, timeout=self.config.neuron.followup_timeout, ) - base_text = augment_event["best"] - base_prompt = augment_event["best"] - exclude = augment_event["uids"] + best_summary = summarization_event["best"] + exclude = summarization_event["uids"] + prompt_context = "### SUMMARY CONTEXT:\n" + best_summary for k in range(self.config.neuron.num_followup_steps): # Get a followup question, given the summarized context. - prompt = followup_prompt(base_text, i=k) - followup_event = await run_step( + qg_task = create_qg_task(base_text=prompt_context, index=k) + qg_event = await run_step( self, - prompt=prompt, - name="followup" + str(k), + task=qg_task, k=self.config.neuron.followup_sample_size, timeout=self.config.neuron.followup_timeout, exclude=exclude, - base_prompt=base_prompt, ) - exclude += followup_event["uids"] + exclude += qg_event["uids"] - # Ask the followup question, given the original context. - prompt = answer_prompt(base_text, followup_event["best"]) - answer_event = await run_step( + # Adds the best question to the prompt context. + best_question = qg_event["best"] + prompt_context += f"\n### QUESTION {k}:\n{best_question}" + + qa_task = create_qa_task(prompt_context, index=k) + qa_event = await run_step( self, - prompt=prompt, - name="answer" + str(k), + task=qa_task, k=self.config.neuron.answer_sample_size, timeout=self.config.neuron.answer_timeout, exclude=exclude, - base_prompt=followup_event["best"], ) - exclude += answer_event["uids"] - - self.blacklist.question_blacklist.append(followup_event["best"]) - self.blacklist.answer_blacklist.append(answer_event["best"]) - - if k == 0: - # Extend the base text with the best answer. - base_text = ( - base_text - + "\nPrevious Question \nQuestion:" - + followup_event["best"] - + "\nAnswer:" - + answer_event["best"] - ) - else: - base_text = ( - base_text - + "\nQuestion:" - + followup_event["best"] - + "\nAnswer:" - + answer_event["best"] - ) + + best_answer = qa_event["best"] + prompt_context += f"\n### ANSWER {k}:\n{best_answer}" + + exclude += qa_event["uids"] + + self.blacklist.question_blacklist.append(qg_event["best"]) + self.blacklist.answer_blacklist.append(qa_event["best"]) diff --git a/prompting/validators/mock.py b/prompting/validators/mock.py index 5c1056c..574a6d3 100644 --- a/prompting/validators/mock.py +++ b/prompting/validators/mock.py @@ -21,13 +21,14 @@ import bittensor as bt from prompting.validators.prompts import FirewallPrompt, FollowupPrompt, AnswerPrompt from prompting.validators.gating import BaseGatingModel +from prompting.validators.reward import BaseRewardModel from typing import List class MockGatingModel(BaseGatingModel): def __init__(self, num_uids: int): super(MockGatingModel, self).__init__() - # super(MockGatingModel, self).__init__() + self.num_uids = num_uids self.linear = torch.nn.Linear(256, 10) @@ -45,7 +46,27 @@ def resync( pass -class MockRewardModel(torch.nn.Module): +class MockRewardModel(BaseRewardModel): + question_blacklist = [] + answer_blacklist = [] + + @property + def name(self) -> str: + return self.mock_name + + def __init__(self, mock_name: str = "MockReward"): + super().__init__() + self.mock_name = mock_name + self.question_blacklist = [] + self.answer_blacklist = [] + + def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor: + mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32) + return mock_reward, mock_reward + + def reset(self): + return self + def reward( self, completions_with_prompt: List[str], diff --git a/prompting/validators/penalty/__init__.py b/prompting/validators/penalty/__init__.py new file mode 100644 index 0000000..2e4e67a --- /dev/null +++ b/prompting/validators/penalty/__init__.py @@ -0,0 +1,4 @@ +from .penalty import BasePenaltyModel, PenaltyModelType +from .task_validation import TaskValidationPenaltyModel +from .keyword_match import KeywordMatchPenaltyModel +from .content_match import ContentMatchPenaltyModel diff --git a/prompting/validators/penalty/content_match.py b/prompting/validators/penalty/content_match.py new file mode 100644 index 0000000..fdba288 --- /dev/null +++ b/prompting/validators/penalty/content_match.py @@ -0,0 +1,60 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import re +import torch +from typing import List +from prompting.validators.tasks import Task +from prompting.validators.penalty.penalty import BasePenaltyModel, PenaltyModelType + + +class ContentMatchPenaltyModel(BasePenaltyModel): + @property + def name(self) -> str: + return PenaltyModelType.sentence_match_penalty.value + + def calculate_penalties( + self, task: Task, completions: List[str] + ) -> torch.FloatTensor: + # NOTE: This is an example placeholder, the data source can be easily expanded to include more sentences + # or be externalized in a public hugging face dataset. + system_messages_penalizing_sentences = [ + r"here(?:\s+is|\s*'s)\s+a\s+task", # here is a task, here's a task + r"here(?:\s+is|\s*'s)\s+the\s+solution", # here is the solution, here's the solution + r"here(?:\s+is|\s*'s)\s+my\s+question", # here is my question, here's my question + r"what\s+have\s+we\s+learned\s+from\s+this\s+task\?", # what have we learned from this task? + r"use\s+complete\s+sentences", # use complete sentences + r"the\s+question\s+was", # the question was + r"use\s+proper\s+grammar", # use proper grammar + r"what\s+is\s+the\s+correct\s+order\s+of\s+the\s+key\s+points", # what is the correct order of the key points + r"sure!\s+here.+", # sure! here... + r"solution\s+\(in\s+\w+\)", # solution (in \w+) + r"great\s+job!\s+here(?:'s| is)", # great job! here... + r"keep\s+it\s+clear\s+and\s+concise.\s+Use\s+complete\s+sentences.", # keep it clear and concise. Use complete sentences. + ] + + penalties = [] + for completion in completions: + accumulated_penalty = 0.0 + # Trim and consider only the first 200 characters + completion_segment = completion.strip()[:200].lower() + for pattern in system_messages_penalizing_sentences: + if re.search(pattern, completion_segment): + accumulated_penalty += 0.1 + penalties.append(accumulated_penalty) + + return torch.tensor(penalties, dtype=torch.float32) diff --git a/prompting/validators/reward/task_validator.py b/prompting/validators/penalty/keyword_match.py similarity index 69% rename from prompting/validators/reward/task_validator.py rename to prompting/validators/penalty/keyword_match.py index 0e3f494..c497220 100644 --- a/prompting/validators/reward/task_validator.py +++ b/prompting/validators/penalty/keyword_match.py @@ -15,21 +15,19 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import re import torch from typing import List -from .config import RewardModelType -from .reward import BaseRewardModel +from prompting.validators.tasks import Task +from prompting.validators.penalty.penalty import BasePenaltyModel, PenaltyModelType -class TaskValidator(BaseRewardModel): +class KeywordMatchPenaltyModel(BasePenaltyModel): @property def name(self) -> str: - return RewardModelType.task_validator.value + return PenaltyModelType.keyword_match_penalty.value - def __init__(self): - super().__init__() - - def reward(self, prompt: str, completion: str, name: str) -> float: + def check_exploits_keywords(self, completion: str, name: str) -> float: summary_keywords = ["Summary:", "Paraphrase:", "Paraphrasing:", "Paraphrased:"] question_keywords = ["Question:", "Query:", "Q:"] answer_keywords = ["Answer:", "Response:", "A:", "Completion:"] @@ -54,28 +52,36 @@ def reward(self, prompt: str, completion: str, name: str) -> float: if ( is_summarization_prompt or is_question_prompt ) and completion_contains_answer: - return 0.0 + return 1 if ( is_summarization_prompt or is_answer_prompt ) and completion_contains_question: - return 0.0 + return 1 if not is_summarization_prompt and completion_contains_summary: - return 0.0 - - return 1 - - def get_rewards( - self, prompt: str, completions: List[str], name: str + return 1 + + # Patterns defined accordingly to task orchestrator in forward function. + # Punishes responses that copy the context + text_separation_patterns = [ + r"#+[\d\s]*QUESTION[\d\s]*:", + r"f\"\\n#+[\d\s]*ANSWER[\d\s]*:", + r"#+[\d\s]*SUMMARY[\d\s]*CONTEXT:", + ] + for pattern in text_separation_patterns: + if re.search(pattern, completion, re.IGNORECASE): + return 1 + + return 0 + + def calculate_penalties( + self, task: Task, completions: List[str] ) -> torch.FloatTensor: return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], + [ + self.check_exploits_keywords(completion, task.task_name) + for completion in completions + ], dtype=torch.float32, ) - - def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor: - return rewards - - def reset(self): - pass diff --git a/prompting/validators/penalty/penalty.py b/prompting/validators/penalty/penalty.py new file mode 100644 index 0000000..73035dd --- /dev/null +++ b/prompting/validators/penalty/penalty.py @@ -0,0 +1,66 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import torch +import bittensor as bt +from enum import Enum +from typing import List +from abc import ABC, abstractmethod +from prompting.validators.tasks import Task + + +class BasePenaltyModel(ABC): + def __init__(self, max_penalty: float): + self.max_penalty = max_penalty + + @property + @abstractmethod + def name(self) -> str: + ... + + def __str__(self) -> str: + return str(self.name) + + def __repr__(self) -> str: + return str(self.name) + + @abstractmethod + def calculate_penalties(task: Task, completions: List[str]) -> torch.FloatTensor: + ... + + def apply_penalties( + self, responses: List[bt.Synapse], task: Task + ) -> torch.FloatTensor: + completions = [response.completion for response in responses] + raw_penalties = self.calculate_penalties(task, completions) + + # Clip penalties between 0 and 1 + adjusted_penalties = torch.clip(raw_penalties, 0, 1) + + # Clip penalties between 0 and self.max_penalty + adjusted_penalties = torch.clip(adjusted_penalties, 0, self.max_penalty) + + # Invert penalties to scale rewards accordingly + applied_penalties = 1 - adjusted_penalties + + return raw_penalties, adjusted_penalties, applied_penalties + + +class PenaltyModelType(Enum): + task_validation_penalty = "task_validation_penalty" + keyword_match_penalty = "keyword_match_penalty" + sentence_match_penalty = "sentence_match_penalty" diff --git a/prompting/validators/penalty/task_validation.py b/prompting/validators/penalty/task_validation.py new file mode 100644 index 0000000..3715528 --- /dev/null +++ b/prompting/validators/penalty/task_validation.py @@ -0,0 +1,40 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import torch +from typing import List +from prompting.validators.tasks import Task +from prompting.validators.penalty.penalty import BasePenaltyModel, PenaltyModelType + + +class TaskValidationPenaltyModel(BasePenaltyModel): + @property + def name(self) -> str: + return PenaltyModelType.task_validation_penalty.value + + def calculate_penalties( + self, task: Task, completions: List[str] + ) -> torch.FloatTensor: + accumulated_penalties: torch.FloatTensor = torch.zeros( + len(completions), dtype=torch.float32 + ) + + # Accumulate penalties for each criterion + for criterion in task.criteria: + accumulated_penalties.add_(criterion.evaluate(completions)) + + return accumulated_penalties diff --git a/prompting/validators/reward/__init__.py b/prompting/validators/reward/__init__.py index d26773f..b83f85f 100644 --- a/prompting/validators/reward/__init__.py +++ b/prompting/validators/reward/__init__.py @@ -1,12 +1,10 @@ from .blacklist import Blacklist -from .task_validator import TaskValidator from .nsfw import NSFWRewardModel from .dpo import DirectPreferenceRewardModel from .open_assistant import OpenAssistantRewardModel from .reciprocate import ReciprocateRewardModel from .relevance import RelevanceRewardModel from .reward import BaseRewardModel -from .reward import MockRewardModel from .dahoas import DahoasRewardModel from .diversity import DiversityRewardModel from .prompt import PromptRewardModel diff --git a/prompting/validators/reward/config.py b/prompting/validators/reward/config.py index ea5df05..ecf9160 100644 --- a/prompting/validators/reward/config.py +++ b/prompting/validators/reward/config.py @@ -30,6 +30,7 @@ class RewardModelType(Enum): nsfw = "nsfw_filter" relevance = "relevance_filter" task_validator = "task_validator_filter" + keyword_match = "keyword_match_penalty" @dataclass(frozen=True) @@ -38,8 +39,8 @@ class DefaultRewardFrameworkConfig: Note: All the weights should add up to 1.0. """ - dpo_model_weight: float = 0.2 - rlhf_model_weight: float = 0.4 - reciprocate_model_weight: float = 0.4 + dpo_model_weight: float = 0.425 + rlhf_model_weight: float = 0.15 + reciprocate_model_weight: float = 0.425 dahoas_model_weight: float = 0 prompt_model_weight: float = 0 diff --git a/prompting/validators/reward/diversity.py b/prompting/validators/reward/diversity.py index d0a5b77..66314c5 100644 --- a/prompting/validators/reward/diversity.py +++ b/prompting/validators/reward/diversity.py @@ -68,7 +68,7 @@ def __init__(self, device: str): self.history_reward_bottom_k = 2 self.historic_embeddings = torch.tensor([]).to(self.device) self.history_range = (500, 15500) - self.boundary = 0.5 + self.boundary = 0.2 def get_embeddings(self, sentences: List[str]) -> "torch.FloatTensor": """Runs a forward pass through the model. diff --git a/prompting/validators/reward/reward.py b/prompting/validators/reward/reward.py index 23e7479..c20220d 100644 --- a/prompting/validators/reward/reward.py +++ b/prompting/validators/reward/reward.py @@ -137,20 +137,3 @@ def apply( # Return the filled rewards. return filled_rewards, filled_rewards_normalized - - -class MockRewardModel(BaseRewardModel): - @property - def name(self) -> str: - return self.mock_name - - def __init__(self, mock_name: str = "MockReward"): - super().__init__() - self.mock_name = mock_name - - def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor: - mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32) - return mock_reward, mock_reward - - def reset(self): - return self diff --git a/prompting/validators/tasks.py b/prompting/validators/tasks.py new file mode 100644 index 0000000..e161d7b --- /dev/null +++ b/prompting/validators/tasks.py @@ -0,0 +1,189 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import torch +import textwrap +import random +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from typing import List +from prompting.validators.criteria import ( + TaskCriterion, + MatchLengthCriteria, + TextLengthUnitEnum, +) + + +@dataclass +class Task(ABC): + base_text: str + task_name: str + task_type: str + criteria: List[TaskCriterion] = field(default_factory=list) + + @abstractmethod + def compose_prompt(self) -> str: + ... + + +class SummaryTask(Task): + def compose_prompt(self) -> str: + # Aggregates criteria in bullet points + criteria_bullet_points = [ + f"- {criterion.compose_text()}" for criterion in self.criteria + ] + criteria_bullet_points_str = "\n".join(criteria_bullet_points) + + prompt_template = textwrap.dedent( + """\ + Your task is to summarize the text delimited with triple backticks: + '''{base_text}''' + + The following criteria must be respected: + {criteria} + - Do not try to create questions or answers for your summarization. + """ + ) + + prompt = prompt_template.format( + base_text=self.base_text, criteria=criteria_bullet_points_str + ) + return prompt + + +class QuestionGenerationTask(Task): + def compose_prompt(self) -> str: + # Aggregates criteria in bullet points + criteria_bullet_points = [ + f"- {criterion.compose_text()}" for criterion in self.criteria + ] + criteria_bullet_points_str = "\n".join(criteria_bullet_points) + + prompt_template = textwrap.dedent( + """\ + Your task is to ask a single relevant and insightful question about the preceding context delimited with triple backticks: + '''{base_text}''' + + The following criteria must be respected: + {criteria} + - Do not answer the question you generate. + - Do not try to summarize the text + """ + ) + + prompt = prompt_template.format( + base_text=self.base_text, criteria=criteria_bullet_points_str + ) + return prompt + + +class QuestionAnswerTask(Task): + def compose_prompt(self) -> str: + # Aggregates criteria in bullet points + criteria_bullet_points = [ + f"- {criterion.compose_text()}" for criterion in self.criteria + ] + criteria_bullet_points_str = "\n".join(criteria_bullet_points) + + prompt_template = textwrap.dedent( + """\ + Read the preceding context delimited with triple backticks carefully. + Your task is to provide a step-by-step answer to the last question found in the text, elaborating on your thought process: + '''{base_text}''' + + The following criteria must be respected: + {criteria} + - Do not include questions or summaries in your answer. + """ + ) + + prompt = prompt_template.format( + base_text=self.base_text, criteria=criteria_bullet_points_str + ) + return prompt + + +def create_summarization_task(base_text: str) -> SummaryTask: + possible_criterias = [ + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(50, 200), + unit=TextLengthUnitEnum.WORDS, + ), + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(4, 8), + unit=TextLengthUnitEnum.SENTENCES, + ), + ] + + sampled_criterias = random.sample(possible_criterias, 1) + + return SummaryTask( + base_text=base_text, + criteria=sampled_criterias, + task_type="summarization", + task_name="augment", + ) + + +def create_qg_task(base_text: str, index: int) -> QuestionGenerationTask: + possible_criterias = [ + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(10, 40), + unit=TextLengthUnitEnum.WORDS, + ), + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(40, 150), + unit=TextLengthUnitEnum.CHARACTERS, + ), + ] + + sampled_criterias = random.sample(possible_criterias, 1) + + return QuestionGenerationTask( + base_text=base_text, + criteria=sampled_criterias, + task_type="question-generation", + task_name=f"followup{index}", + ) + + +def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask: + possible_criterias = [ + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(50, 200), + unit=TextLengthUnitEnum.WORDS, + ), + MatchLengthCriteria( + penalty=0.1, + target_length=random.randint(4, 8), + unit=TextLengthUnitEnum.SENTENCES, + ), + ] + + sampled_criterias = random.sample(possible_criterias, 1) + + return QuestionAnswerTask( + base_text=base_text, + criteria=sampled_criterias, + task_type="question-answer", + task_name=f"answer{index}", + ) diff --git a/prompting/validators/utils.py b/prompting/validators/utils.py index f39daed..072542e 100644 --- a/prompting/validators/utils.py +++ b/prompting/validators/utils.py @@ -23,7 +23,6 @@ import bittensor as bt import prompting.validators as validators from prompting.validators.misc import ttl_get_block -from prompting.validators.reward import MockRewardModel def should_reinit_wandb(self): @@ -49,7 +48,7 @@ def init_wandb(self, reinit=False): if self.config.neuron.use_custom_gating_model: tags.append("custom_gating_model") for fn in self.reward_functions: - if not isinstance(fn, MockRewardModel): + if not self.config.neuron.mock_reward_models: tags.append(str(fn.name)) if self.config.neuron.disable_set_weights: tags.append("disable_set_weights")