Skip to content

Commit

Permalink
Merge pull request #9 from charliehuang09/parent_docs
Browse files Browse the repository at this point in the history
add parent_docs
  • Loading branch information
maxstrid committed Dec 25, 2023
2 parents 6979760 + 528500c commit 17acdba
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
57 changes: 45 additions & 12 deletions scripts/encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,33 @@
import numpy as np
from dotenv import load_dotenv

from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredPDFLoader, UnstructuredRSTLoader
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredPDFLoader, UnstructuredRSTLoader, UnstructuredMarkdownLoader
from langchain.embeddings import OpenAIEmbeddings

from langchain.text_splitter import TokenTextSplitter


class CustomDataLoader:

def __init__(self, chunk_size=500) -> None:
def __init__(self, chunk_size=500, parent_chunk_size=1000, child_chunk_size=100, do_parent_document=False, thread=False) -> None:
'''
if do_parent_document = true, than the dataloader will generate a parent documents and child documents from the data
parent_chunk_size and child_chunk_size are parameters of do_parent_document
'''
self.files = []

self.thread = thread
self.documents = []

self.do_parent_document = do_parent_document

self.embeddings_model = OpenAIEmbeddings(
model="text-embedding-ada-002")

self.text_splitter = TokenTextSplitter(chunk_size=chunk_size,
chunk_overlap=0)
self.text_splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=0)

self.parent_splitter = TokenTextSplitter(chunk_size=parent_chunk_size, chunk_overlap=0)

self.child_splitter = TokenTextSplitter(chunk_size=child_chunk_size, chunk_overlap=0)

self.counter = 0

Expand All @@ -42,16 +51,34 @@ def embed_documents(self) -> None:

if name == "data/links":
continue

thread = threading.Thread(target=self._process_document, args=(file,))

self.semaphore.acquire()

thread.start()
if self.thread:
threading.Thread(target=self._process_document, args=(file,)).start()
self.semaphore.acquire()
else:
self._process_document(file)

def save(self) -> None:
print(len(self.documents))
np.save('data.npy', np.asarray(self.documents, dtype=object))


def _generate_child_docs(self, parent_document):
child_docs = self.child_splitter.split_documents([parent_document])
vector_responses = self.embeddings_model.embed_documents(
list(map(lambda document: document.page_content, child_docs))
)
for doc in child_docs:
doc.page_content = parent_document.page_content
document_map = []
for doc in zip(vector_responses, child_docs):
document_map.append({
"vector": doc[0],
"document": doc[1]
})
return child_docs


def _process_document(self, file) -> None:
_, extension = os.path.splitext(file)

Expand All @@ -62,9 +89,12 @@ def _process_document(self, file) -> None:
loader = UnstructuredPDFLoader(file)
case ".rst":
loader = UnstructuredRSTLoader(file)
case ".md":
loader = UnstructuredMarkdownLoader(file)


documents = self.text_splitter.split_documents(loader.load())
parent_documents = self.parent_splitter.split_documents(loader.load())
vector_responses = self.embeddings_model.embed_documents(
list(map(lambda document: document.page_content, documents))
)
Expand All @@ -76,6 +106,9 @@ def _process_document(self, file) -> None:
"vector": doc[0],
"document": doc[1]
})
if self.do_parent_document:
for doc in parent_documents:
document_map.extend(self._generate_child_docs(doc))

with self.lock:
print('\r', f'Embedding progress: {self.counter + 1}/{len(self.files) - 1}')
Expand All @@ -93,7 +126,7 @@ def _load_files(self) -> None:
name, extension = os.path.splitext(full_path)

match extension:
case ".pdf" | ".rst":
case ".pdf" | ".rst" | ".md":
pass
case _:
# Links is a special case, its where we load arbitrary html
Expand Down
5 changes: 3 additions & 2 deletions scripts/qdrant_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

def main():
client = QdrantClient("0.0.0.0", port=6333)
client.create_collection(collection_name='default',
collection_name = 'default'
client.create_collection(collection_name=collection_name,
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE))

Expand All @@ -19,7 +20,7 @@ def main():
end='',
flush=True)

client.upsert(collection_name='default',
client.upsert(collection_name=collection_name,
points=[
models.PointStruct(id=i,
payload={
Expand Down

0 comments on commit 17acdba

Please sign in to comment.