Skip to content

Commit

Permalink
openai-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishMahendra committed Jul 5, 2024
1 parent d18552d commit 8988501
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 105 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ipython_config.py
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

config.json
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

Expand Down
81 changes: 37 additions & 44 deletions ask_monk/app.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,18 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from utils.pdf_processing import process_documents
from utils.question_answering import answer_question_from_pdf
from utils.question_answering import answer_question
from utils.highlight import highlight_text_in_pdf
import os
import logging
import traceback
import tempfile
import httpx
import boto3
import uvicorn
from urllib.parse import urlparse
from threading import Lock


class FileInfo(BaseModel):
id: int
name: str
file: str


class FolderInfo(BaseModel):
id: int
uid: str
name: str
slug: str
user_id: int
url: str
files: List[FileInfo]


class DocumentData(BaseModel):
data: FolderInfo


class QuestionRequest(BaseModel):
question: str
document_url: str


class FeedbackModel(BaseModel):
user_prompt: str
feedback: str


class HighlightRequest(BaseModel):
pdf_name: str
page_number: int
highlight_text: str

from models.schemas import DocumentData, QuestionRequest, FeedbackModel, HighlightRequest

app = FastAPI()

Expand All @@ -69,8 +31,8 @@ class HighlightRequest(BaseModel):
async def ingest_from_json(document_data: DocumentData):
file_urls = [file.file for file in document_data.data.files]
try:
results = await process_documents(file_urls)
return {"message": "Documents processed and vectorstore updated successfully", "results": results}
await process_documents(file_urls)
return {"message": "Documents processed and vectorstore updated successfully"}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
Expand All @@ -80,8 +42,37 @@ async def ingest_from_json(document_data: DocumentData):
async def prompt_route(question_request: QuestionRequest):
with request_lock:
try:
answer, references = answer_question_from_pdf(question_request.document_url, question_request.question)
return {"answer": answer, "references": references}
answer, references = answer_question(question_request.question)
# Use a set to avoid duplicate file names in the source list
seen_files = set()
source_list = []

if not answer or "I apologize" in answer or "there is no information" in answer:
for document in references:
pdf_name = document.metadata["source"]
if pdf_name not in seen_files:
seen_files.add(pdf_name)
source_list.append({"PDF": pdf_name})
answer = "I apologize, but I'm unable to find detailed information on this topic. Please refer to the following sources for more information."
else:
for document in references:
pdf_name = document.metadata["source"]
source_list.append(
{
"filename": pdf_name,
"pageNumber": document.metadata.get("page_number"),
"highlightText": str(document.page_content),
}
)

prompt_response_dict = {
"Prompt": question_request.question,
"Answer": answer,
"Sources": source_list,
}

return prompt_response_dict

except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -142,4 +133,6 @@ def upload_image_to_s3(image_path, bucket, object_name):


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
1 change: 0 additions & 1 deletion ask_monk/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class DocumentData(BaseModel):

class QuestionRequest(BaseModel):
question: str
document_url: str


class FeedbackModel(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion ask_monk/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ openai
httpx
boto3
pydantic
langchain
tiktoken
langchain
langchain-community
32 changes: 24 additions & 8 deletions ask_monk/utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
from openai import OpenAI
import os
from langchain.vectorstores import Chroma
from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from .constants import CHROMA_SETTINGS, PERSIST_DIRECTORY

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


def get_openai_embeddings(text):
response = client.embeddings.create(model="text-embedding-ada-002", input=text)
return response.data[0].embedding
class EmbeddingWrapper(OpenAIEmbeddings):
def embed_documents(self, documents):
# This will now handle multiple documents correctly
return get_openai_embeddings(documents)

def embed_query(self, query):
# Handle a single query by wrapping it in a list
return get_openai_embeddings([query])[0]

def get_embeddings(text_chunks):

def get_openai_embeddings(texts):
embeddings = []
for text in texts:
# Ensure that each text is properly formatted and sent separately
response = client.embeddings.create(model="text-embedding-ada-002", input=[text])
embeddings.append(response.data[0].embedding)
return embeddings


def get_single_embedding(text_chunks):
embeddings = []
for chunk in text_chunks:
embedding = get_openai_embeddings(chunk)
embeddings.append(embedding)
return embeddings


def save_embeddings(embeddings, texts):
def save_embeddings(texts):
if not os.path.exists(PERSIST_DIRECTORY):
os.makedirs(PERSIST_DIRECTORY)

db = Chroma.from_embeddings(
embeddings = EmbeddingWrapper()
db = Chroma.from_documents(
texts,
embeddings,
persist_directory=PERSIST_DIRECTORY,
Expand Down
2 changes: 1 addition & 1 deletion ask_monk/utils/highlight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pdfplumber
from pdf_processing import preprocess_text
from .pdf_processing import preprocess_text
import re


Expand Down
5 changes: 2 additions & 3 deletions ask_monk/utils/pdf_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
import re
from utils.embeddings import get_embeddings, save_embeddings
from utils.embeddings import save_embeddings


def file_log(logentry):
Expand Down Expand Up @@ -63,8 +63,7 @@ async def process_documents(file_urls):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
logging.info(f"Loaded {len(texts)} chunks of text from documents")
embeddings = get_embeddings(texts)
save_embeddings(embeddings, texts)
save_embeddings(texts)
except Exception as e:
traceback.print_exc()
logging.error(f"Error processing PDF from {url}: {e}")
109 changes: 63 additions & 46 deletions ask_monk/utils/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,64 @@
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.llms import OpenAI
import os

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


def get_embeddings(text_chunks):
embeddings = []
for chunk in text_chunks:
response = client.embeddings.create(model="text-embedding-ada-002", input=chunk)
embeddings.append(response.data[0].embedding)
return embeddings


def get_most_relevant_chunks(question, chunk_embeddings, chunks, top_k=5):
question_embedding = client.embeddings.create(model="text-embedding-ada-002", input=question).data[0].embedding

similarities = cosine_similarity([question_embedding], chunk_embeddings)[0]

top_k_indices = similarities.argsort()[-top_k:][::-1]
relevant_chunks = [chunks[i] for i in top_k_indices]

return relevant_chunks


def generate_answer(question, relevant_chunks):
context = "\n\n".join(relevant_chunks)
prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"

response = client.completions.create(model="gpt-3.5-turbo-instruct-0914", prompt=prompt, max_tokens=200)

answer = response.choices[0].text.strip()
return answer


def answer_question_from_pdf(pdf_path, question):
from .pdf_processing import extract_text_from_pdf, semantic_chunk_text

document_text = extract_text_from_pdf(pdf_path)
chunks = semantic_chunk_text(document_text)
chunk_embeddings = get_embeddings(chunks)
relevant_chunks = get_most_relevant_chunks(question, chunk_embeddings, chunks)
answer = generate_answer(question, relevant_chunks)
references = relevant_chunks

return answer, references
from .constants import CHROMA_SETTINGS, PERSIST_DIRECTORY
from .embeddings import EmbeddingWrapper
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
from langchain.callbacks.manager import CallbackManager
import traceback


callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
# Setup
api_key = os.getenv("OPENAI_API_KEY")
chroma_db_path = PERSIST_DIRECTORY
model_name = "gpt-3.5-turbo-instruct-0914"

embedding_func = EmbeddingWrapper()
# Initialize the components
llm = OpenAI(model=model_name)
chroma_db = Chroma(
persist_directory=PERSIST_DIRECTORY,
embedding_function=embedding_func,
client_settings=CHROMA_SETTINGS,
)


# Define the prompt structure for the QA
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""You are a helpful assistant, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you can not answer a user question based on
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question.
Context: {context}
Question: {question}
""",
)

# retrieval_qa = RetrievalQA.from_chain_type(
# retriever=chroma_db.as_retriever(), chain_type="stuff", llm=model, prompt_template=prompt_template
# )
retrieval_qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=chroma_db.as_retriever(),
return_source_documents=True,
callbacks=callback_manager,
chain_type_kwargs={
"prompt": prompt_template,
},
)


# Function to answer questions
def answer_question(question):
try:
result = retrieval_qa(question)
answer = result["result"]
return answer, result["source_documents"]
except Exception as e:
traceback.print_exc()
print(f"An error occurred: {e}")
raise e

0 comments on commit 8988501

Please sign in to comment.