Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/python pooling #415

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
import typer

from enum import Enum
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum

import typer
from loguru import logger

app = typer.Typer()

Expand All @@ -24,6 +24,7 @@ def serve(
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pool: str = "cls",
):
# Remove default handler
logger.remove()
Expand All @@ -48,7 +49,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path)
server.serve(model_path, dtype, uds_path, pool)


if __name__ == "__main__":
Expand Down
14 changes: 8 additions & 6 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch

from loguru import logger
from pathlib import Path
from typing import Optional

import torch
from loguru import logger
from transformers import AutoConfig
from transformers.models.bert import BertConfig

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.model import Model

__all__ = ["Model"]

Expand All @@ -25,7 +25,7 @@
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str]):
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
Expand All @@ -52,8 +52,10 @@ def get_model(model_path: Path, dtype: Optional[str]):
and dtype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype)
return DefaultModel(model_path, device, dtype, pool)

raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from transformers import AutoModel
from typing import List, Type

import torch
from opentelemetry import trace
from sentence_transformers.models import Pooling
from transformers import AutoModel

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding
from text_embeddings_server.models.types import Embedding, PaddedBatch

tracer = trace.get_tracer(__name__)


class DefaultModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
):
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
Expand All @@ -41,7 +45,13 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
embedding = output[0][:, 0]

pooling_features = {
"token_embeddings": output[0],
"attention_mask": batch.attention_mask,
}
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]

cpu_results = embedding.view(-1).tolist()

return [
Expand Down
16 changes: 8 additions & 8 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import asyncio
import torch
from pathlib import Path
from typing import Optional

import torch
from grpc import aio
from loguru import logger

from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional
from loguru import logger

from text_embeddings_server.models import Model, get_model
from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2
from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor
from text_embeddings_server.pb import embed_pb2, embed_pb2_grpc
from text_embeddings_server.utils.interceptor import ExceptionInterceptor
from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor


class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
Expand All @@ -37,6 +36,7 @@ def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
):
async def serve_inner(
model_path: Path,
Expand All @@ -45,7 +45,7 @@ async def serve_inner(
unix_socket = f"unix://{uds_path}"

try:
model = get_model(model_path, dtype)
model = get_model(model_path, dtype, pool)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down
12 changes: 4 additions & 8 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
use tokio::runtime::Runtime;

Expand All @@ -24,18 +24,13 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
match model_type {
let pool = match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool
}
ModelType::Embedding(pool) => pool
};

let backend_process = management::BackendProcess::new(
Expand All @@ -44,6 +39,7 @@ impl PythonBackend {
&uds_path,
otlp_endpoint,
otlp_service_name,
pool,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down
14 changes: 13 additions & 1 deletion backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::mpsc;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{env, fs, io, thread};
use text_embeddings_backend_core::BackendError;
use text_embeddings_backend_core::{BackendError, Pool};

#[derive(Debug)]
pub(crate) struct BackendProcess {
Expand All @@ -22,6 +22,7 @@ impl BackendProcess {
uds_path: &str,
otlp_endpoint: Option<String>,
otlp_service_name: String,
pool: Pool,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
Expand All @@ -31,6 +32,15 @@ impl BackendProcess {
fs::remove_file(uds).expect("could not remove UDS file");
}

let pool = match pool {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
},
};

// Process args
let mut python_server_args = vec![
model_path,
Expand All @@ -41,6 +51,8 @@ impl BackendProcess {
"--logger-level".to_owned(),
"INFO".to_owned(),
"--json-output".to_owned(),
"--pool".to_owned(),
pool.to_owned(),
];

// OpenTelemetry
Expand Down