diff --git a/neurons/validators/validator.py b/neurons/validators/validator.py index fe1ec78..93d9375 100644 --- a/neurons/validators/validator.py +++ b/neurons/validators/validator.py @@ -194,8 +194,8 @@ def __init__(self): MockRewardModel(RewardModelType.nsfw.value), ] self.penalty_functions = [ - TaskValidationPenaltyModel(max_penalty=0.1), - ContentMatchPenaltyModel(max_penalty=0.1), + TaskValidationPenaltyModel(max_penalty=0.6), + ContentMatchPenaltyModel(max_penalty=0.2), KeywordMatchPenaltyModel(max_penalty=1), ] bt.logging.debug(str(self.reward_functions)) @@ -277,8 +277,8 @@ def __init__(self): ] self.penalty_functions = [ - TaskValidationPenaltyModel(max_penalty=0.1), - ContentMatchPenaltyModel(max_penalty=0.1), + TaskValidationPenaltyModel(max_penalty=0.6), + ContentMatchPenaltyModel(max_penalty=0.2), KeywordMatchPenaltyModel(max_penalty=1), ] diff --git a/prompting/validators/__init__.py b/prompting/validators/__init__.py index 6855712..6ca818c 100644 --- a/prompting/validators/__init__.py +++ b/prompting/validators/__init__.py @@ -27,7 +27,7 @@ from . import event from . import dataset -__version__ = "2.1.1" +__version__ = "2.1.2" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index ae464dc..49fdbc5 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -18,7 +18,7 @@ import re import torch import numpy as np -from dataclasses import dataclass +from dataclasses import dataclass, field from abc import ABC, abstractmethod from typing import List from enum import Enum @@ -109,3 +109,136 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor: def compose_text(self) -> str: return self.text.format(target_length=self.target_length, unit=self.unit.value) + + +class ContentMatchTypeEnum(Enum): + STARTS_WITH = "starts with" + ENDS_WITH = "ends with" + INCLUDES = "includes" + + +@dataclass +class MatchContentCriteria(TaskCriterion): + default_text: str = ( + "Your response should {match_type} the following words: {words}." + ) + text: str = default_text + penalty: float = 0.1 + n_words: int = 3 + words_array: List[str] = field(default_factory=list) + contentMatchType: ContentMatchTypeEnum = ContentMatchTypeEnum.STARTS_WITH + sampled_words: List[str] = field(init=False) + negate_match: bool = False + + def __post_init__(self): + # Randomly sample words from the array based on n_words + self.sampled_words = np.random.choice( + self.words_array, self.n_words, replace=False + ) + + def _get_regex_pattern(self): + # Escape all special characters in the sampled words + escaped_words = map(re.escape, self.sampled_words) + + if self.contentMatchType == ContentMatchTypeEnum.STARTS_WITH: + return rf"^\s*({'|'.join(escaped_words)})\b" + elif self.contentMatchType == ContentMatchTypeEnum.ENDS_WITH: + return rf"({'|'.join(escaped_words)})\s*$" + else: # ContentMatchTypeEnum.INCLUDES + return rf"({'|'.join(escaped_words)})" + + def evaluate(self, completions: List[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype=torch.float32) + # Define regex pattern based on contentMatchType + pattern = self._get_regex_pattern() + + for idx, completion in enumerate(completions): + # Check if the completion matches the pattern + match = re.search(pattern, completion, re.IGNORECASE) + + completion_with_undesired_match = self.negate_match and match + completion_without_desired_match = not self.negate_match and not match + + if completion_with_undesired_match or completion_without_desired_match: + penalties[idx] = self.penalty + + return penalties + + def compose_text(self) -> str: + # Check if the text property is different than the default. If so, use that text. + if self.text != MatchContentCriteria.default_text: + return self.text + + # Adds "should" or "should not" instruction based on the negate_match property + should_match_text = "should" if not self.negate_match else "should not" + + # Get the list of selected sampled words + words_list = ", ".join(self.sampled_words) + + # Get the descriptive text of the match type + match_type_text = self.contentMatchType.value + + # Adjust the text based on the number of words + if self.n_words > 1: + text = f"Your response {should_match_text} {match_type_text} one of the following words: {words_list}." + else: + text = f"Your response {should_match_text} {match_type_text} the following word: {words_list}." + return text + + +@dataclass +class SimpleResponseLayoutCriteria(TaskCriterion): + penalty: float = 0.1 + text: str = "Your response should not contain any bullet points or numbered lists." + + def evaluate(self, completions: List[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype=torch.float32) + + # Regex patterns to match bullet points (unordered lists) and numbered lists + bullet_point_pattern = re.compile(r"(\*|\-|\+|\•|\‣|\◦)\s") + numbered_list_pattern = re.compile(r"\d+\.\s") + + for idx, completion in enumerate(completions): + # Check if the completion contains a bullet point or numbered list + if bullet_point_pattern.search(completion) or numbered_list_pattern.search( + completion + ): + penalties[idx] = self.penalty + + return penalties + + def compose_text(self) -> str: + return self.text + + +class LayoutMatchTypeEnum(Enum): + UNORDERED_LIST = "unordered list" + NUMBERED_LIST = "numbered list" + + +@dataclass +class MatchLayoutCriteria(TaskCriterion): + layout_type: LayoutMatchTypeEnum = LayoutMatchTypeEnum.UNORDERED_LIST + penalty: float = 0.1 + text: str = "Your response should be ordered in format of {layout_type}." + + def evaluate(self, completions: List[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype=torch.float32) + + # Regex patterns based on layout type + bullet_point_pattern = re.compile(r"(\*|\-|\+|\•|\‣|\◦)\s") + numbered_list_pattern = re.compile(r"\d+\.\s") + + for idx, completion in enumerate(completions): + # Evaluate based on the layout type + if self.layout_type == LayoutMatchTypeEnum.UNORDERED_LIST: + if not bullet_point_pattern.search(completion): + penalties[idx] = self.penalty + elif self.layout_type == LayoutMatchTypeEnum.NUMBERED_LIST: + if not numbered_list_pattern.search(completion): + penalties[idx] = self.penalty + + return penalties + + def compose_text(self) -> str: + return self.text.format(layout_type=self.layout_type) diff --git a/prompting/validators/event.py b/prompting/validators/event.py index df7a26a..206554f 100644 --- a/prompting/validators/event.py +++ b/prompting/validators/event.py @@ -48,6 +48,12 @@ class EventSchema: List[float] ] # Output vector of the dahoas reward model blacklist_filter: Optional[List[float]] # Output vector of the blacklist filter + blacklist_filter_matched_ngram: Optional[ + List[str] + ] # Output vector of the blacklist filter + blacklist_filter_significance_score: Optional[ + List[float] + ] # Output vector of the blacklist filter nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter reciprocate_reward_model: Optional[ List[float] @@ -55,6 +61,12 @@ class EventSchema: diversity_reward_model: Optional[ List[float] ] # Output vector of the diversity reward model + diversity_reward_model_historic: Optional[ + List[float] + ] # Output vector of the diversity reward model + diversity_reward_model_batch: Optional[ + List[float] + ] # Output vector of the diversity reward model dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model prompt_reward_model: Optional[ @@ -65,6 +77,7 @@ class EventSchema: List[float] ] # Output vector of the dahoas reward model nsfw_filter_normalized: Optional[List[float]] # Output vector of the nsfw filter + nsfw_filter_score: Optional[List[float]] # Output vector of the nsfw filter reciprocate_reward_model_normalized: Optional[ List[float] ] # Output vector of the reciprocate reward model @@ -80,7 +93,16 @@ class EventSchema: prompt_reward_model_normalized: Optional[ List[float] ] # Output vector of the prompt reward model - relevance_filter_normalized: Optional[List[float]] + + relevance_filter_normalized: Optional[ + List[float] + ] # Output vector of the relevance scoring reward model + relevance_filter_bert_score: Optional[ + List[float] + ] # Output vector of the relevance scoring reward model + relevance_filter_mpnet_score: Optional[ + List[float] + ] # Output vector of the relevance scoring reward model # TODO: Add comments task_validation_penalty_raw: Optional[List[float]] task_validation_penalty_adjusted: Optional[List[float]] @@ -109,6 +131,12 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": RewardModelType.reciprocate.value ), "diversity_reward_model": event_dict.get(RewardModelType.diversity.value), + "diversity_reward_model_historic": event_dict.get( + RewardModelType.diversity.value + "_historic" + ), + "diversity_reward_model_batch": event_dict.get( + RewardModelType.diversity.value + "_batch" + ), "dpo_reward_model": event_dict.get(RewardModelType.dpo.value), "rlhf_reward_model": event_dict.get(RewardModelType.rlhf.value), "prompt_reward_model": event_dict.get(RewardModelType.prompt.value), @@ -136,6 +164,19 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema": "prompt_reward_model_normalized": event_dict.get( RewardModelType.prompt.value + "_normalized" ), + "blacklist_filter_matched_ngram": event_dict.get( + RewardModelType.blacklist.value + "_matched_ngram" + ), + "blacklist_filter_significance_score": event_dict.get( + RewardModelType.blacklist.value + "_significance_score" + ), + "relevance_filter_bert_score": event_dict.get( + RewardModelType.relevance.value + "_bert_score" + ), + "relevance_filter_mpnet_score": event_dict.get( + RewardModelType.relevance.value + "_mpnet_score" + ), + "nsfw_filter_score": event_dict.get(RewardModelType.nsfw.value + "_score"), } penalties = { "task_validation_penalty_raw": event_dict.get( diff --git a/prompting/validators/forward.py b/prompting/validators/forward.py index 5ea5c94..0c54d78 100644 --- a/prompting/validators/forward.py +++ b/prompting/validators/forward.py @@ -96,44 +96,48 @@ async def run_step(self, task: Task, k: int, timeout: float, exclude: list = []) timeout=timeout, ) + # Update blacklist with completions so that n-gram filtering can be applied + self.blacklist.add( + [response.completion for response in responses if response.completion] + ) + # Restrict the format of acceptable followup completions. for response in responses: # remove leading and trailing periods completion = response.completion.strip(".") if "followup" in task_name and len(completion) > 0: + # take maximum of 40 words + max_words = 40 if "?" in completion: # take first question that is found and only use the sentence before the question mark completion = completion.split("?")[0].split(".")[-1] + response.completion = " ".join(completion.split(" ")[-max_words:]) + "?" else: # otherwise take the last sentence completion = completion.split(".")[-1].split(".")[-1] - - # take maximum of 40 words - response.completion = " ".join(completion.split(" ")[-40:]) + "?" + response.completion = " ".join(completion.split(" ")[-max_words:]) # Compute the rewards for the responses given the prompt. rewards: torch.FloatTensor = torch.zeros(len(responses), dtype=torch.float32).to( self.device ) for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions): - reward_i, reward_i_normalized = reward_fn_i.apply( + reward_i_normalized, reward_event = reward_fn_i.apply( task.base_text, responses, task_name ) rewards += weight_i * reward_i_normalized.to(self.device) if not self.config.neuron.disable_log_rewards: - event[reward_fn_i.name] = reward_i.tolist() - event[reward_fn_i.name + "_normalized"] = reward_i_normalized.tolist() + event = {**event, **reward_event} bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist()) for masking_fn_i in self.masking_functions: - mask_i, mask_i_normalized = masking_fn_i.apply( + mask_i_normalized, reward_event = masking_fn_i.apply( task.base_text, responses, task_name ) rewards *= mask_i_normalized.to(self.device) # includes diversity if not self.config.neuron.disable_log_rewards: - event[masking_fn_i.name] = mask_i.tolist() - event[masking_fn_i.name + "_normalized"] = mask_i_normalized.tolist() + event = {**event, **reward_event} bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist()) for penalty_fn_i in self.penalty_functions: @@ -271,6 +275,3 @@ async def forward(self): ) exclude += qa_event["uids"] - - self.blacklist.question_blacklist.append(qg_event["best"]) - self.blacklist.answer_blacklist.append(qa_event["best"]) diff --git a/prompting/validators/mock.py b/prompting/validators/mock.py index 574a6d3..461aaff 100644 --- a/prompting/validators/mock.py +++ b/prompting/validators/mock.py @@ -62,7 +62,7 @@ def __init__(self, mock_name: str = "MockReward"): def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor: mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32) - return mock_reward, mock_reward + return mock_reward, {} def reset(self): return self diff --git a/prompting/validators/penalty/content_match.py b/prompting/validators/penalty/content_match.py index fdba288..22757df 100644 --- a/prompting/validators/penalty/content_match.py +++ b/prompting/validators/penalty/content_match.py @@ -41,10 +41,15 @@ def calculate_penalties( r"the\s+question\s+was", # the question was r"use\s+proper\s+grammar", # use proper grammar r"what\s+is\s+the\s+correct\s+order\s+of\s+the\s+key\s+points", # what is the correct order of the key points - r"sure!\s+here.+", # sure! here... + r"sure!\s+her.+", # sure! here... r"solution\s+\(in\s+\w+\)", # solution (in \w+) r"great\s+job!\s+here(?:'s| is)", # great job! here... r"keep\s+it\s+clear\s+and\s+concise.\s+Use\s+complete\s+sentences.", # keep it clear and concise. Use complete sentences. + r"task\s*:" r"i\s+can\s+help\s+you\s+with", # task: # I can help you with + r"what\s+did\s+I\s+learn\s+today\s*\?", # what did I learn today? + r"paraphrase\s*:", # paraphrase: + r"your\s+task\s+now\s+is\s+to\s+write\s+a\s+tweet\s+about\s+the\s+previous\s+text", # your task now is to write a tweet about the previous text + r"what\s+is\s+the\s+main\s+point\s+of\s+the\s+passage", # what is the main point of the passage ] penalties = [] @@ -53,7 +58,7 @@ def calculate_penalties( # Trim and consider only the first 200 characters completion_segment = completion.strip()[:200].lower() for pattern in system_messages_penalizing_sentences: - if re.search(pattern, completion_segment): + if re.search(pattern, completion_segment, re.IGNORECASE): accumulated_penalty += 0.1 penalties.append(accumulated_penalty) diff --git a/prompting/validators/reward/blacklist.py b/prompting/validators/reward/blacklist.py index 4f45591..77b68b9 100644 --- a/prompting/validators/reward/blacklist.py +++ b/prompting/validators/reward/blacklist.py @@ -16,12 +16,24 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import re import torch -from typing import List +import math +from fuzzywuzzy import fuzz +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent +from transformers import BertTokenizer +from dataclasses import dataclass -blacklist = ["That is an excellent question."] + +# TODO: Use CLI arguments to set blacklist values: the most important being the boundary value and max_size + + +@dataclass +class BlacklistRewardEvent(BaseRewardEvent): + matched_ngram: str = None + significance_score: float = None class Blacklist(BaseRewardModel): @@ -29,34 +41,283 @@ class Blacklist(BaseRewardModel): def name(self) -> str: return RewardModelType.blacklist.value - def __init__(self): + def __init__( + self, + boundary: float = 6, + n_min: int = 5, + n_max: int = 14, + word_limit: int = 2000, + A: float = 1.3, + preprocess: str = "[^(\\w|\\s)]", + partial_ratio_boundary: float = 95, + half_life: int = 20000, + support: float = 0.01, + error: float = 0.001, + memory_lim: int = 1_000_000, + frequency_multiplier: float = 100, + ): + """N-gram blacklist reward model which penalizes overused phrases in the network + + Args: + boundary (float, optional): Cutoff for flagging completions and giving zero reward. Defaults to 6. + max_size (int, optional): Maximum size of sliding window to use for aggregating ngrams. Defaults to 1_000_000. + n_min (int, optional): Smallest ngram size. Defaults to 5. + n_max (int, optional): Largest ngram size. Defaults to 14. + word_limit (int, optional): Maximum word length, to prevent extremely long completions from overworking the queue. Defaults to 2000. + A (float, optional): Exponent used in significance scoring, smaller A gives more weight to smaller ngrams. Values of 1.1-2 are recommended. Defaults to 1.3. + preprocess (str, optional): Regex preprocessing string to make text more uniform. Defaults to '[^(\w|\s)]'. + partial_ratio_boundry (int, optional): Boundry for fuzzy match. Default to 95. + half_life (int, optional): Half life of the counter. ie. When the number of completions processed > half life, then put all the counters in half. + support (float, optional): The percentage of times that a phrase need to appear to get the phrase kept in counter. (support should be >> counter) + error (float, optional): Error parameter for lossy sampling, should be as small as possible, further decreasing it further will increase memory usage. (support should be >> error ) + memory_lim (int, optional): Max number of counter entry to save for memory protection. + frequency_multiplier (float, optional): Multiplier for phrases frequency. Default to 100. + """ super().__init__() - self.question_blacklist = [] - self.answer_blacklist = [] - def reward(self, prompt: str, completion: str, name: str) -> float: - if completion in blacklist: - return 0.0 + self.counter = {} + + self.n_min = n_min + self.n_max = n_max + self.word_limit = word_limit + + self.significance_scores = {} # Store significance scores + self.A = A + self.boundary = boundary + self.partial_ratio_boundary = partial_ratio_boundary + + self.preprocess = re.compile(preprocess) if preprocess else None + self._last_update = 0 + + # Lossy sampling parameters + self.support = support + self.error = error + self.window = math.ceil( + 1 / self.error + ) # Window size, counter would get pruned once for each window. + self.w_current = 1 # window index. + self.num_ngram = 0 + self.num_completion = 0 + + self.half_life = half_life + self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + self.memory_lim = memory_lim + self.frequency_multiplier = frequency_multiplier + + def add(self, texts: List[str]): + """Extract and add n-grams from a list of texts to counter + + Args: + texts (list): batch of completion texts + """ + + for text in texts: + # Extract n-grams from lowercased text + ngrams = self.extract_ngrams(text.lower()) + + if ngrams: + self._add_ngrams(ngrams) + + def extract_ngrams(self, text: str) -> List[tuple]: + """Extract n-grams from text string + + Args: + text (str): completion text + + Returns: + list: List of n-gram tuples + + """ + + if self.preprocess: + # remove all punctuation + text = self.preprocess.sub("", text) + + words = self.tokenizer(text.lower())["input_ids"][1:-1] + + if self.word_limit is not None: + words = words[: self.word_limit] + + ngrams = [] + for i in range(self.n_min, self.n_max + 1): + ngrams.extend(zip(*[words[j:] for j in range(i)])) + + return ngrams + + def _add_ngrams(self, ngrams: List[tuple]): + """Adds n-grams to counter, removing old n-grams periodically. + Counting and pruning method based on Lossy counter. + Reference: https://files.ifi.uzh.ch/dbtg/sdbs13/T01.3.pdf + + Args: + ngrams (List[tuple]): List of n-gram tuples + """ + + for ngram in ngrams: + if ngram in self.counter: + self.counter[ngram][0] += 1 + else: + # Store the tuple (frequency, max_error) + self.counter[ngram] = [1, self.w_current - 1] + + self.num_ngram += 1 + + self.num_completion += 1 + + # Prune when move to next window. + if self.num_completion % self.window == 0: + self.w_current = math.ceil(self.num_completion / self.window) + self.prune() + + # Safety feature: prune when reached max memory size. + if len(self.counter) > self.memory_lim: + self.w_current += 1 + self.prune() + + # Apply half life for the counter + if self.num_completion > self.half_life: + self.set_counter_to_half() - if completion == prompt: - return 0.0 + def prune(self): + """Prune the counter when the count is smaller then bucket index.""" + prune_ele = [] + for ele, (frequency, max_error) in self.counter.items(): + if frequency + max_error <= self.w_current: + prune_ele.append(ele) - if completion in self.question_blacklist or completion in self.answer_blacklist: - return 0.0 + for ele in prune_ele: + del self.counter[ele] - return 1 + def reset(self): + """Reset counters to initial values.""" + self.num_ngram = 0 + self.num_completion = 0 + self.w_current = 1 + self.counter = {} + self.significance_scores = {} + self._last_update = 0 + + def calculate_significance(self) -> dict: + """Calculate significance of all n-grams in counter. By construction, n-grams with count 1 will have significance 0. + + Returns: + dict: Dictionary of n-gram tuples and their significance scores + """ + + significance_scores = {} + for ngram, count in self.counter.items(): + if count[0] + count[1] > max( + self.support * self.num_completion, self.w_current + 1 + ): + decoded_ngram = self.tokenizer.decode(ngram) + if len(decoded_ngram.split()) >= self.n_min: + # calculate significance score for ngram + significance_scores[decoded_ngram] = ( + self.A ** (len(decoded_ngram) - 1) + * ((count[0] + count[1]) / self.num_completion) + * self.frequency_multiplier + ) + + self._last_update = self.num_completion + + return dict( + sorted(significance_scores.items(), key=lambda x: x[1], reverse=True) + ) + + def get_significance(self) -> dict: + """Get significance scores, only recalculating if the counter has been updated. + + Returns: + dict: Dictionary of n-gram tuples and their significance scores + """ + + if self.num_completion - self._last_update > self.window: + self.significance_scores = self.calculate_significance() + + return self.significance_scores + + def most_common(self, n: int = 10) -> dict: + """Get most common n-grams in queue + + Args: + n (int): Number of most common n-grams to return. Defaults to 10. + + Returns: + dict: Sorted dictionary of n-gram tuples and their counts + """ + return sorted( + self.counter.items(), key=lambda x: x[1][0] + x[1][1], reverse=True + )[:n] + + def most_significant(self, n: int = 10, force_update: bool = True) -> dict: + """Get most significant n-grams in queue based on significance scores + + Args: + n (int, optional): Number of most significant n-grams to return. Defaults to 10. + force_update (bool, optional): Force recalculate the significance scores. Defaults to True. + + Returns: + dict: Sorted dictionary of n-gram tuples and their significance scores + """ + + scores = self.get_significance() if force_update else self.significance_scores + return sorted(scores.items(), key=lambda x: x[1], reverse=True)[:n] + + def set_counter_to_half(self): + """Set all the counters to half for a rolling window effect.""" + self.num_ngram = math.ceil(self.num_ngram / 2) + self.num_completion = math.ceil(self.num_completion / 2) + self.w_current = math.ceil(self.num_completion / self.window) + self.counter = { + tokens: [math.ceil(count[0] / 2), math.ceil(count[1] / 2)] + for tokens, count in self.counter.items() + } + self._last_update = 0 + + def reward(self, prompt: str, completion: str, name: str) -> BlacklistRewardEvent: + """Reward function for blacklist reward model. Returns 1 if completion contains an n-gram with significance above the boundary, 0 otherwise. + + Args: + prompt (str): Prompt text + completion (str): Completion text + name (str): Name of the validation step + + Returns: + float: Reward value {0,1} + """ + + reward_event = BlacklistRewardEvent() + + if completion in prompt: + reward_event.reward = 0.0 + return reward_event + + # Get significance scores + scores = self.get_significance() + + # Check if any n-grams have significance above the boundary + for ngram, score in scores.items(): + if ( + score > self.boundary + and fuzz.partial_ratio(ngram, completion.lower()) + > self.partial_ratio_boundary + ): + reward_event.reward = 0 + reward_event.matched_ngram = ngram + reward_event.significance_score = score + return reward_event + + reward_event.reward = 1 + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ) + ) -> List[BlacklistRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + return reward_events def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor: return rewards - - def reset(self): - self.question_blacklist = [] - self.answer_blacklist = [] diff --git a/prompting/validators/reward/config.py b/prompting/validators/reward/config.py index 7583924..61c7a67 100644 --- a/prompting/validators/reward/config.py +++ b/prompting/validators/reward/config.py @@ -29,6 +29,8 @@ class RewardModelType(Enum): blacklist = "blacklist_filter" nsfw = "nsfw_filter" relevance = "relevance_filter" + relevance_bert = "relevance_bert" + relevance_mpnet = "relevance_mpnet" task_validator = "task_validator_filter" keyword_match = "keyword_match_penalty" diff --git a/prompting/validators/reward/dahoas.py b/prompting/validators/reward/dahoas.py index b1183cc..27344ba 100644 --- a/prompting/validators/reward/dahoas.py +++ b/prompting/validators/reward/dahoas.py @@ -18,9 +18,9 @@ import os import torch -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig @@ -63,10 +63,14 @@ def __init__(self, path: str, device: str): self.tokenizer.pad_token = self.tokenizer.eos_token self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] - def reward(self, prompt: str, completion: str, name: str) -> float: + def reward(self, prompt: str, completion: str, name: str) -> BaseRewardEvent: + reward_event = BaseRewardEvent() + def reward_fn(samples): if samples is None: - return 0 + reward_event.reward = 0 + return reward_event + scores_list = [] batch_size = 1 for i in range(0, len(samples), batch_size): @@ -92,21 +96,24 @@ def reward_fn(samples): attention_mask=attn_masks.to(self.device), ) scores_list.append(sub_scores["chosen_end_scores"]) - scores = torch.cat(scores_list, dim=0).mean().item() - return scores + score = torch.cat(scores_list, dim=0).mean().item() + return score with torch.no_grad(): combined_reward = reward_fn(prompt + completion) independent_reward = reward_fn(completion) - return float((combined_reward - independent_reward).item()) + reward_event.reward = float((combined_reward - independent_reward).item()) + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ).to(self.device) + ) -> List[BaseRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + + return reward_events def forward( self, diff --git a/prompting/validators/reward/diversity.py b/prompting/validators/reward/diversity.py index 66314c5..744254d 100644 --- a/prompting/validators/reward/diversity.py +++ b/prompting/validators/reward/diversity.py @@ -18,11 +18,11 @@ import torch import torch.nn.functional as F -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoTokenizer, AutoModel - +from dataclasses import dataclass from torchmetrics.functional import pairwise_cosine_similarity @@ -48,6 +48,12 @@ def mean_pooling(model_output, attention_mask): ) +@dataclass +class DiversityRewardEvent(BaseRewardEvent): + historic: float = None + batch: float = None + + class DiversityRewardModel(BaseRewardModel): diversity_model_path = "sentence-transformers/all-mpnet-base-v2" @@ -155,10 +161,10 @@ def regularise(rewards): def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: + ) -> List[DiversityRewardEvent]: # Check if completions are empty, return 0 if so if len(completions) == 0: - return torch.tensor([]).to(self.device) + return torch.tensor([]).to(self.device), None # Get embeddings for all completions. embeddings = self.get_embeddings(completions) @@ -171,11 +177,17 @@ def get_rewards( self.update_historic_embeddings(embeddings) - # Return all + reward_events = [] if historic_rewards != None: - return batch_rewards * historic_rewards + for b, h in zip(batch_rewards.tolist(), historic_rewards.tolist()): + reward_events.append( + DiversityRewardEvent(reward=b * h, batch=b, historic=h) + ) else: - return batch_rewards + for b in batch_rewards.tolist(): + reward_events.append(DiversityRewardEvent(reward=b, batch=b)) + + return reward_events def normalize_rewards(self, raw_rewards: torch.FloatTensor) -> torch.FloatTensor: # Applies binarization on the rewards. diff --git a/prompting/validators/reward/dpo.py b/prompting/validators/reward/dpo.py index a987f69..a26299b 100644 --- a/prompting/validators/reward/dpo.py +++ b/prompting/validators/reward/dpo.py @@ -18,9 +18,9 @@ import torch import bittensor as bt -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import ( AutoTokenizer, AutoModelForCausalLM, @@ -51,15 +51,20 @@ def __init__(self, device: str): def reward_single( self, prompt: str, completion: str, name: str, with_penalty=True - ) -> float: + ) -> BaseRewardEvent: r"""Calculates a direct preference optimization (DPO) style reward for a completion, which is a reference model's average log-probability for completion tokens given a prompt. Uses guidance from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py. """ + + reward_event = BaseRewardEvent() + with torch.no_grad(): # Check if completion is if completion.strip() == "" or len(completion) <= 5: - return -11 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + reward_event.reward = -11.0 + return reward_event # Tokenize the combined prompt + completion. combined = ( @@ -74,7 +79,8 @@ def reward_single( # Completion doesn't fit into model sequence, so return lowest reward. if self.tokenizer.model_max_length <= len(prompt_part): - return -11.0 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) + reward_event.reward = -11.0 + return reward_event # Truncate combined to fit into model max sequence length. if self.tokenizer.model_max_length < len(combined): @@ -123,18 +129,21 @@ def reward_single( # NaNs can possibly arise through log(0)=-inf, replace with suitably small logits. if torch.isnan(reward) or torch.isinf(reward): - return -11.0 # exp(-11)=1.67e-5 < 2e-5=1/50257 (typical vocab size) - return reward.item() + reward_event.reward = 11 + + reward_event.reward = reward.item() + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - rewards = torch.tensor( - [ - self.reward_single(prompt, completion, name) - for completion in completions - ], - dtype=torch.float32, - ).to(self.device) - bt.logging.trace(f"DirectPreferenceRewardModel | rewards: {rewards.tolist()}") - return rewards + ) -> List[BaseRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward_single(prompt, completion, name) for completion in completions + ] + + bt.logging.trace( + f"DirectPreferenceRewardModel | rewards: {[reward_event.reward for reward_event in reward_events]}" + ) + + return reward_events diff --git a/prompting/validators/reward/nsfw.py b/prompting/validators/reward/nsfw.py index bb649d2..710f1d9 100644 --- a/prompting/validators/reward/nsfw.py +++ b/prompting/validators/reward/nsfw.py @@ -17,10 +17,16 @@ # DEALINGS IN THE SOFTWARE. import torch -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoModelForSequenceClassification, AutoTokenizer +from dataclasses import dataclass + + +@dataclass +class NSFWRewardEvent(BaseRewardEvent): + score: float = None class NSFWRewardModel(BaseRewardModel): @@ -40,7 +46,9 @@ def __init__(self, device: str): NSFWRewardModel.nsfw_filter_model_path ).to(self.device) - def reward(self, prompt: str, completion: str, name: str) -> float: + def reward(self, prompt: str, completion: str, name: str) -> NSFWRewardEvent: + reward_event = NSFWRewardEvent() + boundary = -0.5 with torch.no_grad(): message = completion @@ -63,15 +71,20 @@ def sum_nsfw_scores(input_ids, chunk_size): return max_score # 0 when needs to be filtered out, 1 when it is safe - return 0.0 if sum_nsfw_scores(input_ids, chunk_size=512) > boundary else 1.0 + score = sum_nsfw_scores(input_ids, chunk_size=512) + reward_event.score = score + reward_event.reward = 0.0 if score > boundary else 1.0 + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ).to(self.device) + ) -> List[NSFWRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + + return reward_events def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor: return rewards diff --git a/prompting/validators/reward/open_assistant.py b/prompting/validators/reward/open_assistant.py index 77dfa36..342001c 100644 --- a/prompting/validators/reward/open_assistant.py +++ b/prompting/validators/reward/open_assistant.py @@ -17,9 +17,9 @@ # DEALINGS IN THE SOFTWARE. import torch -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoTokenizer, AutoModelForSequenceClassification @@ -40,20 +40,22 @@ def __init__(self, device: str): OpenAssistantRewardModel.reward_model_name ).to(self.device) - def reward_single(self, prompt: str, completion: str, name: str) -> float: + def reward_single(self, prompt: str, completion: str, name: str) -> BaseRewardEvent: + reward_event = BaseRewardEvent() + with torch.no_grad(): inputs = self.tokenizer(prompt, completion, return_tensors="pt").to( self.device ) - return float(self.model(**inputs).logits[0].cpu().detach()) + reward_event.reward = float(self.model(**inputs).logits[0].cpu().detach()) + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [ - self.reward_single(prompt, completion, name) - for completion in completions - ], - dtype=torch.float32, - ).to(self.device) + ) -> List[BaseRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward_single(prompt, completion, name) for completion in completions + ] + + return reward_events diff --git a/prompting/validators/reward/prompt.py b/prompting/validators/reward/prompt.py index b72e366..ad5d656 100644 --- a/prompting/validators/reward/prompt.py +++ b/prompting/validators/reward/prompt.py @@ -19,9 +19,9 @@ import time import torch import bittensor as bt -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from prompting.validators.prompts import AugmentPrompt, FollowupPrompt, AnswerPrompt from transformers import AutoTokenizer, AutoModelForCausalLM @@ -50,7 +50,9 @@ def __init__(self, device: str): PromptRewardModel.reward_model_name, torch_dtype=torch.float16 ).to(self.device) - def reward(self, prompt: str, completion: str, name: str) -> float: + def reward(self, prompt: str, completion: str, name: str) -> BaseRewardEvent: + reward_event = BaseRewardEvent() + with torch.no_grad(): # Choose correct scoring prompt for request type. if name == "augment": @@ -60,7 +62,8 @@ def reward(self, prompt: str, completion: str, name: str) -> float: elif name == "answer": scoring_prompt = AnswerPrompt() else: - return 0 + reward_event.reward = 0 + return reward_event # Format scoring prompt for this completion. scoring_prompt_text = scoring_prompt.text(prompt, completion) @@ -96,18 +99,21 @@ def reward(self, prompt: str, completion: str, name: str) -> float: # Scale 0-10 score to 0-1 range. score /= 10.0 - return score + reward_event.reward = score + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: + ) -> List[BaseRewardEvent]: bt.logging.debug( f"PromptRewardModel | Calculating {len(completions)} rewards (typically < 1 sec/reward)." ) bt.logging.trace( f"PromptRewardModel | prompt: {repr(prompt[:50])} ... {repr(prompt[-50:])}" ) - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ).to(self.device) + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + + return reward_events diff --git a/prompting/validators/reward/reciprocate.py b/prompting/validators/reward/reciprocate.py index ff2e572..7da1cdd 100644 --- a/prompting/validators/reward/reciprocate.py +++ b/prompting/validators/reward/reciprocate.py @@ -17,9 +17,9 @@ # DEALINGS IN THE SOFTWARE. import torch -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoTokenizer, AutoModelForSequenceClassification @@ -44,7 +44,8 @@ def __init__(self, device: str): torch_dtype=torch.float16, ).to(self.device) - def reward(self, prompt: str, completion: str, name: str) -> float: + def reward(self, prompt: str, completion: str, name: str) -> BaseRewardEvent: + reward_event = BaseRewardEvent() with torch.no_grad(): message = ( f"<|prompter|>{prompt}<|assistant|>{completion}<|endoftext|>" @@ -54,12 +55,15 @@ def reward(self, prompt: str, completion: str, name: str) -> float: return_tensors="pt", truncation=True, ).to(self.device) - return float(self.model(**inputs)[0].item()) + reward_event.reward = float(self.model(**inputs)[0].item()) + return reward_event def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ).to(self.device) + ) -> List[BaseRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + + return reward_events diff --git a/prompting/validators/reward/relevance.py b/prompting/validators/reward/relevance.py index 7ed0602..5f15bca 100644 --- a/prompting/validators/reward/relevance.py +++ b/prompting/validators/reward/relevance.py @@ -17,12 +17,13 @@ # DEALINGS IN THE SOFTWARE. import torch -from typing import List +from typing import List, Union from .config import RewardModelType -from .reward import BaseRewardModel +from .reward import BaseRewardModel, BaseRewardEvent from transformers import AutoTokenizer, AutoModel from torchmetrics.functional import pairwise_cosine_similarity import torch.nn.functional as F +from dataclasses import dataclass def mean_pooling(model_output, attention_mask): @@ -47,6 +48,12 @@ def mean_pooling(model_output, attention_mask): ) +@dataclass +class RelevanceRewardEvent(BaseRewardEvent): + bert_score: float = None + mpnet_score: float = None + + class RelevanceRewardModel(BaseRewardModel): @property def name(self) -> str: @@ -63,30 +70,44 @@ def __init__(self, device: str): def get_rewards( self, prompt: str, completions: List[str], name: str - ) -> torch.FloatTensor: - return torch.tensor( - [self.reward(prompt, completion, name) for completion in completions], - dtype=torch.float32, - ).to(self.device) + ) -> List[RelevanceRewardEvent]: + # Get all the reward results. + reward_events = [ + self.reward(prompt, completion, name) for completion in completions + ] + return reward_events def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor: return rewards - def reward(self, prompt: str, completion: str, name: str) -> float: + def reward(self, prompt: str, completion: str, name: str) -> RelevanceRewardEvent: + reward_event = RelevanceRewardEvent() + for i, model in enumerate(self.models): # rewards diff = model.reward(prompt, completion) # If a model returns 0, stop iterating and return 0 if diff < self.bounds[i]: - return 0.0 + reward_event.reward = 0 + + if model.name == "relevance_bert": + reward_event.bert_score = diff + + elif model.name == "relevance_mpnet": + reward_event.mpnet_score = diff + # If none of the models returned 0, return 1 - return 1.0 + return reward_event class BertRelevanceRewardModel(BaseRewardModel): relevance_model_path = "bert-base-uncased" + @property + def name(self) -> str: + return RewardModelType.relevance_bert.value + def __init__(self, device: str): super().__init__() self.device = device @@ -142,6 +163,10 @@ def reward(self, prompt: str, completion: str) -> float: class MpnetRelevenceModel(BaseRewardModel): diversity_model_path = "sentence-transformers/all-mpnet-base-v2" + @property + def name(self) -> str: + return RewardModelType.relevance_mpnet.value + def __init__(self, device: str): super().__init__() self.device = device @@ -190,4 +215,4 @@ def reward(self, prompt: str, completion: str) -> torch.FloatTensor: # Calculate the pairwise cosine similarity. similarity = pairwise_cosine_similarity(prompt_embed, embeddings) - return torch.abs(similarity) + return torch.abs(similarity).item() diff --git a/prompting/validators/reward/reward.py b/prompting/validators/reward/reward.py index c20220d..23c8463 100644 --- a/prompting/validators/reward/reward.py +++ b/prompting/validators/reward/reward.py @@ -18,8 +18,24 @@ import torch import bittensor as bt -from typing import List +from typing import List, Union from abc import abstractmethod +from dataclasses import dataclass, asdict, fields + + +@dataclass +class BaseRewardEvent: + reward: float = 1.0 + normalized_reward: float = None + + @staticmethod + def parse_reward_events(reward_events): + field_names = [field.name for field in fields(reward_events[0])] + reward_events = [ + asdict(reward_event).values() for reward_event in reward_events + ] + reward_event = dict(zip(field_names, list(zip(*reward_events)))) + return reward_event class BaseRewardModel: @@ -37,7 +53,7 @@ def __repr__(self) -> str: @abstractmethod def get_rewards( self, prompt: str, completion: List[str], name: str - ) -> torch.FloatTensor: + ) -> Union[torch.FloatTensor, dict]: ... def __init__(self) -> None: @@ -101,7 +117,7 @@ def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor: def apply( self, prompt: str, responses: List[bt.Synapse], name: str - ) -> torch.FloatTensor: + ) -> Union[torch.FloatTensor, dict]: """Applies the reward model across each call. Unsuccessful responses are zeroed.""" # Get indices of correctly responding calls. @@ -117,7 +133,12 @@ def apply( ] # Reward each completion. - successful_rewards = self.get_rewards(prompt, successful_completions, name) + reward_events = BaseRewardEvent.parse_reward_events( + self.get_rewards(prompt, successful_completions, name) + ) + successful_rewards = torch.tensor( + reward_events.pop("reward"), dtype=torch.float32 + ) # Softmax rewards across samples. successful_rewards_normalized = self.normalize_rewards(successful_rewards) @@ -135,5 +156,17 @@ def apply( filled_rewards[idx] = reward filled_rewards_normalized[idx] = reward_normalized + # Fill every item of the reward_events + for name, reward_values in reward_events.items(): + filled_values = [None] * len(responses) + for idx, reward_value in zip(successful_completions_indices, reward_values): + filled_values[idx] = reward_value + reward_events[name] = filled_values + + # Name each item of the reward event with the reward model name. + reward_events = {f"{self.name}_{k}": v for k, v in reward_events.items()} + reward_events[self.name] = filled_rewards.tolist() + reward_events[self.name + "_normalized"] = filled_rewards_normalized.tolist() + # Return the filled rewards. - return filled_rewards, filled_rewards_normalized + return filled_rewards_normalized, reward_events diff --git a/prompting/validators/tasks.py b/prompting/validators/tasks.py index ccf63de..dad365a 100644 --- a/prompting/validators/tasks.py +++ b/prompting/validators/tasks.py @@ -25,6 +25,9 @@ TaskCriterion, MatchLengthCriteria, TextLengthUnitEnum, + ContentMatchTypeEnum, + SimpleResponseLayoutCriteria, + MatchContentCriteria, ) @@ -143,7 +146,51 @@ def create_summarization_task(base_text: str) -> SummaryTask: def create_qg_task(base_text: str, index: int) -> QuestionGenerationTask: - possible_criterias = [ + questions_prefixes = [ + "who", + "what", + "when", + "where", + "why", + "how", + "is", + "are", + "can", + "do", + "does", + "did", + "would", + "could", + "will", + "shall", + "may", + "might", + "am", + "was", + "were", + "has", + "have", + "had", + "been", + "being", + ] + + question_starts_with_prefix_criteria = MatchContentCriteria( + contentMatchType=ContentMatchTypeEnum.STARTS_WITH, + penalty=0.25, + words_array=questions_prefixes, + n_words=3, + ) + + question_ends_with_criteria = MatchContentCriteria( + contentMatchType=ContentMatchTypeEnum.ENDS_WITH, + penalty=0.25, + words_array=["?"], + n_words=1, + text='Your response should end with a question mark, i.e. "?"', + ) + + other_random_criteria = [ MatchLengthCriteria( penalty=0.1, target_length=random.randint(10, 40), @@ -156,35 +203,47 @@ def create_qg_task(base_text: str, index: int) -> QuestionGenerationTask: ), ] - sampled_criterias = random.sample(possible_criterias, 1) + random_sampled_criteria = random.sample(other_random_criteria, 1) + criteria = [ + question_starts_with_prefix_criteria, + question_ends_with_criteria, + ] + random_sampled_criteria return QuestionGenerationTask( base_text=base_text, - criteria=sampled_criterias, + criteria=criteria, task_type="question-generation", task_name=f"followup{index}", ) def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask: - possible_criterias = [ - MatchLengthCriteria( - penalty=0.1, - target_length=random.randint(50, 200), - unit=TextLengthUnitEnum.WORDS, - ), - MatchLengthCriteria( - penalty=0.1, - target_length=random.randint(4, 8), - unit=TextLengthUnitEnum.SENTENCES, - ), - ] + answer_should_not_include_criteria = MatchContentCriteria( + words_array=["?"], + n_words=1, + penalty=0.2, + contentMatchType=ContentMatchTypeEnum.INCLUDES, + negate_match=True, + text="Your response should not include any question marks", + ) - sampled_criterias = random.sample(possible_criterias, 1) + simple_response_layout_criteria = SimpleResponseLayoutCriteria(penalty=0.2) + + words_criteria = MatchLengthCriteria( + penalty=0.2, + target_length=random.randint(50, 200), + unit=TextLengthUnitEnum.WORDS, + ) + + criteria = [ + answer_should_not_include_criteria, + simple_response_layout_criteria, + words_criteria, + ] return QuestionAnswerTask( base_text=base_text, - criteria=sampled_criterias, + criteria=criteria, task_type="question-answer", task_name=f"answer{index}", )