Skip to content

Commit

Permalink
Merge branch 'rocm-support' of github.com:huggingface/text-embeddings…
Browse files Browse the repository at this point in the history
…-inference into rocm-support
  • Loading branch information
fxmarty committed Jun 24, 2024
2 parents 839a445 + 09b8b22 commit c6c5e45
Show file tree
Hide file tree
Showing 15 changed files with 6,380 additions and 3,410 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
integration-tests:
cargo test --release
cargo test

cuda-integration-tests:
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --release
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --profile release-debug

integration-tests-review:
cargo insta test --review --release
cargo insta test --review

cuda-integration-tests-review:
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --release
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --profile release-debug
File renamed without changes.
133 changes: 67 additions & 66 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig,
Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand All @@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};

/// This enum is needed to be able to differentiate between jina models that also use
/// the `bert` model type and valid Bert models.
/// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
/// run but is still better than the other options...
#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "_name_or_path")]
pub enum BertConfigWrapper {
#[serde(rename = "jinaai/jina-bert-implementation")]
JinaBert(BertConfig),
#[serde(rename = "jinaai/jina-bert-v2-qk-post-norm")]
JinaCodeBert(BertConfig),
#[serde(untagged)]
Bert(BertConfig),
}

#[derive(Deserialize)]
#[serde(tag = "model_type", rename_all = "kebab-case")]
enum Config {
Bert(BertConfig),
Bert(BertConfigWrapper),
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "jina_bert"))]
JinaBert(JinaConfig),
#[serde(rename(deserialize = "jina_code_bert"))]
JinaCodeBert(JinaCodeConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
Expand Down Expand Up @@ -76,7 +87,7 @@ impl CandleBackend {
"Runtime compute cap {} is not compatible with compile time compute cap {}",
get_runtime_compute_cap().unwrap(),
get_compile_compute_cap().unwrap()
)))
)));
}
Err(err) => {
tracing::warn!("Could not find a compatible CUDA device on host: {err:?}");
Expand Down Expand Up @@ -123,20 +134,22 @@ impl CandleBackend {
(_, Device::Cuda(_)) => Err(BackendError::Start(
"`cuda` feature is not enabled".to_string(),
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting JinaCodeBert model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
},
(
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
Device::Cpu | Device::Metal(_),
Expand All @@ -160,56 +173,45 @@ impl CandleBackend {
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting FlashJinaBert model on {:?}", device);
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting FlashJinaCodeBert model on {:?}", device);
Ok(Box::new(
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
}
}
}
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting JinaCodeBert model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
}
}
}
#[cfg(feature = "cuda")]
Expand All @@ -219,7 +221,6 @@ impl CandleBackend {
) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
Expand Down
35 changes: 17 additions & 18 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::BertEmbeddings;
use crate::models::jina::{BertEmbeddings, JinaConfig};
use crate::models::Model;
use crate::models::jina::JinaEmbeddings;
use crate::models::{BertConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct AlibiBertAttention {
struct JinaAttention {
qkv_linear: Linear,
dense: Linear,
layer_norm: LayerNorm,
Expand All @@ -23,7 +22,7 @@ struct AlibiBertAttention {
span: tracing::Span,
}

impl AlibiBertAttention {
impl JinaAttention {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi_slopes: Option<Tensor>) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
Expand Down Expand Up @@ -117,7 +116,7 @@ impl AlibiBertAttention {
}

struct JinaBertLayer {
attention: AlibiBertAttention,
attention: JinaAttention,
gated: Linear,
output: Linear,
layer_norm: LayerNorm,
Expand All @@ -130,7 +129,7 @@ struct JinaBertLayer {

impl JinaBertLayer {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
let attention = AlibiBertAttention::load(vb.pp("attention"), config, alibi)?;
let attention = JinaAttention::load(vb.pp("attention"), config, alibi)?;

let gated_weight = vb
.pp("mlp")
Expand Down Expand Up @@ -174,14 +173,14 @@ impl JinaBertLayer {
let residual = hidden_states.clone();

let hidden_states = self.gated.forward(&hidden_states)?;
let gated = hidden_states.i((.., 0..self.intermediate_size))?;
let gated = hidden_states.narrow(1, 0, self.intermediate_size)?;
let gated = match self.act {
HiddenAct::Gelu => gated.gelu(),
HiddenAct::Relu => gated.relu(),
HiddenAct::Swiglu => gated.silu(),
}?;

let non_gated = hidden_states.i((.., self.intermediate_size..))?;
let non_gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?;
let hidden_states = (gated * non_gated)?;

let hidden_states = self.output.forward(&hidden_states)?;
Expand All @@ -191,12 +190,12 @@ impl JinaBertLayer {
}
}

struct BertEncoder {
struct JinaBertEncoder {
layers: Vec<JinaBertLayer>,
span: tracing::Span,
}

impl BertEncoder {
impl JinaBertEncoder {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| {
Expand All @@ -205,7 +204,7 @@ impl BertEncoder {
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");

Ok(BertEncoder { layers, span })
Ok(JinaBertEncoder { layers, span })
}

fn forward(&self, hidden_states: &Tensor, cu_seqlens: &Tensor, max_s: usize) -> Result<Tensor> {
Expand All @@ -223,8 +222,8 @@ impl BertEncoder {
}

pub struct FlashJinaBertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
embeddings: JinaEmbeddings,
encoder: JinaBertEncoder,
pool: Pool,
pub device: Device,

Expand Down Expand Up @@ -266,14 +265,14 @@ impl FlashJinaBertModel {
};

let (embeddings, encoder) = match (
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
JinaEmbeddings::load(vb.pp("embeddings"), config),
JinaBertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp("bert.embeddings"), config),
BertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
JinaEmbeddings::load(vb.pp("bert.embeddings"), config),
JinaBertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
) {
(embeddings, encoder)
} else {
Expand Down
Loading

0 comments on commit c6c5e45

Please sign in to comment.