From 18d64d0651c64d311d06df9c565e81a429edd75f Mon Sep 17 00:00:00 2001 From: David Westgate Date: Thu, 25 Apr 2024 13:55:29 -0700 Subject: [PATCH] add support for ksp rag from hw1 --- hw1/__init__.py | 0 hw1/app.py | 33 ++++++++++++++++++--------------- hw2/__init__.py | 0 hw2/app.py | 4 ++-- hw2/tools.py | 11 +++++++++++ 5 files changed, 31 insertions(+), 17 deletions(-) create mode 100644 hw1/__init__.py create mode 100644 hw2/__init__.py diff --git a/hw1/__init__.py b/hw1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hw1/app.py b/hw1/app.py index 4dca458..15430b5 100644 --- a/hw1/app.py +++ b/hw1/app.py @@ -18,6 +18,15 @@ load_dotenv() def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) +def get_rag_chain(): + return ( + {"context": retriever | format_docs, "question": RunnablePassthrough()} + | prompt + | llm + | StrOutputParser() + ) + + vectorstore = Chroma( embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb" @@ -26,23 +35,17 @@ prompt = hub.pull("rlm/rag-prompt") retriever = vectorstore.as_retriever() llm = ChatOpenAI(model="gpt-4") -rag_chain = ( - {"context": retriever | format_docs, "question": RunnablePassthrough()} - | prompt - | llm - | StrOutputParser() -) - -print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ") document_data_sources = set() for doc_metadata in retriever.vectorstore.get()['metadatas']: document_data_sources.add(doc_metadata['sourceURL']) -while True: - line = input("llm>> ") - if line: - result = rag_chain.invoke(line) - print(result) - else: - break +if __name__ == "__main__" : + print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ") + while True: + line = input("llm>> ") + if line: + result = get_rag_chain().invoke(line) + print(result) + else: + break diff --git a/hw2/__init__.py b/hw2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hw2/app.py b/hw2/app.py index 3d3198b..1693abe 100644 --- a/hw2/app.py +++ b/hw2/app.py @@ -7,7 +7,7 @@ from langchain_community.tools.google_jobs import GoogleJobsQueryRun from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper from dotenv import load_dotenv -from tools import lookup_ip, lookup_name +from tools import lookup_ip, lookup_name, search_ksp from langsmith import Client """ @@ -25,7 +25,7 @@ load_dotenv() llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0) tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm) -tools.extend([lookup_name, lookup_ip]) +tools.extend([lookup_name, lookup_ip, search_ksp]) base_prompt = hub.pull("langchain-ai/react-agent-template") prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls") diff --git a/hw2/tools.py b/hw2/tools.py index 78507a7..f401f66 100644 --- a/hw2/tools.py +++ b/hw2/tools.py @@ -2,6 +2,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain.tools import tool import dns.resolver, dns.reversename import validators +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from hw1.app import get_rag_chain """ Tools to be run by my custom agent. @@ -25,6 +29,13 @@ class LookupIPInput(BaseModel): return values raise ValueError("Malformed IP address") +class KSPTool(BaseModel): + query: str = Field(description="should be a kerbal space program (ksp) related query") + +@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False) +def search_ksp(query:str) -> str: + """Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation""" + return get_rag_chain().invoke(query) @tool("lookup_name",args_schema=LookupNameInput, return_direct=False) def lookup_name(hostname):