Skip to content

Commit

Permalink
Merge pull request #67 from opentensor/staging
Browse files Browse the repository at this point in the history
2.1.2 Release
  • Loading branch information
p-ferreira committed Nov 9, 2023
2 parents 070e599 + cb9e98c commit 1e3d367
Show file tree
Hide file tree
Showing 19 changed files with 768 additions and 155 deletions.
8 changes: 4 additions & 4 deletions neurons/validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def __init__(self):
MockRewardModel(RewardModelType.nsfw.value),
]
self.penalty_functions = [
TaskValidationPenaltyModel(max_penalty=0.1),
ContentMatchPenaltyModel(max_penalty=0.1),
TaskValidationPenaltyModel(max_penalty=0.6),
ContentMatchPenaltyModel(max_penalty=0.2),
KeywordMatchPenaltyModel(max_penalty=1),
]
bt.logging.debug(str(self.reward_functions))
Expand Down Expand Up @@ -277,8 +277,8 @@ def __init__(self):
]

self.penalty_functions = [
TaskValidationPenaltyModel(max_penalty=0.1),
ContentMatchPenaltyModel(max_penalty=0.1),
TaskValidationPenaltyModel(max_penalty=0.6),
ContentMatchPenaltyModel(max_penalty=0.2),
KeywordMatchPenaltyModel(max_penalty=1),
]

Expand Down
2 changes: 1 addition & 1 deletion prompting/validators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from . import event
from . import dataset

__version__ = "2.1.1"
__version__ = "2.1.2"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
Expand Down
135 changes: 134 additions & 1 deletion prompting/validators/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import torch
import numpy as np
from dataclasses import dataclass
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from typing import List
from enum import Enum
Expand Down Expand Up @@ -109,3 +109,136 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor:

def compose_text(self) -> str:
return self.text.format(target_length=self.target_length, unit=self.unit.value)


class ContentMatchTypeEnum(Enum):
STARTS_WITH = "starts with"
ENDS_WITH = "ends with"
INCLUDES = "includes"


@dataclass
class MatchContentCriteria(TaskCriterion):
default_text: str = (
"Your response should {match_type} the following words: {words}."
)
text: str = default_text
penalty: float = 0.1
n_words: int = 3
words_array: List[str] = field(default_factory=list)
contentMatchType: ContentMatchTypeEnum = ContentMatchTypeEnum.STARTS_WITH
sampled_words: List[str] = field(init=False)
negate_match: bool = False

def __post_init__(self):
# Randomly sample words from the array based on n_words
self.sampled_words = np.random.choice(
self.words_array, self.n_words, replace=False
)

def _get_regex_pattern(self):
# Escape all special characters in the sampled words
escaped_words = map(re.escape, self.sampled_words)

if self.contentMatchType == ContentMatchTypeEnum.STARTS_WITH:
return rf"^\s*({'|'.join(escaped_words)})\b"
elif self.contentMatchType == ContentMatchTypeEnum.ENDS_WITH:
return rf"({'|'.join(escaped_words)})\s*$"
else: # ContentMatchTypeEnum.INCLUDES
return rf"({'|'.join(escaped_words)})"

def evaluate(self, completions: List[str]) -> torch.FloatTensor:
penalties = torch.zeros(len(completions), dtype=torch.float32)
# Define regex pattern based on contentMatchType
pattern = self._get_regex_pattern()

for idx, completion in enumerate(completions):
# Check if the completion matches the pattern
match = re.search(pattern, completion, re.IGNORECASE)

completion_with_undesired_match = self.negate_match and match
completion_without_desired_match = not self.negate_match and not match

if completion_with_undesired_match or completion_without_desired_match:
penalties[idx] = self.penalty

return penalties

def compose_text(self) -> str:
# Check if the text property is different than the default. If so, use that text.
if self.text != MatchContentCriteria.default_text:
return self.text

# Adds "should" or "should not" instruction based on the negate_match property
should_match_text = "should" if not self.negate_match else "should not"

# Get the list of selected sampled words
words_list = ", ".join(self.sampled_words)

# Get the descriptive text of the match type
match_type_text = self.contentMatchType.value

# Adjust the text based on the number of words
if self.n_words > 1:
text = f"Your response {should_match_text} {match_type_text} one of the following words: {words_list}."
else:
text = f"Your response {should_match_text} {match_type_text} the following word: {words_list}."
return text


@dataclass
class SimpleResponseLayoutCriteria(TaskCriterion):
penalty: float = 0.1
text: str = "Your response should not contain any bullet points or numbered lists."

def evaluate(self, completions: List[str]) -> torch.FloatTensor:
penalties = torch.zeros(len(completions), dtype=torch.float32)

# Regex patterns to match bullet points (unordered lists) and numbered lists
bullet_point_pattern = re.compile(r"(\*|\-|\+|\•|\‣|\◦)\s")
numbered_list_pattern = re.compile(r"\d+\.\s")

for idx, completion in enumerate(completions):
# Check if the completion contains a bullet point or numbered list
if bullet_point_pattern.search(completion) or numbered_list_pattern.search(
completion
):
penalties[idx] = self.penalty

return penalties

def compose_text(self) -> str:
return self.text


class LayoutMatchTypeEnum(Enum):
UNORDERED_LIST = "unordered list"
NUMBERED_LIST = "numbered list"


@dataclass
class MatchLayoutCriteria(TaskCriterion):
layout_type: LayoutMatchTypeEnum = LayoutMatchTypeEnum.UNORDERED_LIST
penalty: float = 0.1
text: str = "Your response should be ordered in format of {layout_type}."

def evaluate(self, completions: List[str]) -> torch.FloatTensor:
penalties = torch.zeros(len(completions), dtype=torch.float32)

# Regex patterns based on layout type
bullet_point_pattern = re.compile(r"(\*|\-|\+|\•|\‣|\◦)\s")
numbered_list_pattern = re.compile(r"\d+\.\s")

for idx, completion in enumerate(completions):
# Evaluate based on the layout type
if self.layout_type == LayoutMatchTypeEnum.UNORDERED_LIST:
if not bullet_point_pattern.search(completion):
penalties[idx] = self.penalty
elif self.layout_type == LayoutMatchTypeEnum.NUMBERED_LIST:
if not numbered_list_pattern.search(completion):
penalties[idx] = self.penalty

return penalties

def compose_text(self) -> str:
return self.text.format(layout_type=self.layout_type)
43 changes: 42 additions & 1 deletion prompting/validators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,25 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
blacklist_filter: Optional[List[float]] # Output vector of the blacklist filter
blacklist_filter_matched_ngram: Optional[
List[str]
] # Output vector of the blacklist filter
blacklist_filter_significance_score: Optional[
List[float]
] # Output vector of the blacklist filter
nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model: Optional[
List[float]
] # Output vector of the reciprocate reward model
diversity_reward_model: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_historic: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_batch: Optional[
List[float]
] # Output vector of the diversity reward model
dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model
rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model: Optional[
Expand All @@ -65,6 +77,7 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
nsfw_filter_normalized: Optional[List[float]] # Output vector of the nsfw filter
nsfw_filter_score: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model_normalized: Optional[
List[float]
] # Output vector of the reciprocate reward model
Expand All @@ -80,7 +93,16 @@ class EventSchema:
prompt_reward_model_normalized: Optional[
List[float]
] # Output vector of the prompt reward model
relevance_filter_normalized: Optional[List[float]]

relevance_filter_normalized: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_bert_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_mpnet_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
# TODO: Add comments
task_validation_penalty_raw: Optional[List[float]]
task_validation_penalty_adjusted: Optional[List[float]]
Expand Down Expand Up @@ -109,6 +131,12 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
RewardModelType.reciprocate.value
),
"diversity_reward_model": event_dict.get(RewardModelType.diversity.value),
"diversity_reward_model_historic": event_dict.get(
RewardModelType.diversity.value + "_historic"
),
"diversity_reward_model_batch": event_dict.get(
RewardModelType.diversity.value + "_batch"
),
"dpo_reward_model": event_dict.get(RewardModelType.dpo.value),
"rlhf_reward_model": event_dict.get(RewardModelType.rlhf.value),
"prompt_reward_model": event_dict.get(RewardModelType.prompt.value),
Expand Down Expand Up @@ -136,6 +164,19 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
"prompt_reward_model_normalized": event_dict.get(
RewardModelType.prompt.value + "_normalized"
),
"blacklist_filter_matched_ngram": event_dict.get(
RewardModelType.blacklist.value + "_matched_ngram"
),
"blacklist_filter_significance_score": event_dict.get(
RewardModelType.blacklist.value + "_significance_score"
),
"relevance_filter_bert_score": event_dict.get(
RewardModelType.relevance.value + "_bert_score"
),
"relevance_filter_mpnet_score": event_dict.get(
RewardModelType.relevance.value + "_mpnet_score"
),
"nsfw_filter_score": event_dict.get(RewardModelType.nsfw.value + "_score"),
}
penalties = {
"task_validation_penalty_raw": event_dict.get(
Expand Down
25 changes: 13 additions & 12 deletions prompting/validators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,44 +96,48 @@ async def run_step(self, task: Task, k: int, timeout: float, exclude: list = [])
timeout=timeout,
)

# Update blacklist with completions so that n-gram filtering can be applied
self.blacklist.add(
[response.completion for response in responses if response.completion]
)

# Restrict the format of acceptable followup completions.
for response in responses:
# remove leading and trailing periods
completion = response.completion.strip(".")

if "followup" in task_name and len(completion) > 0:
# take maximum of 40 words
max_words = 40
if "?" in completion:
# take first question that is found and only use the sentence before the question mark
completion = completion.split("?")[0].split(".")[-1]
response.completion = " ".join(completion.split(" ")[-max_words:]) + "?"
else:
# otherwise take the last sentence
completion = completion.split(".")[-1].split(".")[-1]

# take maximum of 40 words
response.completion = " ".join(completion.split(" ")[-40:]) + "?"
response.completion = " ".join(completion.split(" ")[-max_words:])

# Compute the rewards for the responses given the prompt.
rewards: torch.FloatTensor = torch.zeros(len(responses), dtype=torch.float32).to(
self.device
)
for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions):
reward_i, reward_i_normalized = reward_fn_i.apply(
reward_i_normalized, reward_event = reward_fn_i.apply(
task.base_text, responses, task_name
)
rewards += weight_i * reward_i_normalized.to(self.device)
if not self.config.neuron.disable_log_rewards:
event[reward_fn_i.name] = reward_i.tolist()
event[reward_fn_i.name + "_normalized"] = reward_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist())

for masking_fn_i in self.masking_functions:
mask_i, mask_i_normalized = masking_fn_i.apply(
mask_i_normalized, reward_event = masking_fn_i.apply(
task.base_text, responses, task_name
)
rewards *= mask_i_normalized.to(self.device) # includes diversity
if not self.config.neuron.disable_log_rewards:
event[masking_fn_i.name] = mask_i.tolist()
event[masking_fn_i.name + "_normalized"] = mask_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist())

for penalty_fn_i in self.penalty_functions:
Expand Down Expand Up @@ -271,6 +275,3 @@ async def forward(self):
)

exclude += qa_event["uids"]

self.blacklist.question_blacklist.append(qg_event["best"])
self.blacklist.answer_blacklist.append(qa_event["best"])
2 changes: 1 addition & 1 deletion prompting/validators/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, mock_name: str = "MockReward"):

def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor:
mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32)
return mock_reward, mock_reward
return mock_reward, {}

def reset(self):
return self
Expand Down
9 changes: 7 additions & 2 deletions prompting/validators/penalty/content_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@ def calculate_penalties(
r"the\s+question\s+was", # the question was
r"use\s+proper\s+grammar", # use proper grammar
r"what\s+is\s+the\s+correct\s+order\s+of\s+the\s+key\s+points", # what is the correct order of the key points
r"sure!\s+here.+", # sure! here...
r"sure!\s+her.+", # sure! here...
r"solution\s+\(in\s+\w+\)", # solution (in \w+)
r"great\s+job!\s+here(?:'s| is)", # great job! here...
r"keep\s+it\s+clear\s+and\s+concise.\s+Use\s+complete\s+sentences.", # keep it clear and concise. Use complete sentences.
r"task\s*:" r"i\s+can\s+help\s+you\s+with", # task: # I can help you with
r"what\s+did\s+I\s+learn\s+today\s*\?", # what did I learn today?
r"paraphrase\s*:", # paraphrase:
r"your\s+task\s+now\s+is\s+to\s+write\s+a\s+tweet\s+about\s+the\s+previous\s+text", # your task now is to write a tweet about the previous text
r"what\s+is\s+the\s+main\s+point\s+of\s+the\s+passage", # what is the main point of the passage
]

penalties = []
Expand All @@ -53,7 +58,7 @@ def calculate_penalties(
# Trim and consider only the first 200 characters
completion_segment = completion.strip()[:200].lower()
for pattern in system_messages_penalizing_sentences:
if re.search(pattern, completion_segment):
if re.search(pattern, completion_segment, re.IGNORECASE):
accumulated_penalty += 0.1
penalties.append(accumulated_penalty)

Expand Down
Loading

0 comments on commit 1e3d367

Please sign in to comment.