From 1a9876f4edcbea156757d25481cdf92704a38c69 Mon Sep 17 00:00:00 2001 From: Oskar Liew Date: Thu, 26 Sep 2024 15:09:08 +0200 Subject: [PATCH] Add support for pooling with python backend --- .../server/text_embeddings_server/cli.py | 11 +++++---- .../text_embeddings_server/models/__init__.py | 14 ++++++----- .../models/default_model.py | 24 +++++++++++++------ .../server/text_embeddings_server/server.py | 16 ++++++------- backends/python/src/lib.rs | 12 ++++------ backends/python/src/management.rs | 14 ++++++++++- 6 files changed, 56 insertions(+), 35 deletions(-) diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 9497dc20..b6c14bee 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -1,10 +1,10 @@ import sys -import typer - +from enum import Enum from pathlib import Path -from loguru import logger from typing import Optional -from enum import Enum + +import typer +from loguru import logger app = typer.Typer() @@ -24,6 +24,7 @@ def serve( json_output: bool = False, otlp_endpoint: Optional[str] = None, otlp_service_name: str = "text-embeddings-inference.server", + pool: str = "cls", ): # Remove default handler logger.remove() @@ -48,7 +49,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pool) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 47867187..588311eb 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,13 +1,13 @@ -import torch - -from loguru import logger from pathlib import Path from typing import Optional + +import torch +from loguru import logger from transformers import AutoConfig from transformers.models.bert import BertConfig -from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel +from text_embeddings_server.models.model import Model __all__ = ["Model"] @@ -25,7 +25,7 @@ __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +def get_model(model_path: Path, dtype: Optional[str], pool: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -52,8 +52,10 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): + if pool != "cls": + raise ValueError("FlashBert only supports cls pooling") return FlashBert(model_path, device, dtype) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, pool) raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index dc39fdc8..f75255bb 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -1,21 +1,25 @@ import inspect -import torch - from pathlib import Path -from typing import Type, List -from transformers import AutoModel +from typing import List, Type + +import torch from opentelemetry import trace +from sentence_transformers.models import Pooling +from transformers import AutoModel from text_embeddings_server.models import Model -from text_embeddings_server.models.types import PaddedBatch, Embedding +from text_embeddings_server.models.types import Embedding, PaddedBatch tracer = trace.get_tracer(__name__) class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__( + self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str + ): model = AutoModel.from_pretrained(model_path).to(dtype).to(device) self.hidden_size = model.config.hidden_size + self.pooling = Pooling(self.hidden_size, pooling_mode=pool) self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -41,7 +45,13 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - embedding = output[0][:, 0] + + pooling_features = { + "token_embeddings": output[0], + "attention_mask": batch.attention_mask, + } + embedding = self.pooling.forward(pooling_features)["sentence_embedding"] + cpu_results = embedding.view(-1).tolist() return [ diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index d0a43ace..d1a92261 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -1,17 +1,16 @@ import asyncio -import torch +from pathlib import Path +from typing import Optional +import torch from grpc import aio -from loguru import logger - from grpc_reflection.v1alpha import reflection -from pathlib import Path -from typing import Optional +from loguru import logger from text_embeddings_server.models import Model, get_model -from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 -from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor +from text_embeddings_server.pb import embed_pb2, embed_pb2_grpc from text_embeddings_server.utils.interceptor import ExceptionInterceptor +from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): @@ -37,6 +36,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pool: str, ): async def serve_inner( model_path: Path, @@ -45,7 +45,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pool) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..f68bd20c 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -5,7 +5,7 @@ use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, }; use tokio::runtime::Runtime; @@ -24,18 +24,13 @@ impl PythonBackend { otlp_endpoint: Option, otlp_service_name: String, ) -> Result { - match model_type { + let pool = match model_type { ModelType::Classifier => { return Err(BackendError::Start( "`classifier` model type is not supported".to_string(), )) } - ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); - } - pool - } + ModelType::Embedding(pool) => pool }; let backend_process = management::BackendProcess::new( @@ -44,6 +39,7 @@ impl PythonBackend { &uds_path, otlp_endpoint, otlp_service_name, + pool, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..977ab045 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -8,7 +8,7 @@ use std::sync::mpsc; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{env, fs, io, thread}; -use text_embeddings_backend_core::BackendError; +use text_embeddings_backend_core::{BackendError, Pool}; #[derive(Debug)] pub(crate) struct BackendProcess { @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pool: Pool, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -31,6 +32,15 @@ impl BackendProcess { fs::remove_file(uds).expect("could not remove UDS file"); } + let pool = match pool { + Pool::Cls => "cls", + Pool::Mean => "mean", + Pool::LastToken => "lasttoken", + Pool::Splade => { + return Err(BackendError::Start(format!("{pool:?} is not supported"))); + }, + }; + // Process args let mut python_server_args = vec![ model_path, @@ -41,6 +51,8 @@ impl BackendProcess { "--logger-level".to_owned(), "INFO".to_owned(), "--json-output".to_owned(), + "--pool".to_owned(), + pool.to_owned(), ]; // OpenTelemetry