format all scripts; fix loader issue from hw1 firecrawl; start hw3

This commit is contained in:
David Westgate 2024-04-30 16:00:31 -07:00
parent 06dbb7fcf2
commit 266d5953cc
6 changed files with 129 additions and 75 deletions

View File

@ -15,9 +15,11 @@ I use the same rag-prompt since it's a good choice
load_dotenv() load_dotenv()
def format_docs(docs): def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs) return "\n\n".join(doc.page_content for doc in docs)
def get_rag_chain(): def get_rag_chain():
return ( return (
{"context": retriever | format_docs, "question": RunnablePassthrough()} {"context": retriever | format_docs, "question": RunnablePassthrough()}
@ -28,19 +30,20 @@ def get_rag_chain():
vectorstore = Chroma( vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(), embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
persist_directory="./rag_data/.chromadb"
) )
prompt = hub.pull("rlm/rag-prompt") prompt = hub.pull("rlm/rag-prompt")
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
llm = ChatOpenAI(model="gpt-4") llm = ChatOpenAI(model="gpt-4")
document_data_sources = set() document_data_sources = set()
for doc_metadata in retriever.vectorstore.get()['metadatas']: for doc_metadata in retriever.vectorstore.get()["metadatas"]:
document_data_sources.add(doc_metadata['sourceURL']) document_data_sources.add(doc_metadata["sourceURL"])
if __name__ == "__main__": if __name__ == "__main__":
print("Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions ") print(
"Welcome to the Kerbal Space Program RAG application. I will try to assist you with any questions "
)
while True: while True:
line = input("llm>> ") line = input("llm>> ")
if line: if line:
@ -48,4 +51,3 @@ if __name__ == "__main__" :
print(result) print(result)
else: else:
break break

View File

@ -1,7 +1,9 @@
from typing import List
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import FireCrawlLoader from langchain_community.document_loaders import FireCrawlLoader
from langchain_core.documents import Document
from dotenv import load_dotenv from dotenv import load_dotenv
""" """
@ -26,59 +28,69 @@ This takes a while to crawl, so just run it once and watch out for firecrawl cre
""" """
load_dotenv() load_dotenv()
page_options = {"onlyMainContent": True}
crawl_params = { crawl_params = {
'crawlerOptions': { "crawlerOptions": {
# Exclude non-english paths, image resources, etc. # Exclude non-english paths, image resources, etc.
'excludes': [ "excludes": [
'cs', "cs",
'da', "da",
'de', "de",
'es', "es",
'fi', "fi",
'fr', "fr",
'he', "he",
'hr', "hr",
'hu', "hu",
'it', "it",
'ja', "ja",
'ko', "ko",
'nl', "nl",
'no', "no",
'pl', "pl",
'pt', "pt",
'ru', "ru",
'sv', "sv",
'th', "th",
'tr', "tr",
'zh-cn' "zh-cn",
'.jpg', ".jpg",
'.png' ".png",
'.gif' ".gif",
], ],
'includes': ['wiki/*'], "includes": ["wiki/*"],
'limit': 75, #higher limit means more credits and more wait time. "limit": 75, # higher limit means more credits and more wait time.
} },
"pageOptions": {"onlyMainContent": True},
} }
loader = FireCrawlLoader("https://wiki.kerbalspaceprogram.com/wiki/Main_Page", mode="crawl", params=crawl_params) loader = FireCrawlLoader(
docs = loader.load() "https://wiki.kerbalspaceprogram.com/wiki/Main_Page",
print("docs loaded") mode="crawl",
params=crawl_params,
)
docs: List[Document] = loader.load()
# Split # Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs) splits: List[Document] = text_splitter.split_documents(docs)
print("split complete")
# This metadata incompatiblity issue should be resolved by the firecrawl maintiner (ogLocalaeAlternate is an empty list, not allowed by Chroma)
for doc in splits:
doc.metadata.pop("ogLocaleAlternate", None)
# Embed # Embed
vectorstore = Chroma.from_documents(documents=splits, vectorstore = Chroma.from_documents(
documents=splits,
embedding=OpenAIEmbeddings(), embedding=OpenAIEmbeddings(),
persist_directory="./rag_data/.chromadb") persist_directory="./rag_data/.chromadb",
)
print("RAG database initialized with the following sources.") print("RAG database initialized with the following sources.")
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
document_data_sources = set() document_data_sources = set()
for doc_metadata in retriever.vectorstore.get()['metadatas']: for doc_metadata in retriever.vectorstore.get()["metadatas"]:
document_data_sources.add(doc_metadata['sourceURL']) document_data_sources.add(doc_metadata["sourceURL"])
for doc in document_data_sources: for doc in document_data_sources:
print(f" {doc}") print(f" {doc}")

View File

@ -11,10 +11,10 @@ Adapted from https://github.com/wu4f/cs410g-src/blob/main/03_RAG/07_rag_docsearc
load_dotenv() load_dotenv()
vectorstore = Chroma( vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(), embedding_function=OpenAIEmbeddings(), persist_directory="./rag_data/.chromadb"
persist_directory="./rag_data/.chromadb"
) )
def search_db(query): def search_db(query):
docs = vectorstore.similarity_search(query) docs = vectorstore.similarity_search(query)
print(f"Query database for: {query}") print(f"Query database for: {query}")
@ -23,16 +23,19 @@ def search_db(query):
else: else:
print("No matching documents") print("No matching documents")
print("RAG database initialized.") print("RAG database initialized.")
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
document_data_sources = set() document_data_sources = set()
for doc_metadata in retriever.vectorstore.get()['metadatas']: for doc_metadata in retriever.vectorstore.get()["metadatas"]:
print(f"docm {doc_metadata}") print(f"docm {doc_metadata}")
document_data_sources.add(doc_metadata['sourceURL']) document_data_sources.add(doc_metadata["sourceURL"])
for doc in document_data_sources: for doc in document_data_sources:
print(f" {doc}") print(f" {doc}")
print("This program queries documents in the RAG database that are similar to whatever is entered.") print(
"This program queries documents in the RAG database that are similar to whatever is entered."
)
while True: while True:
line = input(">> ") line = input(">> ")
if line: if line:

View File

@ -8,6 +8,7 @@ from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
from dotenv import load_dotenv from dotenv import load_dotenv
from tools import lookup_ip, lookup_name, search_ksp from tools import lookup_ip, lookup_name, search_ksp
# from langsmith import Client # from langsmith import Client
""" """
@ -25,16 +26,22 @@ load_dotenv()
# client = Client() # client = Client()
llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0) llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0)
tools = load_tools(["serpapi", "terminal", "dalle-image-generator", "google-jobs"], allow_dangerous_tools=True, llm=llm) tools = load_tools(
["serpapi", "terminal", "dalle-image-generator", "google-jobs"],
allow_dangerous_tools=True,
llm=llm,
)
tools.extend([lookup_name, lookup_ip, search_ksp]) tools.extend([lookup_name, lookup_ip, search_ksp])
base_prompt = hub.pull("langchain-ai/react-agent-template") base_prompt = hub.pull("langchain-ai/react-agent-template")
prompt = base_prompt.partial(instructions="Answer the user's request utilizing at most 8 tool calls") prompt = base_prompt.partial(
instructions="Answer the user's request utilizing at most 8 tool calls"
)
agent = create_react_agent(llm, tools, prompt) agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
print("Welcome to my application. I am configured with these tools:") print("Welcome to my application. I am configured with these tools:")
for tool in agent_executor.tools: for tool in agent_executor.tools:
print(f' Tool: {tool.name} = {tool.description}') print(f" Tool: {tool.name} = {tool.description}")
while True: while True:
line = input("llm>> ") line = input("llm>> ")
try: try:

View File

@ -4,7 +4,8 @@ import dns.resolver, dns.reversename
import validators import validators
import sys import sys
import os import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from hw1.app import get_rag_chain from hw1.app import get_rag_chain
""" """
@ -14,33 +15,45 @@ These are the same tooling provided by the example in https://github.com/wu4f/cs
with the addition of my Kerbal Space Program RAG Application tool with the addition of my Kerbal Space Program RAG Application tool
""" """
class LookupNameInput(BaseModel): class LookupNameInput(BaseModel):
hostname: str = Field(description="Should be a hostname such as www.google.com") hostname: str = Field(description="Should be a hostname such as www.google.com")
@root_validator @root_validator
def is_dns_address(cls, values: dict[str, any]) -> str: def is_dns_address(cls, values: dict[str, any]) -> str:
if validators.domain(values.get("hostname")): if validators.domain(values.get("hostname")):
return values return values
raise ValueError("Malformed hostname") raise ValueError("Malformed hostname")
class LookupIPInput(BaseModel): class LookupIPInput(BaseModel):
address: str = Field(description="Should be an IP address such as 208.91.197.27 or 143.95.239.83") address: str = Field(
description="Should be an IP address such as 208.91.197.27 or 143.95.239.83"
)
@root_validator @root_validator
def is_ip_address(cls, values: dict[str, any]) -> str: def is_ip_address(cls, values: dict[str, any]) -> str:
if validators.ip_address.ipv4(values.get("address")): if validators.ip_address.ipv4(values.get("address")):
return values return values
raise ValueError("Malformed IP address") raise ValueError("Malformed IP address")
class KSPTool(BaseModel): class KSPTool(BaseModel):
query: str = Field(description="should be a kerbal space program (ksp) related query") query: str = Field(
description="should be a kerbal space program (ksp) related query"
)
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False) @tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
def search_ksp(query: str) -> str: def search_ksp(query: str) -> str:
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation""" """Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
return get_rag_chain().invoke(query) return get_rag_chain().invoke(query)
@tool("lookup_name", args_schema=LookupNameInput, return_direct=False) @tool("lookup_name", args_schema=LookupNameInput, return_direct=False)
def lookup_name(hostname): def lookup_name(hostname):
"""Given a DNS hostname, it will return its IPv4 addresses""" """Given a DNS hostname, it will return its IPv4 addresses"""
result = dns.resolver.resolve(hostname, 'A') result = dns.resolver.resolve(hostname, "A")
res = [r.to_text() for r in result] res = [r.to_text() for r in result]
return res[0] return res[0]
@ -49,6 +62,6 @@ def lookup_name(hostname):
def lookup_ip(address): def lookup_ip(address):
"""Given an IP address, returns names associated with it""" """Given an IP address, returns names associated with it"""
n = dns.reversename.from_address(address) n = dns.reversename.from_address(address)
result = dns.resolver.resolve(n, 'PTR') result = dns.resolver.resolve(n, "PTR")
res = [r.to_text() for r in result] res = [r.to_text() for r in result]
return res[0] return res[0]

17
hw3/notes.MD Normal file
View File

@ -0,0 +1,17 @@
# Security testing
## LangChain RAG application (hw1)
### Indirect prompt injection
todo
### Insecure output handling
todo
### Data poisoning
todo
## LangChain agent (hw2)
### Excessive agency
todo
### Insecure tool design
todo
### Sensitive information exposure
todo