forked from PromtEngineer/localGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d18552d
commit 8988501
Showing
8 changed files
with
131 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,6 @@ openai | |
httpx | ||
boto3 | ||
pydantic | ||
langchain | ||
tiktoken | ||
langchain | ||
langchain-community |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,64 @@ | ||
from openai import OpenAI | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
from langchain.chains import RetrievalQA | ||
from langchain.prompts import PromptTemplate | ||
from langchain_community.vectorstores import Chroma | ||
from langchain_community.llms import OpenAI | ||
import os | ||
|
||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | ||
|
||
|
||
def get_embeddings(text_chunks): | ||
embeddings = [] | ||
for chunk in text_chunks: | ||
response = client.embeddings.create(model="text-embedding-ada-002", input=chunk) | ||
embeddings.append(response.data[0].embedding) | ||
return embeddings | ||
|
||
|
||
def get_most_relevant_chunks(question, chunk_embeddings, chunks, top_k=5): | ||
question_embedding = client.embeddings.create(model="text-embedding-ada-002", input=question).data[0].embedding | ||
|
||
similarities = cosine_similarity([question_embedding], chunk_embeddings)[0] | ||
|
||
top_k_indices = similarities.argsort()[-top_k:][::-1] | ||
relevant_chunks = [chunks[i] for i in top_k_indices] | ||
|
||
return relevant_chunks | ||
|
||
|
||
def generate_answer(question, relevant_chunks): | ||
context = "\n\n".join(relevant_chunks) | ||
prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:" | ||
|
||
response = client.completions.create(model="gpt-3.5-turbo-instruct-0914", prompt=prompt, max_tokens=200) | ||
|
||
answer = response.choices[0].text.strip() | ||
return answer | ||
|
||
|
||
def answer_question_from_pdf(pdf_path, question): | ||
from .pdf_processing import extract_text_from_pdf, semantic_chunk_text | ||
|
||
document_text = extract_text_from_pdf(pdf_path) | ||
chunks = semantic_chunk_text(document_text) | ||
chunk_embeddings = get_embeddings(chunks) | ||
relevant_chunks = get_most_relevant_chunks(question, chunk_embeddings, chunks) | ||
answer = generate_answer(question, relevant_chunks) | ||
references = relevant_chunks | ||
|
||
return answer, references | ||
from .constants import CHROMA_SETTINGS, PERSIST_DIRECTORY | ||
from .embeddings import EmbeddingWrapper | ||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response | ||
from langchain.callbacks.manager import CallbackManager | ||
import traceback | ||
|
||
|
||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | ||
# Setup | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
chroma_db_path = PERSIST_DIRECTORY | ||
model_name = "gpt-3.5-turbo-instruct-0914" | ||
|
||
embedding_func = EmbeddingWrapper() | ||
# Initialize the components | ||
llm = OpenAI(model=model_name) | ||
chroma_db = Chroma( | ||
persist_directory=PERSIST_DIRECTORY, | ||
embedding_function=embedding_func, | ||
client_settings=CHROMA_SETTINGS, | ||
) | ||
|
||
|
||
# Define the prompt structure for the QA | ||
prompt_template = PromptTemplate( | ||
input_variables=["context", "question"], | ||
template="""You are a helpful assistant, you will use the provided context to answer user questions. | ||
Read the given context before answering questions and think step by step. If you can not answer a user question based on | ||
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question. | ||
Context: {context} | ||
Question: {question} | ||
""", | ||
) | ||
|
||
# retrieval_qa = RetrievalQA.from_chain_type( | ||
# retriever=chroma_db.as_retriever(), chain_type="stuff", llm=model, prompt_template=prompt_template | ||
# ) | ||
retrieval_qa = RetrievalQA.from_chain_type( | ||
llm=llm, | ||
chain_type="stuff", | ||
retriever=chroma_db.as_retriever(), | ||
return_source_documents=True, | ||
callbacks=callback_manager, | ||
chain_type_kwargs={ | ||
"prompt": prompt_template, | ||
}, | ||
) | ||
|
||
|
||
# Function to answer questions | ||
def answer_question(question): | ||
try: | ||
result = retrieval_qa(question) | ||
answer = result["result"] | ||
return answer, result["source_documents"] | ||
except Exception as e: | ||
traceback.print_exc() | ||
print(f"An error occurred: {e}") | ||
raise e |