Skip to content

Commit

Permalink
make the elasticsearch api support version which below 8.x (langchain…
Browse files Browse the repository at this point in the history
…-ai#5495)

the api which create index or search in the elasticsearch below 8.x is
different with 8.x. When use the es which below 8.x , it will throw
error. I fix the problem


Co-authored-by: gaofeng27692 <gaofeng27692@hundsun.com>
  • Loading branch information
2 people authored and Undertone0809 committed Jun 19, 2023
1 parent 58aabd6 commit 4d8555b
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions langchain/vectorstores/elastic_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def add_texts(
except NotFoundError:
# TODO would be nice to create index before embedding,
# just to save expensive steps for last
self.client.indices.create(index=self.index_name, mappings=mapping)
self.create_index(self.client, self.index_name, mapping)

for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
Expand Down Expand Up @@ -226,7 +226,9 @@ def similarity_search_with_score(
"""
embedding = self.embedding.embed_query(query)
script_query = _default_script_query(embedding, filter)
response = self.client.search(index=self.index_name, query=script_query, size=k)
response = self.client_search(
self.client, self.index_name, script_query, size=k
)
hits = [hit for hit in response["hits"]["hits"]]
docs_and_scores = [
(
Expand Down Expand Up @@ -281,3 +283,24 @@ def from_texts(
texts, metadatas=metadatas, refresh_indices=refresh_indices
)
return vectorsearch

def create_index(self, client: Any, index_name: str, mapping: Dict) -> None:
version_num = client.info()["version"]["number"][0]
version_num = int(version_num)
if version_num >= 8:
client.indices.create(index=index_name, mappings=mapping)
else:
client.indices.create(index=index_name, body={"mappings": mapping})

def client_search(
self, client: Any, index_name: str, script_query: Dict, size: int
) -> Any:
version_num = client.info()["version"]["number"][0]
version_num = int(version_num)
if version_num >= 8:
response = client.search(index=index_name, query=script_query, size=size)
else:
response = client.search(
index=index_name, body={"query": script_query, "size": size}
)
return response

0 comments on commit 4d8555b

Please sign in to comment.