This repository has been archived on 2025-04-28. You can view files and clone it, but cannot push or open issues or pull requests.
gensec-westgate-djw2/hw2/tools.py
2024-04-25 22:15:25 -07:00

55 lines
2.1 KiB
Python

from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools import tool
import dns.resolver, dns.reversename
import validators
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from hw1.app import get_rag_chain
"""
Tools to be run by my custom agent.
These are the same tooling provided by the example in https://github.com/wu4f/cs410g-src/blob/main/04_Agents/07_tools_custom_agent.py,
with the addition of my Kerbal Space Program RAG Application tool
"""
class LookupNameInput(BaseModel):
hostname: str = Field(description="Should be a hostname such as www.google.com")
@root_validator
def is_dns_address(cls, values: dict[str,any]) -> str:
if validators.domain(values.get("hostname")):
return values
raise ValueError("Malformed hostname")
class LookupIPInput(BaseModel):
address: str = Field(description="Should be an IP address such as 208.91.197.27 or 143.95.239.83")
@root_validator
def is_ip_address(cls, values: dict[str,any]) -> str:
if validators.ip_address.ipv4(values.get("address")):
return values
raise ValueError("Malformed IP address")
class KSPTool(BaseModel):
query: str = Field(description="should be a kerbal space program (ksp) related query")
@tool("kerbal_space_program_ksp_information", args_schema=KSPTool, return_direct=False)
def search_ksp(query:str) -> str:
"""Given a query about kerbal space program (ksp), it will send the query to the KSP rag applciation"""
return get_rag_chain().invoke(query)
@tool("lookup_name",args_schema=LookupNameInput, return_direct=False)
def lookup_name(hostname):
"""Given a DNS hostname, it will return its IPv4 addresses"""
result = dns.resolver.resolve(hostname, 'A')
res = [ r.to_text() for r in result ]
return res[0]
@tool("lookup_ip", args_schema=LookupIPInput, return_direct=False)
def lookup_ip(address):
"""Given an IP address, returns names associated with it"""
n = dns.reversename.from_address(address)
result = dns.resolver.resolve(n, 'PTR')
res = [ r.to_text() for r in result ]
return res[0]