(no commit message)
This commit is contained in:
236
agent/agent.py
Normal file
236
agent/agent.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
import dspy
|
||||
from typing import Literal, Optional, List
|
||||
from typing_extensions import TypedDict
|
||||
from .persana import PersanaClient, truncate_profiles
|
||||
|
||||
# Persana API Company Type Options
|
||||
CompanyType = Literal[
|
||||
"Public Company",
|
||||
"Educational",
|
||||
"Self Employed",
|
||||
"Government Agency",
|
||||
"Non Profit",
|
||||
"Self Owned",
|
||||
"Privately Held",
|
||||
"Partnership",
|
||||
]
|
||||
|
||||
|
||||
# used to store feedback on the search parameters
|
||||
class SearchFeedback(TypedDict):
|
||||
search_parameters: dict
|
||||
feedback: Optional[str]
|
||||
|
||||
|
||||
# Defines the prompt signature for the agent that takes in the user's company description
|
||||
# and target customer description and runs a search on the Persana API
|
||||
class ICPTOSearchSignature(dspy.Signature):
|
||||
"""
|
||||
You are a marketing agent that comes up with the ideal people search paramaters for Persana given a
|
||||
company description and a target customer description. The search API allows the following parameters:
|
||||
- Included Job Titles
|
||||
- Excluded Job Titles
|
||||
- Included Companies
|
||||
- Excluded Companies
|
||||
- Company Type
|
||||
- Company Include Keywords
|
||||
- Company Exclude Keywords
|
||||
- Included Industries
|
||||
- Excluded Industries
|
||||
|
||||
You may use all or some of these parameters to craft your search. It is recommended not to use all parameters unless necessary.
|
||||
"""
|
||||
|
||||
company_description: str = dspy.InputField(
|
||||
description="The description of the company"
|
||||
)
|
||||
target_customer: str = dspy.InputField(
|
||||
description="The target customer of the company"
|
||||
)
|
||||
feedback_history: List[SearchFeedback] = dspy.InputField(
|
||||
description="Feedback on previous searches"
|
||||
)
|
||||
tools: list[dspy.Tool] = dspy.InputField()
|
||||
tool_calls: dspy.ToolCalls = dspy.OutputField()
|
||||
|
||||
|
||||
# Defines the prompt signature for the agent that takes in the search parameters, its search results,
|
||||
# and optional user feedback on the search and generates feedback for the AI agent that generated the search parameters
|
||||
class FeedbackSignature(dspy.Signature):
|
||||
"""
|
||||
You are a helpful assistant that helps a user give feedback to an AI agent that generates search parameters for a Persana API,
|
||||
given a company description and a target customer description from the user. Use the search result profiles the user selected, along with the
|
||||
results the user deselected, along with the feedback from the user to give feedback to the AI agent.
|
||||
"""
|
||||
|
||||
search_parameters: dict = dspy.InputField(
|
||||
description="The search parameters for the Persana API"
|
||||
)
|
||||
|
||||
selected_profiles: List[dict] = dspy.InputField(
|
||||
description="The profiles the user selected"
|
||||
)
|
||||
unselected_profiles: List[dict] = dspy.InputField(
|
||||
description="The profiles the user did not select"
|
||||
)
|
||||
user_feedback: Optional[str] = dspy.InputField(
|
||||
description="Feedback from the user on the previous search"
|
||||
)
|
||||
|
||||
feedback: str = dspy.OutputField(
|
||||
description="Feedback to the AI agent that generated the search parameters"
|
||||
)
|
||||
|
||||
|
||||
# Config for main agent, handles performance-altering parameters for the agent
|
||||
class PersanaConfig(PrecompiledConfig):
|
||||
model: str = "anthropic/claude-3-5-haiku-20241022"
|
||||
|
||||
|
||||
class FeedbackDict(TypedDict):
|
||||
selected_profiles: List[dict]
|
||||
unselected_profiles: List[dict]
|
||||
user_feedback: Optional[str]
|
||||
|
||||
|
||||
# Main agent that handles the search and feedback
|
||||
class PersanaSearcher(PrecompiledAgent):
|
||||
config: PersanaConfig # ! important, used to link the config class to the agent
|
||||
|
||||
feedback_history: List[SearchFeedback] = []
|
||||
search_parameters: Optional[dict] = None
|
||||
response: Optional[dspy.Prediction] = (
|
||||
None # stores last response (used for debugging purposes)
|
||||
)
|
||||
|
||||
# Constructor, all PrecompiledAgent subclasses must have a constructor that takes in a config.
|
||||
# You can optionally take in other variables specific to the environment
|
||||
# It is recommended to track any performance-altering variables (model, temperature, etc.) in the config, not the constructor
|
||||
def __init__(self, config: PersanaConfig, api_key: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
self.client = PersanaClient(api_key=api_key)
|
||||
self.search_creator = dspy.Predict(ICPTOSearchSignature)
|
||||
self.search_creator.set_lm(dspy.LM(config.model))
|
||||
self.feedback_creator = dspy.Predict(FeedbackSignature)
|
||||
self.feedback_creator.set_lm(dspy.LM(config.model))
|
||||
|
||||
# all Modules must define a forward, contains the logic for the module. Called when you run PrecompiledAgent.__call__
|
||||
def forward(
|
||||
self,
|
||||
company_description: str,
|
||||
target_customer: str,
|
||||
feedback_dict: Optional[FeedbackDict] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Runs a search on the Persana API with the given company description and target customer description.
|
||||
Also saves the search parameters to self.search_parameters so they can be used for feedback later
|
||||
Args:
|
||||
company_description: The description of the company
|
||||
target_customer: The target customer of the company
|
||||
Returns:
|
||||
The profiles returned from the search
|
||||
"""
|
||||
feedback_response = None
|
||||
if feedback_dict is not None:
|
||||
feedback_response = self.enter_feedback(
|
||||
selected_profiles=feedback_dict["selected_profiles"],
|
||||
unselected_profiles=feedback_dict["unselected_profiles"],
|
||||
user_feedback=feedback_dict["user_feedback"],
|
||||
).feedback
|
||||
|
||||
# retry up to 10 times if model generates invalid search parameters
|
||||
search_tool = dspy.Tool(self.client.people_search)
|
||||
for _ in range(10):
|
||||
response = self.search_creator(
|
||||
company_description=company_description,
|
||||
target_customer=target_customer,
|
||||
feedback_history=self.feedback_history,
|
||||
tools=[search_tool],
|
||||
)
|
||||
tool_call = response.tool_calls.tool_calls[0]
|
||||
self.search_parameters = tool_call.args
|
||||
try:
|
||||
profiles = search_tool(**tool_call.args)
|
||||
self.response = response
|
||||
if len(profiles) == 0:
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback="No profiles found. Try loosening the search parameters",
|
||||
)
|
||||
)
|
||||
continue
|
||||
return dspy.Prediction(
|
||||
search_parameters=self.search_parameters,
|
||||
profiles=profiles,
|
||||
feedback=feedback_response,
|
||||
)
|
||||
except TypeError as e:
|
||||
feedback = str(e)
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback=feedback,
|
||||
)
|
||||
)
|
||||
raise Exception("Model failed to generate valid search parameters")
|
||||
|
||||
def enter_feedback(
|
||||
self,
|
||||
selected_profiles: List[dict],
|
||||
unselected_profiles: List[dict],
|
||||
user_feedback: Optional[str] = None,
|
||||
) -> dspy.Prediction:
|
||||
"""
|
||||
Allows user to enter feedback on the previous search results
|
||||
Args:
|
||||
selected_profiles: The profiles the user selected
|
||||
unselected_profiles: The profiles the user did not select
|
||||
user_feedback: User's feedback on the previous search
|
||||
Returns:
|
||||
The feedback Prediction object (for debugging purposes)
|
||||
"""
|
||||
feedback_response = self.feedback_creator(
|
||||
search_parameters=self.search_parameters,
|
||||
selected_profiles=truncate_profiles(selected_profiles),
|
||||
unselected_profiles=truncate_profiles(unselected_profiles),
|
||||
user_feedback=user_feedback,
|
||||
)
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback=feedback_response.feedback,
|
||||
)
|
||||
)
|
||||
return feedback_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
||||
results = searcher(
|
||||
company_description="A company that makes software for businesses",
|
||||
target_customer="A business that needs software",
|
||||
)
|
||||
|
||||
print(results)
|
||||
searcher.enter_feedback(
|
||||
selected_profiles=results[:-1],
|
||||
unselected_profiles=[results[-1]],
|
||||
user_feedback="Look for businesses that would need mobile apps",
|
||||
)
|
||||
results = searcher(
|
||||
company_description="A company that makes software for businesses",
|
||||
target_customer="A business that needs software",
|
||||
)
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print(results)
|
||||
print(searcher.feedback_history)
|
||||
print(searcher.search_parameters)
|
||||
Reference in New Issue
Block a user