format all scripts; fix loader issue from hw1 firecrawl; start hw3
This commit is contained in:
parent
06dbb7fcf2
commit
266d5953cc
16
hw1/app.py
16
hw1/app.py
@ -15,9 +15,11 @@ I use the same rag-prompt since it's a good choice
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
|
||||
def get_rag_chain():
|
||||
return (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
@ -28,19 +30,20 @@ def get_rag_chain():
|
||||
|
||||
|
||||
vectorstore = Chroma(
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory="./rag_data/.chromadb"
|
||||
embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
|
||||
)
|
||||
prompt = hub.pull("rlm/rag-prompt")
|
||||
retriever = vectorstore.as_retriever()
|
||||
llm = ChatOpenAI(model="gpt-4")
|
||||
|
||||
document_data_sources = set()
|
||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
||||
document_data_sources.add(doc_metadata['sourceURL'])
|
||||
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||
document_data_sources.add(doc_metadata["sourceURL"])
|
||||
|
||||
if __name__ == "__main__" :
|
||||
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
|
||||
if __name__ == "__main__":
|
||||
print(
|
||||
"Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions "
|
||||
)
|
||||
while True:
|
||||
line = input("llm>> ")
|
||||
if line:
|
||||
@ -48,4 +51,3 @@ if __name__ == "__main__" :
|
||||
print(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from typing import List
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.document_loaders import FireCrawlLoader
|
||||
from langchain_core.documents import Document
|
||||
from dotenv import load_dotenv
|
||||
|
||||
"""
|
||||
@ -26,59 +28,69 @@ This takes a while to crawl, so just run it once and watch out for firecrawl cre
|
||||
"""
|
||||
|
||||
load_dotenv()
|
||||
page_options = {"onlyMainContent": True}
|
||||
crawl_params = {
|
||||
'crawlerOptions': {
|
||||
#Exclude non-english paths, image resources, etc.
|
||||
'excludes': [
|
||||
'cs',
|
||||
'da',
|
||||
'de',
|
||||
'es',
|
||||
'fi',
|
||||
'fr',
|
||||
'he',
|
||||
'hr',
|
||||
'hu',
|
||||
'it',
|
||||
'ja',
|
||||
'ko',
|
||||
'nl',
|
||||
'no',
|
||||
'pl',
|
||||
'pt',
|
||||
'ru',
|
||||
'sv',
|
||||
'th',
|
||||
'tr',
|
||||
'zh-cn'
|
||||
'.jpg',
|
||||
'.png'
|
||||
'.gif'
|
||||
"crawlerOptions": {
|
||||
# Exclude non-english paths, image resources, etc.
|
||||
"excludes": [
|
||||
"cs",
|
||||
"da",
|
||||
"de",
|
||||
"es",
|
||||
"fi",
|
||||
"fr",
|
||||
"he",
|
||||
"hr",
|
||||
"hu",
|
||||
"it",
|
||||
"ja",
|
||||
"ko",
|
||||
"nl",
|
||||
"no",
|
||||
"pl",
|
||||
"pt",
|
||||
"ru",
|
||||
"sv",
|
||||
"th",
|
||||
"tr",
|
||||
"zh-cn",
|
||||
".jpg",
|
||||
".png",
|
||||
".gif",
|
||||
],
|
||||
'includes': ['wiki/*'],
|
||||
'limit': 75, #higher limit means more credits and more wait time.
|
||||
}
|
||||
"includes": ["wiki/*"],
|
||||
"limit": 75, # higher limit means more credits and more wait time.
|
||||
},
|
||||
"pageOptions": {"onlyMainContent": True},
|
||||
}
|
||||
|
||||
loader = FireCrawlLoader("https://wiki.kerbalspaceprogram.com/wiki/Main_Page", mode="crawl", params=crawl_params)
|
||||
docs = loader.load()
|
||||
print("docs loaded")
|
||||
loader = FireCrawlLoader(
|
||||
"https://wiki.kerbalspaceprogram.com/wiki/Main_Page",
|
||||
mode="crawl",
|
||||
params=crawl_params,
|
||||
)
|
||||
docs: List[Document] = loader.load()
|
||||
|
||||
# Split
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
splits = text_splitter.split_documents(docs)
|
||||
print("split complete")
|
||||
splits: List[Document] = text_splitter.split_documents(docs)
|
||||
|
||||
# This metadata incompatiblity issue should be resolved by the firecrawl maintiner (ogLocalaeAlternate is an empty list, not allowed by Chroma)
|
||||
for doc in splits:
|
||||
doc.metadata.pop("ogLocaleAlternate", None)
|
||||
|
||||
# Embed
|
||||
vectorstore = Chroma.from_documents(documents=splits,
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=splits,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./rag_data/.chromadb")
|
||||
persist_directory="./rag_data/.chromadb",
|
||||
)
|
||||
|
||||
|
||||
print("RAG database initialized with the following sources.")
|
||||
retriever = vectorstore.as_retriever()
|
||||
document_data_sources = set()
|
||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
||||
document_data_sources.add(doc_metadata['sourceURL'])
|
||||
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||
document_data_sources.add(doc_metadata["sourceURL"])
|
||||
for doc in document_data_sources:
|
||||
print(f" {doc}")
|
@ -11,10 +11,10 @@ Adapted from https://github.com/wu4f/cs410g-src/blob/main/03_RAG/07_rag_docsearc
|
||||
|
||||
load_dotenv()
|
||||
vectorstore = Chroma(
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory="./rag_data/.chromadb"
|
||||
embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
|
||||
)
|
||||
|
||||
|
||||
def search_db(query):
|
||||
docs = vectorstore.similarity_search(query)
|
||||
print(f"Query database for: {query}")
|
||||
@ -23,16 +23,19 @@ def search_db(query):
|
||||
else:
|
||||
print("No matching documents")
|
||||
|
||||
|
||||
print("RAG database initialized.")
|
||||
retriever = vectorstore.as_retriever()
|
||||
document_data_sources = set()
|
||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
||||
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||
print(f"docm {doc_metadata}")
|
||||
document_data_sources.add(doc_metadata['sourceURL'])
|
||||
document_data_sources.add(doc_metadata["sourceURL"])
|
||||
for doc in document_data_sources:
|
||||
print(f" {doc}")
|
||||
|
||||
print("This program queries documents in the RAG database that are similar to whatever is entered.")
|
||||
print(
|
||||
"This program queries documents in the RAG database that are similar to whatever is entered."
|
||||
)
|
||||
while True:
|
||||
line = input(">> ")
|
||||
if line:
|
||||
|
27
hw2/app.py
27
hw2/app.py
@ -8,7 +8,8 @@ from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from tools import lookup_ip, lookup_name, search_ksp
|
||||
#from langsmith import Client
|
||||
|
||||
# from langsmith import Client
|
||||
|
||||
"""
|
||||
This is the main runner of the custom agent. Custom agent tools are defined seperatly and imported from tools.py
|
||||
@ -19,27 +20,33 @@ Langsmith code can be uncommeted for testing/debugging
|
||||
"""
|
||||
|
||||
load_dotenv()
|
||||
#os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
#os.environ["LANGCHAIN_PROJECT"] = f"LangSmith Introduction"
|
||||
#os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
#client = Client()
|
||||
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
# os.environ["LANGCHAIN_PROJECT"] = f"LangSmith Introduction"
|
||||
# os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
# client = Client()
|
||||
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
|
||||
|
||||
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm)
|
||||
tools = load_tools(
|
||||
["serpapi", "terminal", "dalle-image-generator", "google-jobs"],
|
||||
allow_dangerous_tools=True,
|
||||
llm=llm,
|
||||
)
|
||||
tools.extend([lookup_name, lookup_ip, search_ksp])
|
||||
|
||||
base_prompt = hub.pull("langchain-ai/react-agent-template")
|
||||
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls")
|
||||
agent = create_react_agent(llm,tools,prompt)
|
||||
prompt = base_prompt.partial(
|
||||
instructions="Answer the user's request utilizing at most 8 tool calls"
|
||||
)
|
||||
agent = create_react_agent(llm, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
print("Welcome to my application. I am configured with these tools:")
|
||||
for tool in agent_executor.tools:
|
||||
print(f' Tool: {tool.name} = {tool.description}')
|
||||
print(f" Tool: {tool.name} = {tool.description}")
|
||||
while True:
|
||||
line = input("llm>> ")
|
||||
try:
|
||||
if line:
|
||||
result = agent_executor.invoke({"input":line})
|
||||
result = agent_executor.invoke({"input": line})
|
||||
print(result)
|
||||
else:
|
||||
break
|
||||
|
35
hw2/tools.py
35
hw2/tools.py
@ -4,7 +4,8 @@ import dns.resolver, dns.reversename
|
||||
import validators
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from hw1.app import get_rag_chain
|
||||
|
||||
"""
|
||||
@ -14,34 +15,46 @@ These are the same tooling provided by the example in https://github.com/wu4f/cs
|
||||
with the addition of my Kerbal Space Program RAG Application tool
|
||||
"""
|
||||
|
||||
|
||||
class LookupNameInput(BaseModel):
|
||||
hostname: str = Field(description="Should be a hostname such as www.google.com")
|
||||
|
||||
@root_validator
|
||||
def is_dns_address(cls, values: dict[str,any]) -> str:
|
||||
def is_dns_address(cls, values: dict[str, any]) -> str:
|
||||
if validators.domain(values.get("hostname")):
|
||||
return values
|
||||
raise ValueError("Malformed hostname")
|
||||
|
||||
|
||||
class LookupIPInput(BaseModel):
|
||||
address: str = Field(description="Should be an IP address such as 208.91.197.27 or 143.95.239.83")
|
||||
address: str = Field(
|
||||
description="Should be an IP address such as 208.91.197.27 or 143.95.239.83"
|
||||
)
|
||||
|
||||
@root_validator
|
||||
def is_ip_address(cls, values: dict[str,any]) -> str:
|
||||
def is_ip_address(cls, values: dict[str, any]) -> str:
|
||||
if validators.ip_address.ipv4(values.get("address")):
|
||||
return values
|
||||
raise ValueError("Malformed IP address")
|
||||
|
||||
|
||||
class KSPTool(BaseModel):
|
||||
query: str = Field(description="should be a kerbal space program (ksp) related query")
|
||||
query: str = Field(
|
||||
description="should be a kerbal space program (ksp) related query"
|
||||
)
|
||||
|
||||
|
||||
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
|
||||
def search_ksp(query:str) -> str:
|
||||
def search_ksp(query: str) -> str:
|
||||
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
|
||||
return get_rag_chain().invoke(query)
|
||||
|
||||
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
|
||||
|
||||
@tool("lookup_name", args_schema=LookupNameInput, return_direct=False)
|
||||
def lookup_name(hostname):
|
||||
"""Given a DNS hostname, it will return its IPv4 addresses"""
|
||||
result = dns.resolver.resolve(hostname, 'A')
|
||||
res = [ r.to_text() for r in result ]
|
||||
result = dns.resolver.resolve(hostname, "A")
|
||||
res = [r.to_text() for r in result]
|
||||
return res[0]
|
||||
|
||||
|
||||
@ -49,6 +62,6 @@ def lookup_name(hostname):
|
||||
def lookup_ip(address):
|
||||
"""Given an IP address, returns names associated with it"""
|
||||
n = dns.reversename.from_address(address)
|
||||
result = dns.resolver.resolve(n, 'PTR')
|
||||
res = [ r.to_text() for r in result ]
|
||||
result = dns.resolver.resolve(n, "PTR")
|
||||
res = [r.to_text() for r in result]
|
||||
return res[0]
|
||||
|
17
hw3/notes.MD
Normal file
17
hw3/notes.MD
Normal file
@ -0,0 +1,17 @@
|
||||
# Security testing
|
||||
|
||||
## LangChain RAG application (hw1)
|
||||
### Indirect prompt injection
|
||||
todo
|
||||
### Insecure output handling
|
||||
todo
|
||||
### Data poisoning
|
||||
todo
|
||||
|
||||
## LangChain agent (hw2)
|
||||
### Excessive agency
|
||||
todo
|
||||
### Insecure tool design
|
||||
todo
|
||||
### Sensitive information exposure
|
||||
todo
|
Reference in New Issue
Block a user