Skip to content

Commit

Permalink
feat(python): add cls and mean pooling (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbroad1881 authored Sep 17, 2024
1 parent c6c5e45 commit 80259c9
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
embedding = output[0][:, 0]

if self.pooling_mode == "cls":
embedding = output[0][:, 0]
elif self.pooling_mode == "mean":
embedding = output[0].mean(dim=1)
else:
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")

cpu_results = embedding.view(-1).tolist()

return [
Expand Down

0 comments on commit 80259c9

Please sign in to comment.