diff --git a/prompting/validators/reward/diversity.py b/prompting/validators/reward/diversity.py index 404abc4..a286f35 100644 --- a/prompting/validators/reward/diversity.py +++ b/prompting/validators/reward/diversity.py @@ -129,9 +129,10 @@ def regularise(rewards): ) # Reward to be at the bottom_k smallest of the 1 - similarity score. - rewards = torch.topk( - (1 - torch.abs(similarity)), self.history_reward_bottom_k, largest=False - )[0][:, -1] + bottom_k = min(self.history_reward_bottom_k, len(similarity)) + rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][ + :, -1 + ] return regularise(rewards) @@ -144,9 +145,10 @@ def regularise(rewards): similarity = pairwise_cosine_similarity(embeddings, embeddings) # Reward to be at the 10% quantile of the 1 - similarity score. - rewards = torch.topk( - (1 - torch.abs(similarity)), self.reward_bottom_k, largest=False - )[0][:, -1] + bottom_k = min(self.reward_bottom_k, len(similarity)) + rewards = torch.topk((1 - torch.abs(similarity)), bottom_k, largest=False)[0][ + :, -1 + ] return regularise(rewards)