add support for ksp rag from hw1

This commit is contained in:
David Westgate 2024-04-25 13:55:29 -07:00
parent fd20f8d694
commit 18d64d0651
5 changed files with 31 additions and 17 deletions

0
hw1/__init__.py Normal file
View File

View File

@ -18,6 +18,15 @@ load_dotenv()
def format_docs(docs): def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs) return "\n\n".join(doc.page_content for doc in docs)
def get_rag_chain():
return (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
vectorstore = Chroma( vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(), embedding_function=OpenAIEmbeddings(),
persist_directory="./rag_data/.chromadb" persist_directory="./rag_data/.chromadb"
@ -26,23 +35,17 @@ prompt = hub.pull("rlm/rag-prompt")
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
llm = ChatOpenAI(model="gpt-4") llm = ChatOpenAI(model="gpt-4")
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
document_data_sources = set() document_data_sources = set()
for doc_metadata in retriever.vectorstore.get()['metadatas']: for doc_metadata in retriever.vectorstore.get()['metadatas']:
document_data_sources.add(doc_metadata['sourceURL']) document_data_sources.add(doc_metadata['sourceURL'])
while True: if __name__ == "__main__" :
line = input("llm>> ") print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
if line: while True:
result = rag_chain.invoke(line) line = input("llm>> ")
print(result) if line:
else: result = get_rag_chain().invoke(line)
break print(result)
else:
break

0
hw2/__init__.py Normal file
View File

View File

@ -7,7 +7,7 @@ from langchain_community.tools.google_jobs import GoogleJobsQueryRun
from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
from dotenv import load_dotenv from dotenv import load_dotenv
from tools import lookup_ip, lookup_name from tools import lookup_ip, lookup_name, search_ksp
from langsmith import Client from langsmith import Client
""" """
@ -25,7 +25,7 @@ load_dotenv()
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0) llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm) tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm)
tools.extend([lookup_name, lookup_ip]) tools.extend([lookup_name, lookup_ip, search_ksp])
base_prompt = hub.pull("langchain-ai/react-agent-template") base_prompt = hub.pull("langchain-ai/react-agent-template")
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls") prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls")

View File

@ -2,6 +2,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools import tool from langchain.tools import tool
import dns.resolver, dns.reversename import dns.resolver, dns.reversename
import validators import validators
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from hw1.app import get_rag_chain
""" """
Tools to be run by my custom agent. Tools to be run by my custom agent.
@ -25,6 +29,13 @@ class LookupIPInput(BaseModel):
return values return values
raise ValueError("Malformed IP address") raise ValueError("Malformed IP address")
class KSPTool(BaseModel):
query: str = Field(description="should be a kerbal space program (ksp) related query")
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
def search_ksp(query:str) -> str:
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
return get_rag_chain().invoke(query)
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False) @tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
def lookup_name(hostname): def lookup_name(hostname):