Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ifrit98 committed Oct 2, 2023
1 parent 2d2daba commit e2e348f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions prompting/validators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def regularise(rewards):

# Reward to be at the bottom_k smallest of the 1 - similarity score.
bottom_k = min(self.history_reward_bottom_k, len(similarity))
rewards = torch.topk(
(1 - torch.abs(similarity)), bottom_k, largest=False
)[0][:, -1]
rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][
:, -1
]

return regularise(rewards)

Expand All @@ -146,9 +146,9 @@ def regularise(rewards):

# Reward to be at the 10% quantile of the 1 - similarity score.
bottom_k = min(self.reward_bottom_k, len(similarity))
rewards = torch.topk(
(1 - torch.abs(similarity)), bottom_k, largest=False
)[0][:, -1]
rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][
:, -1
]

return regularise(rewards)

Expand Down

0 comments on commit e2e348f

Please sign in to comment.