54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
from langchain import hub
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
from dotenv import load_dotenv
|
|
|
|
"""
|
|
User facing RAG application. Mostly adapted from https://github.com/wu4f/cs410g-src/blob/main/03_RAG/08_rag_query.py
|
|
Small changes made regarding OpenAI Embedding, and loading env from dotenv.
|
|
|
|
I use the same rag-prompt since it's a good choice
|
|
|
|
"""
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def format_docs(docs):
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
|
|
def get_rag_chain():
|
|
return (
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
| prompt
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
|
|
vectorstore = Chroma(
|
|
embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
|
|
)
|
|
prompt = hub.pull("rlm/rag-prompt")
|
|
retriever = vectorstore.as_retriever()
|
|
llm = ChatOpenAI(model="gpt-4")
|
|
|
|
document_data_sources = set()
|
|
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
|
document_data_sources.add(doc_metadata["sourceURL"])
|
|
|
|
if __name__ == "__main__":
|
|
print(
|
|
"Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions "
|
|
)
|
|
while True:
|
|
line = input("llm>> ")
|
|
if line:
|
|
result = get_rag_chain().invoke(line)
|
|
print(result)
|
|
else:
|
|
break
|