246 lines
9.4 KiB
Python
246 lines
9.4 KiB
Python
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
import dspy
|
|
from typing import Literal, Optional, List
|
|
from typing_extensions import TypedDict
|
|
from .persana import PersanaClient, truncate_profiles
|
|
|
|
# Persana API Company Type Options
|
|
CompanyType = Literal[
|
|
"Public Company",
|
|
"Educational",
|
|
"Self Employed",
|
|
"Government Agency",
|
|
"Non Profit",
|
|
"Self Owned",
|
|
"Privately Held",
|
|
"Partnership",
|
|
]
|
|
|
|
|
|
# used to store feedback on the search parameters
|
|
class SearchFeedback(TypedDict):
|
|
search_parameters: dict
|
|
feedback: Optional[str]
|
|
|
|
|
|
# Defines the prompt signature for the agent that takes in the user's company description
|
|
# and target customer description and runs a search on the Persana API
|
|
class ICPTOSearchSignature(dspy.Signature):
|
|
"""
|
|
You are a marketing agent that comes up with the ideal people search paramaters for Persana given a
|
|
company description and a target customer description. The search API allows the following parameters:
|
|
- Included Job Titles
|
|
- Excluded Job Titles
|
|
- Included Companies
|
|
- Excluded Companies
|
|
- Company Type
|
|
- Company Include Keywords
|
|
- Company Exclude Keywords
|
|
- Included Industries
|
|
- Excluded Industries
|
|
|
|
You may use all or some of these parameters to craft your search. It is recommended not to use all parameters unless necessary.
|
|
"""
|
|
|
|
company_description: str = dspy.InputField(
|
|
description="The description of the company"
|
|
)
|
|
target_customer: str = dspy.InputField(
|
|
description="The target customer of the company"
|
|
)
|
|
feedback_history: List[SearchFeedback] = dspy.InputField(
|
|
description="Feedback on previous searches"
|
|
)
|
|
tools: list[dspy.Tool] = dspy.InputField()
|
|
tool_calls: dspy.ToolCalls = dspy.OutputField()
|
|
|
|
|
|
# Defines the prompt signature for the agent that takes in the search parameters, its search results,
|
|
# and optional user feedback on the search and generates feedback for the AI agent that generated the search parameters
|
|
class FeedbackSignature(dspy.Signature):
|
|
"""
|
|
You are a helpful assistant that helps a user give feedback to an AI agent that generates search parameters for a Persana API,
|
|
given a company description and a target customer description from the user. Use the search result profiles the user selected, along with the
|
|
results the user deselected, along with the feedback from the user to give feedback to the AI agent.
|
|
"""
|
|
|
|
search_parameters: dict = dspy.InputField(
|
|
description="The search parameters for the Persana API"
|
|
)
|
|
|
|
selected_profiles: List[dict] = dspy.InputField(
|
|
description="The profiles the user selected"
|
|
)
|
|
unselected_profiles: List[dict] = dspy.InputField(
|
|
description="The profiles the user did not select"
|
|
)
|
|
user_feedback: Optional[str] = dspy.InputField(
|
|
description="Feedback from the user on the previous search"
|
|
)
|
|
|
|
feedback: str = dspy.OutputField(
|
|
description="Feedback to the AI agent that generated the search parameters"
|
|
)
|
|
|
|
|
|
# Config for main agent, handles performance-altering parameters for the agent
|
|
class PersanaConfig(PrecompiledConfig):
|
|
model: str = "anthropic/claude-3-5-haiku-20241022"
|
|
|
|
|
|
class FeedbackDict(TypedDict):
|
|
selected_profiles: List[dict]
|
|
unselected_profiles: List[dict]
|
|
user_feedback: Optional[str]
|
|
|
|
|
|
# Main agent that handles the search and feedback
|
|
class PersanaSearcher(PrecompiledAgent):
|
|
config: PersanaConfig # ! important, used to link the config class to the agent
|
|
|
|
feedback_history: List[SearchFeedback] = []
|
|
search_parameters: Optional[dict] = None
|
|
response: Optional[dspy.Prediction] = (
|
|
None # stores last response (used for debugging purposes)
|
|
)
|
|
|
|
# Constructor, all PrecompiledAgent subclasses must have a constructor that takes in a config.
|
|
# You can optionally take in other variables specific to the environment
|
|
# It is recommended to track any performance-altering variables (model, temperature, etc.) in the config, not the constructor
|
|
def __init__(
|
|
self, config: PersanaConfig, api_key: Optional[str] = None, train: bool = False
|
|
):
|
|
super().__init__(config)
|
|
self.client = PersanaClient(api_key=api_key)
|
|
self.search_creator = dspy.Predict(ICPTOSearchSignature)
|
|
self.search_creator.set_lm(dspy.LM(config.model))
|
|
self.feedback_creator = dspy.Predict(FeedbackSignature)
|
|
self.feedback_creator.set_lm(dspy.LM(config.model))
|
|
self.train = train
|
|
|
|
# all Modules must define a forward, contains the logic for the module. Called when you run PrecompiledAgent.__call__
|
|
def forward(
|
|
self,
|
|
company_description: str,
|
|
target_customer: str,
|
|
feedback_dict: Optional[FeedbackDict] = None,
|
|
) -> List[dict]:
|
|
"""
|
|
Runs a search on the Persana API with the given company description and target customer description.
|
|
Also saves the search parameters to self.search_parameters so they can be used for feedback later
|
|
Args:
|
|
company_description: The description of the company
|
|
target_customer: The target customer of the company
|
|
Returns:
|
|
The profiles returned from the search
|
|
"""
|
|
feedback_response = None
|
|
if feedback_dict is not None:
|
|
feedback_response = self.enter_feedback(
|
|
selected_profiles=feedback_dict["selected_profiles"],
|
|
unselected_profiles=feedback_dict["unselected_profiles"],
|
|
user_feedback=feedback_dict["user_feedback"],
|
|
).feedback
|
|
|
|
# retry up to 10 times if model generates invalid search parameters
|
|
search_tool = dspy.Tool(self.client.people_search)
|
|
for _ in range(10):
|
|
response = self.search_creator(
|
|
company_description=company_description,
|
|
target_customer=target_customer,
|
|
feedback_history=self.feedback_history,
|
|
tools=[search_tool],
|
|
)
|
|
tool_call = response.tool_calls.tool_calls[0]
|
|
self.search_parameters = tool_call.args
|
|
if self.train:
|
|
return dspy.Prediction(
|
|
search_parameters=self.search_parameters,
|
|
profiles=None,
|
|
feedback=feedback_response,
|
|
)
|
|
try:
|
|
profiles = search_tool(**tool_call.args)
|
|
self.response = response
|
|
if len(profiles) == 0:
|
|
self.feedback_history.append(
|
|
SearchFeedback(
|
|
search_parameters=self.search_parameters,
|
|
feedback="No profiles found. Try loosening the search parameters",
|
|
)
|
|
)
|
|
continue
|
|
return dspy.Prediction(
|
|
search_parameters=self.search_parameters,
|
|
profiles=profiles,
|
|
feedback=feedback_response,
|
|
)
|
|
except TypeError as e:
|
|
feedback = str(e)
|
|
self.feedback_history.append(
|
|
SearchFeedback(
|
|
search_parameters=self.search_parameters,
|
|
feedback=feedback,
|
|
)
|
|
)
|
|
raise Exception("Model failed to generate valid search parameters")
|
|
|
|
def enter_feedback(
|
|
self,
|
|
selected_profiles: List[dict],
|
|
unselected_profiles: List[dict],
|
|
user_feedback: Optional[str] = None,
|
|
) -> dspy.Prediction:
|
|
"""
|
|
Allows user to enter feedback on the previous search results
|
|
Args:
|
|
selected_profiles: The profiles the user selected
|
|
unselected_profiles: The profiles the user did not select
|
|
user_feedback: User's feedback on the previous search
|
|
Returns:
|
|
The feedback Prediction object (for debugging purposes)
|
|
"""
|
|
feedback_response = self.feedback_creator(
|
|
search_parameters=self.search_parameters,
|
|
selected_profiles=truncate_profiles(selected_profiles),
|
|
unselected_profiles=truncate_profiles(unselected_profiles),
|
|
user_feedback=user_feedback,
|
|
)
|
|
self.feedback_history.append(
|
|
SearchFeedback(
|
|
search_parameters=self.search_parameters,
|
|
feedback=feedback_response.feedback,
|
|
)
|
|
)
|
|
return feedback_response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
results = searcher(
|
|
company_description="A company that makes software for businesses",
|
|
target_customer="A business that needs software",
|
|
)
|
|
|
|
print(results)
|
|
searcher.enter_feedback(
|
|
selected_profiles=results[:-1],
|
|
unselected_profiles=[results[-1]],
|
|
user_feedback="Look for businesses that would need mobile apps",
|
|
)
|
|
results = searcher(
|
|
company_description="A company that makes software for businesses",
|
|
target_customer="A business that needs software",
|
|
)
|
|
print("--------------------------------")
|
|
print("--------------------------------")
|
|
print("--------------------------------")
|
|
print("--------------------------------")
|
|
print(results)
|
|
print(searcher.feedback_history)
|
|
print(searcher.search_parameters)
|