Skip to content

Commit

Permalink
Merge pull request #28 from opentensor/topk
Browse files Browse the repository at this point in the history
ensure topk isn't out of bounds
  • Loading branch information
ifrit98 committed Oct 2, 2023
2 parents 6322c96 + e2e348f commit d6e2919
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions prompting/validators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def regularise(rewards):
)

# Reward to be at the bottom_k smallest of the 1 - similarity score.
rewards = torch.topk(
(1 - torch.abs(similarity)), self.history_reward_bottom_k, largest=False
)[0][:, -1]
bottom_k = min(self.history_reward_bottom_k, len(similarity))
rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][
:, -1
]

return regularise(rewards)

Expand All @@ -144,9 +145,10 @@ def regularise(rewards):
similarity = pairwise_cosine_similarity(embeddings, embeddings)

# Reward to be at the 10% quantile of the 1 - similarity score.
rewards = torch.topk(
(1 - torch.abs(similarity)), self.reward_bottom_k, largest=False
)[0][:, -1]
bottom_k = min(self.reward_bottom_k, len(similarity))
rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][
:, -1
]

return regularise(rewards)

Expand Down

0 comments on commit d6e2919

Please sign in to comment.