add support for ksp rag from hw1
This commit is contained in:
parent
fd20f8d694
commit
18d64d0651
0
hw1/__init__.py
Normal file
0
hw1/__init__.py
Normal file
23
hw1/app.py
23
hw1/app.py
@ -18,6 +18,15 @@ load_dotenv()
|
|||||||
def format_docs(docs):
|
def format_docs(docs):
|
||||||
return "\n\n".join(doc.page_content for doc in docs)
|
return "\n\n".join(doc.page_content for doc in docs)
|
||||||
|
|
||||||
|
def get_rag_chain():
|
||||||
|
return (
|
||||||
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||||
|
| prompt
|
||||||
|
| llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
vectorstore = Chroma(
|
vectorstore = Chroma(
|
||||||
embedding_function=OpenAIEmbeddings(),
|
embedding_function=OpenAIEmbeddings(),
|
||||||
persist_directory="./rag_data/.chromadb"
|
persist_directory="./rag_data/.chromadb"
|
||||||
@ -26,22 +35,16 @@ prompt = hub.pull("rlm/rag-prompt")
|
|||||||
retriever = vectorstore.as_retriever()
|
retriever = vectorstore.as_retriever()
|
||||||
llm = ChatOpenAI(model="gpt-4")
|
llm = ChatOpenAI(model="gpt-4")
|
||||||
|
|
||||||
rag_chain = (
|
|
||||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
||||||
| prompt
|
|
||||||
| llm
|
|
||||||
| StrOutputParser()
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
|
|
||||||
document_data_sources = set()
|
document_data_sources = set()
|
||||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
||||||
document_data_sources.add(doc_metadata['sourceURL'])
|
document_data_sources.add(doc_metadata['sourceURL'])
|
||||||
|
|
||||||
while True:
|
if __name__ == "__main__" :
|
||||||
|
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
|
||||||
|
while True:
|
||||||
line = input("llm>> ")
|
line = input("llm>> ")
|
||||||
if line:
|
if line:
|
||||||
result = rag_chain.invoke(line)
|
result = get_rag_chain().invoke(line)
|
||||||
print(result)
|
print(result)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
0
hw2/__init__.py
Normal file
0
hw2/__init__.py
Normal file
@ -7,7 +7,7 @@ from langchain_community.tools.google_jobs import GoogleJobsQueryRun
|
|||||||
from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
|
from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from tools import lookup_ip, lookup_name
|
from tools import lookup_ip, lookup_name, search_ksp
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -25,7 +25,7 @@ load_dotenv()
|
|||||||
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
|
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
|
||||||
|
|
||||||
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm)
|
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm)
|
||||||
tools.extend([lookup_name, lookup_ip])
|
tools.extend([lookup_name, lookup_ip, search_ksp])
|
||||||
|
|
||||||
base_prompt = hub.pull("langchain-ai/react-agent-template")
|
base_prompt = hub.pull("langchain-ai/react-agent-template")
|
||||||
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls")
|
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls")
|
||||||
|
11
hw2/tools.py
11
hw2/tools.py
@ -2,6 +2,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
import dns.resolver, dns.reversename
|
import dns.resolver, dns.reversename
|
||||||
import validators
|
import validators
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
from hw1.app import get_rag_chain
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Tools to be run by my custom agent.
|
Tools to be run by my custom agent.
|
||||||
@ -25,6 +29,13 @@ class LookupIPInput(BaseModel):
|
|||||||
return values
|
return values
|
||||||
raise ValueError("Malformed IP address")
|
raise ValueError("Malformed IP address")
|
||||||
|
|
||||||
|
class KSPTool(BaseModel):
|
||||||
|
query: str = Field(description="should be a kerbal space program (ksp) related query")
|
||||||
|
|
||||||
|
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
|
||||||
|
def search_ksp(query:str) -> str:
|
||||||
|
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
|
||||||
|
return get_rag_chain().invoke(query)
|
||||||
|
|
||||||
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
|
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
|
||||||
def lookup_name(hostname):
|
def lookup_name(hostname):
|
||||||
|
Reference in New Issue
Block a user