format all scripts; fix loader issue from hw1 firecrawl; start hw3
This commit is contained in:
parent
06dbb7fcf2
commit
266d5953cc
16
hw1/app.py
16
hw1/app.py
@ -15,9 +15,11 @@ I use the same rag-prompt since it's a good choice
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def format_docs(docs):
|
def format_docs(docs):
|
||||||
return "\n\n".join(doc.page_content for doc in docs)
|
return "\n\n".join(doc.page_content for doc in docs)
|
||||||
|
|
||||||
|
|
||||||
def get_rag_chain():
|
def get_rag_chain():
|
||||||
return (
|
return (
|
||||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||||
@ -28,19 +30,20 @@ def get_rag_chain():
|
|||||||
|
|
||||||
|
|
||||||
vectorstore = Chroma(
|
vectorstore = Chroma(
|
||||||
embedding_function=OpenAIEmbeddings(),
|
embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
|
||||||
persist_directory="./rag_data/.chromadb"
|
|
||||||
)
|
)
|
||||||
prompt = hub.pull("rlm/rag-prompt")
|
prompt = hub.pull("rlm/rag-prompt")
|
||||||
retriever = vectorstore.as_retriever()
|
retriever = vectorstore.as_retriever()
|
||||||
llm = ChatOpenAI(model="gpt-4")
|
llm = ChatOpenAI(model="gpt-4")
|
||||||
|
|
||||||
document_data_sources = set()
|
document_data_sources = set()
|
||||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||||
document_data_sources.add(doc_metadata['sourceURL'])
|
document_data_sources.add(doc_metadata["sourceURL"])
|
||||||
|
|
||||||
if __name__ == "__main__" :
|
if __name__ == "__main__":
|
||||||
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ")
|
print(
|
||||||
|
"Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions "
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
line = input("llm>> ")
|
line = input("llm>> ")
|
||||||
if line:
|
if line:
|
||||||
@ -48,4 +51,3 @@ if __name__ == "__main__" :
|
|||||||
print(result)
|
print(result)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from typing import List
|
||||||
from langchain_community.vectorstores import Chroma
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from langchain_community.document_loaders import FireCrawlLoader
|
from langchain_community.document_loaders import FireCrawlLoader
|
||||||
|
from langchain_core.documents import Document
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -26,59 +28,69 @@ This takes a while to crawl, so just run it once and watch out for firecrawl cre
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
page_options = {"onlyMainContent": True}
|
||||||
crawl_params = {
|
crawl_params = {
|
||||||
'crawlerOptions': {
|
"crawlerOptions": {
|
||||||
#Exclude non-english paths, image resources, etc.
|
# Exclude non-english paths, image resources, etc.
|
||||||
'excludes': [
|
"excludes": [
|
||||||
'cs',
|
"cs",
|
||||||
'da',
|
"da",
|
||||||
'de',
|
"de",
|
||||||
'es',
|
"es",
|
||||||
'fi',
|
"fi",
|
||||||
'fr',
|
"fr",
|
||||||
'he',
|
"he",
|
||||||
'hr',
|
"hr",
|
||||||
'hu',
|
"hu",
|
||||||
'it',
|
"it",
|
||||||
'ja',
|
"ja",
|
||||||
'ko',
|
"ko",
|
||||||
'nl',
|
"nl",
|
||||||
'no',
|
"no",
|
||||||
'pl',
|
"pl",
|
||||||
'pt',
|
"pt",
|
||||||
'ru',
|
"ru",
|
||||||
'sv',
|
"sv",
|
||||||
'th',
|
"th",
|
||||||
'tr',
|
"tr",
|
||||||
'zh-cn'
|
"zh-cn",
|
||||||
'.jpg',
|
".jpg",
|
||||||
'.png'
|
".png",
|
||||||
'.gif'
|
".gif",
|
||||||
],
|
],
|
||||||
'includes': ['wiki/*'],
|
"includes": ["wiki/*"],
|
||||||
'limit': 75, #higher limit means more credits and more wait time.
|
"limit": 75, # higher limit means more credits and more wait time.
|
||||||
}
|
},
|
||||||
|
"pageOptions": {"onlyMainContent": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
loader = FireCrawlLoader("https://wiki.kerbalspaceprogram.com/wiki/Main_Page", mode="crawl", params=crawl_params)
|
loader = FireCrawlLoader(
|
||||||
docs = loader.load()
|
"https://wiki.kerbalspaceprogram.com/wiki/Main_Page",
|
||||||
print("docs loaded")
|
mode="crawl",
|
||||||
|
params=crawl_params,
|
||||||
|
)
|
||||||
|
docs: List[Document] = loader.load()
|
||||||
|
|
||||||
# Split
|
# Split
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||||
splits = text_splitter.split_documents(docs)
|
splits: List[Document] = text_splitter.split_documents(docs)
|
||||||
print("split complete")
|
|
||||||
|
# This metadata incompatiblity issue should be resolved by the firecrawl maintiner (ogLocalaeAlternate is an empty list, not allowed by Chroma)
|
||||||
|
for doc in splits:
|
||||||
|
doc.metadata.pop("ogLocaleAlternate", None)
|
||||||
|
|
||||||
# Embed
|
# Embed
|
||||||
vectorstore = Chroma.from_documents(documents=splits,
|
vectorstore = Chroma.from_documents(
|
||||||
embedding=OpenAIEmbeddings(),
|
documents=splits,
|
||||||
persist_directory="./rag_data/.chromadb")
|
embedding=OpenAIEmbeddings(),
|
||||||
|
persist_directory="./rag_data/.chromadb",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
print("RAG database initialized with the following sources.")
|
print("RAG database initialized with the following sources.")
|
||||||
retriever = vectorstore.as_retriever()
|
retriever = vectorstore.as_retriever()
|
||||||
document_data_sources = set()
|
document_data_sources = set()
|
||||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||||
document_data_sources.add(doc_metadata['sourceURL'])
|
document_data_sources.add(doc_metadata["sourceURL"])
|
||||||
for doc in document_data_sources:
|
for doc in document_data_sources:
|
||||||
print(f" {doc}")
|
print(f" {doc}")
|
@ -11,10 +11,10 @@ Adapted from https://github.com/wu4f/cs410g-src/blob/main/03_RAG/07_rag_docsearc
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
vectorstore = Chroma(
|
vectorstore = Chroma(
|
||||||
embedding_function=OpenAIEmbeddings(),
|
embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
|
||||||
persist_directory="./rag_data/.chromadb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def search_db(query):
|
def search_db(query):
|
||||||
docs = vectorstore.similarity_search(query)
|
docs = vectorstore.similarity_search(query)
|
||||||
print(f"Query database for: {query}")
|
print(f"Query database for: {query}")
|
||||||
@ -23,16 +23,19 @@ def search_db(query):
|
|||||||
else:
|
else:
|
||||||
print("No matching documents")
|
print("No matching documents")
|
||||||
|
|
||||||
|
|
||||||
print("RAG database initialized.")
|
print("RAG database initialized.")
|
||||||
retriever = vectorstore.as_retriever()
|
retriever = vectorstore.as_retriever()
|
||||||
document_data_sources = set()
|
document_data_sources = set()
|
||||||
for doc_metadata in retriever.vectorstore.get()['metadatas']:
|
for doc_metadata in retriever.vectorstore.get()["metadatas"]:
|
||||||
print(f"docm {doc_metadata}")
|
print(f"docm {doc_metadata}")
|
||||||
document_data_sources.add(doc_metadata['sourceURL'])
|
document_data_sources.add(doc_metadata["sourceURL"])
|
||||||
for doc in document_data_sources:
|
for doc in document_data_sources:
|
||||||
print(f" {doc}")
|
print(f" {doc}")
|
||||||
|
|
||||||
print("This program queries documents in the RAG database that are similar to whatever is entered.")
|
print(
|
||||||
|
"This program queries documents in the RAG database that are similar to whatever is entered."
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
line = input(">> ")
|
line = input(">> ")
|
||||||
if line:
|
if line:
|
||||||
|
27
hw2/app.py
27
hw2/app.py
@ -8,7 +8,8 @@ from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from tools import lookup_ip, lookup_name, search_ksp
|
from tools import lookup_ip, lookup_name, search_ksp
|
||||||
#from langsmith import Client
|
|
||||||
|
# from langsmith import Client
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This is the main runner of the custom agent. Custom agent tools are defined seperatly and imported from tools.py
|
This is the main runner of the custom agent. Custom agent tools are defined seperatly and imported from tools.py
|
||||||
@ -19,27 +20,33 @@ Langsmith code can be uncommeted for testing/debugging
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
#os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||||
#os.environ["LANGCHAIN_PROJECT"] = f"LangSmith Introduction"
|
# os.environ["LANGCHAIN_PROJECT"] = f"LangSmith Introduction"
|
||||||
#os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
# os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||||
#client = Client()
|
# client = Client()
|
||||||
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
|
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
|
||||||
|
|
||||||
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm)
|
tools = load_tools(
|
||||||
|
["serpapi", "terminal", "dalle-image-generator", "google-jobs"],
|
||||||
|
allow_dangerous_tools=True,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
tools.extend([lookup_name, lookup_ip, search_ksp])
|
tools.extend([lookup_name, lookup_ip, search_ksp])
|
||||||
|
|
||||||
base_prompt = hub.pull("langchain-ai/react-agent-template")
|
base_prompt = hub.pull("langchain-ai/react-agent-template")
|
||||||
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls")
|
prompt = base_prompt.partial(
|
||||||
agent = create_react_agent(llm,tools,prompt)
|
instructions="Answer the user's request utilizing at most 8 tool calls"
|
||||||
|
)
|
||||||
|
agent = create_react_agent(llm, tools, prompt)
|
||||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||||
print("Welcome to my application. I am configured with these tools:")
|
print("Welcome to my application. I am configured with these tools:")
|
||||||
for tool in agent_executor.tools:
|
for tool in agent_executor.tools:
|
||||||
print(f' Tool: {tool.name} = {tool.description}')
|
print(f" Tool: {tool.name} = {tool.description}")
|
||||||
while True:
|
while True:
|
||||||
line = input("llm>> ")
|
line = input("llm>> ")
|
||||||
try:
|
try:
|
||||||
if line:
|
if line:
|
||||||
result = agent_executor.invoke({"input":line})
|
result = agent_executor.invoke({"input": line})
|
||||||
print(result)
|
print(result)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
35
hw2/tools.py
35
hw2/tools.py
@ -4,7 +4,8 @@ import dns.resolver, dns.reversename
|
|||||||
import validators
|
import validators
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from hw1.app import get_rag_chain
|
from hw1.app import get_rag_chain
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -14,34 +15,46 @@ These are the same tooling provided by the example in https://github.com/wu4f/cs
|
|||||||
with the addition of my Kerbal Space Program RAG Application tool
|
with the addition of my Kerbal Space Program RAG Application tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LookupNameInput(BaseModel):
|
class LookupNameInput(BaseModel):
|
||||||
hostname: str = Field(description="Should be a hostname such as www.google.com")
|
hostname: str = Field(description="Should be a hostname such as www.google.com")
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def is_dns_address(cls, values: dict[str,any]) -> str:
|
def is_dns_address(cls, values: dict[str, any]) -> str:
|
||||||
if validators.domain(values.get("hostname")):
|
if validators.domain(values.get("hostname")):
|
||||||
return values
|
return values
|
||||||
raise ValueError("Malformed hostname")
|
raise ValueError("Malformed hostname")
|
||||||
|
|
||||||
|
|
||||||
class LookupIPInput(BaseModel):
|
class LookupIPInput(BaseModel):
|
||||||
address: str = Field(description="Should be an IP address such as 208.91.197.27 or 143.95.239.83")
|
address: str = Field(
|
||||||
|
description="Should be an IP address such as 208.91.197.27 or 143.95.239.83"
|
||||||
|
)
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def is_ip_address(cls, values: dict[str,any]) -> str:
|
def is_ip_address(cls, values: dict[str, any]) -> str:
|
||||||
if validators.ip_address.ipv4(values.get("address")):
|
if validators.ip_address.ipv4(values.get("address")):
|
||||||
return values
|
return values
|
||||||
raise ValueError("Malformed IP address")
|
raise ValueError("Malformed IP address")
|
||||||
|
|
||||||
|
|
||||||
class KSPTool(BaseModel):
|
class KSPTool(BaseModel):
|
||||||
query: str = Field(description="should be a kerbal space program (ksp) related query")
|
query: str = Field(
|
||||||
|
description="should be a kerbal space program (ksp) related query"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
|
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
|
||||||
def search_ksp(query:str) -> str:
|
def search_ksp(query: str) -> str:
|
||||||
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
|
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
|
||||||
return get_rag_chain().invoke(query)
|
return get_rag_chain().invoke(query)
|
||||||
|
|
||||||
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
|
|
||||||
|
@tool("lookup_name", args_schema=LookupNameInput, return_direct=False)
|
||||||
def lookup_name(hostname):
|
def lookup_name(hostname):
|
||||||
"""Given a DNS hostname, it will return its IPv4 addresses"""
|
"""Given a DNS hostname, it will return its IPv4 addresses"""
|
||||||
result = dns.resolver.resolve(hostname, 'A')
|
result = dns.resolver.resolve(hostname, "A")
|
||||||
res = [ r.to_text() for r in result ]
|
res = [r.to_text() for r in result]
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +62,6 @@ def lookup_name(hostname):
|
|||||||
def lookup_ip(address):
|
def lookup_ip(address):
|
||||||
"""Given an IP address, returns names associated with it"""
|
"""Given an IP address, returns names associated with it"""
|
||||||
n = dns.reversename.from_address(address)
|
n = dns.reversename.from_address(address)
|
||||||
result = dns.resolver.resolve(n, 'PTR')
|
result = dns.resolver.resolve(n, "PTR")
|
||||||
res = [ r.to_text() for r in result ]
|
res = [r.to_text() for r in result]
|
||||||
return res[0]
|
return res[0]
|
||||||
|
17
hw3/notes.MD
Normal file
17
hw3/notes.MD
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Security testing
|
||||||
|
|
||||||
|
## LangChain RAG application (hw1)
|
||||||
|
### Indirect prompt injection
|
||||||
|
todo
|
||||||
|
### Insecure output handling
|
||||||
|
todo
|
||||||
|
### Data poisoning
|
||||||
|
todo
|
||||||
|
|
||||||
|
## LangChain agent (hw2)
|
||||||
|
### Excessive agency
|
||||||
|
todo
|
||||||
|
### Insecure tool design
|
||||||
|
todo
|
||||||
|
### Sensitive information exposure
|
||||||
|
todo
|
Reference in New Issue
Block a user