Skip to content

Commit

Permalink
fix pickle issue and relearning logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Jul 5, 2023
1 parent 710176b commit dff5129
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 46 deletions.
103 changes: 60 additions & 43 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,29 @@ def __init__(
self.index = None
self.metadata = IndexMetadata(dirs=[])
self.prev_em_id = None
self.embeddings = None

if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)

self.load_or_create()
self._load_or_create()

def _load_or_create(self):
"""Loads the vector store and creates a new one if none exists."""
embeddings = self.get_embedding_model()
if not embeddings:
return
if self.index is None:
try:
self.index = FAISS.load_local(
INDEX_SAVE_DIR, embeddings, index_name=self.index_name
)
self.load_metadata()
except Exception as e:
self.create()

async def _process_message(self, message: HumanChatMessage):
if not self.index:
self.load_or_create()
self._load_or_create()

# If index is not still there, embeddings are not present
if not self.index:
Expand All @@ -70,7 +83,7 @@ async def _process_message(self, message: HumanChatMessage):
if args.list:
self.reply(self._build_list_response())
return

# Make sure the path exists.
if not len(args.path) == 1:
self.reply(f"{self.parser.format_usage()}", message)
Expand All @@ -82,6 +95,9 @@ async def _process_message(self, message: HumanChatMessage):
self.reply(response, message)
return

# delete and relearn index if embedding model was changed
await self.delete_and_relearn()

if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)

Expand Down Expand Up @@ -127,11 +143,13 @@ async def learn_dir(self, path: str):

delayed = split(path, splitter=splitter)
doc_chunks = await dask_client.compute(delayed)
em = self.get_embedding_model()
delayed = get_embeddings(doc_chunks, em)

em_provider_cls, em_provider_args = self.get_embedding_provider()
delayed = get_embeddings(doc_chunks, em_provider_cls, em_provider_args)
embedding_records = await dask_client.compute(delayed)
self.index.add_embeddings(*embedding_records)
self._add_dir_to_metadata(path)
self.prev_em_id = em_provider_cls.id + ":" + em_provider_args["model_id"]

def _add_dir_to_metadata(self, path: str):
dirs = self.metadata.dirs
Expand All @@ -140,18 +158,37 @@ def _add_dir_to_metadata(self, path: str):
dirs.append(IndexedDir(path=path))
self.metadata.dirs = dirs

def delete_and_relearn(self):
async def delete_and_relearn(self):
"""Delete the vector store and relearn all indexed directories if
necessary. If the embedding model is unchanged, this method does
nothing."""
if not self.metadata.dirs:
self.delete()
return
message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask**
command to work with the new model, I have to re-learn the documents you had previously
submitted for learning. Please wait to use the **/ask** command until I am done with this task."""

em_provider_cls, em_provider_args = self.get_embedding_provider()
curr_em_id = em_provider_cls.id + ":" + em_provider_args["model_id"]
prev_em_id = self.prev_em_id

# TODO: Fix this condition to read the previous EM id from some
# persistent source. Right now, we just skip this validation on server
# init, meaning a user could switch embedding models in the config file
# directly and break their instance.
if (prev_em_id is None) or (prev_em_id == curr_em_id):
return

self.log.info(f"Switching embedding provider from {prev_em_id} to {curr_em_id}.")
message = f"""🔔 Hi there, it seems like you have updated the embeddings
model from `{prev_em_id}` to `{curr_em_id}`. I have to re-learn the
documents you had previously submitted for learning. Please wait to use
the **/ask** command until I am done with this task."""

self.reply(message)

metadata = self.metadata
self.delete()
self.relearn(metadata)
await self.relearn(metadata)
self.prev_em_id = curr_em_id

def delete(self):
self.index = None
Expand All @@ -165,13 +202,15 @@ def delete(self):
os.remove(path)
self.create()

def relearn(self, metadata: IndexMetadata):
async def relearn(self, metadata: IndexMetadata):
# Index all dirs in the metadata
if not metadata.dirs:
return

for dir in metadata.dirs:
self.learn_dir(dir.path)
# TODO: do not relearn directories in serial, but instead
# concurrently or in parallel
await self.learn_dir(dir.path)

self.save()

Expand Down Expand Up @@ -205,19 +244,6 @@ def save_metadata(self):
with open(METADATA_SAVE_PATH, "w") as f:
f.write(self.metadata.json())

def load_or_create(self):
embeddings = self.get_embedding_model()
if not embeddings:
return
if self.index is None:
try:
self.index = FAISS.load_local(
INDEX_SAVE_DIR, embeddings, index_name=self.index_name
)
self.load_metadata()
except Exception as e:
self.create()

def load_metadata(self):
if not os.path.exists(METADATA_SAVE_PATH):
return
Expand All @@ -237,21 +263,12 @@ async def aget_relevant_documents(
) -> Coroutine[Any, Any, List[Document]]:
return self.get_relevant_documents(query)

def get_embedding_model(self):
em_provider = self.config_manager.get_em_provider()
em_provider_params = self.config_manager.get_em_provider_params()
curr_em_id = em_provider_params["model_id"]

if not em_provider:
return None
def get_embedding_provider(self):
em_provider_cls = self.config_manager.get_em_provider()
em_provider_args = self.config_manager.get_em_provider_params()

prev_em_id = self.prev_em_id
if prev_em_id != curr_em_id:
self.log.info(f"Switching embedding model from {prev_em_id} to {curr_em_id}.")
self.embeddings = em_provider(**em_provider_params)
self.prev_em_id = curr_em_id
if prev_em_id:
# delete the index
self.delete_and_relearn()

return self.embeddings
return em_provider_cls, em_provider_args

def get_embedding_model(self):
em_provider_cls, em_provider_args = self.get_embedding_provider()
return em_provider_cls(**em_provider_args)
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def join(embeddings):
return (embedding_records, metadatas)


def embed_chunk(chunk, em):
def embed_chunk(chunk, em_provider_cls, em_provider_args):
em = em_provider_cls(**em_provider_args)
metadata = chunk.metadata
content = chunk.page_content
embedding = em.embed_query(content)
Expand All @@ -90,13 +91,13 @@ def embed_chunk(chunk, em):

# TODO: figure out how to declare the typing of this fn
# dask.delayed.Delayed doesn't work, nor does dask.Delayed
def get_embeddings(chunks, em):
def get_embeddings(chunks, em_provider_cls, em_provider_args):
# split documents in parallel w.r.t. each file
embeddings = []

# compute embeddings in parallel
for chunk in chunks:
embedding = dask.delayed(embed_chunk)(chunk, em)
embedding = dask.delayed(embed_chunk)(chunk, em_provider_cls, em_provider_args)
embeddings.append(embedding)

return dask.delayed(join)(embeddings)

0 comments on commit dff5129

Please sign in to comment.