(no commit message)
This commit is contained in:
@@ -1,2 +0,0 @@
|
|||||||
PERSANA_KEY=
|
|
||||||
ANTHROPIC_API_KEY=
|
|
||||||
148
README.md
148
README.md
@@ -0,0 +1,148 @@
|
|||||||
|
# Persana Lead Gen Agent
|
||||||
|
|
||||||
|
Uses a process of human in the loop iterative refinement search to find leads given a company description and target customer description.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### With uv (preferred)
|
||||||
|
|
||||||
|
1. Create a new folder
|
||||||
|
2. [Install uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||||
|
3. Init workspace
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv init
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add textual rich modaic dspy python-dotenv
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Run the file
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### With pip
|
||||||
|
|
||||||
|
1. Copy `main.py` to a new workspace folder.
|
||||||
|
2. Create a `.env` file with your API keys using the `.example.env` file.
|
||||||
|
3. Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install textual rich modaic dspy python-dotenv
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the file with `python main.py`.
|
||||||
|
|
||||||
|
Follow the prompts to create a new `dataset.jsonl` file.
|
||||||
|
|
||||||
|
## Run Prompt Optimization
|
||||||
|
|
||||||
|
Once you have a `dataset.jsonl` file, you can optimize the agent with dspy's built in prompt optimization.
|
||||||
|
|
||||||
|
1. Create a file called `compile.py` with the following code. Replace `<your-username>` with your modaic username.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
import dspy
|
||||||
|
import json
|
||||||
|
from dspy import Prediction, Example
|
||||||
|
from modaic import AutoAgent
|
||||||
|
|
||||||
|
searcher = AutoAgent.from_precompiled("swagginty/persana-lead-gen", api_key=os.getenv("PERSANA_KEY"))
|
||||||
|
feedback_creator = searcher.feedback_creator
|
||||||
|
|
||||||
|
|
||||||
|
class SearchExample(Example):
|
||||||
|
company_description: str
|
||||||
|
target_customer: str
|
||||||
|
selected_profiles: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchPrediction(Prediction):
|
||||||
|
profiles: list[dict]
|
||||||
|
search_parameters: dict
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_results(
|
||||||
|
target: SearchExample,
|
||||||
|
predictied: SearchPrediction,
|
||||||
|
trace=None,
|
||||||
|
pred_name=None,
|
||||||
|
pred_trace=None,
|
||||||
|
) -> Prediction:
|
||||||
|
"""
|
||||||
|
Evaluates the search results target results were retrieved
|
||||||
|
"""
|
||||||
|
# How many of the target profiles were retrieved
|
||||||
|
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
||||||
|
count = 0
|
||||||
|
for t_result in target.selected_profiles:
|
||||||
|
if t_result["profile_id"] in pred_ids:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
score = count / len(target.selected_profiles)
|
||||||
|
|
||||||
|
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
||||||
|
# Which retrieved profiles were target profiles
|
||||||
|
selected_preds = [
|
||||||
|
result for result in predictied.profiles if result["profile_id"] in target_ids
|
||||||
|
]
|
||||||
|
# Which retrieved profiles were not target profiles
|
||||||
|
unselected_preds = [
|
||||||
|
result
|
||||||
|
for result in predictied.profiles
|
||||||
|
if result["profile_id"] not in target_ids
|
||||||
|
]
|
||||||
|
# Resuse feedback creator to get feedback for prompt creation
|
||||||
|
feedback = feedback_creator(
|
||||||
|
search_parameters=predictied.search_parameters,
|
||||||
|
selected_profiles=selected_preds,
|
||||||
|
unselected_profiles=unselected_preds,
|
||||||
|
user_feedback=None,
|
||||||
|
).feedback
|
||||||
|
return Prediction(
|
||||||
|
score=score,
|
||||||
|
feedback=feedback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
||||||
|
trainset = [
|
||||||
|
dspy.Example(
|
||||||
|
company_description=e["company_description"],
|
||||||
|
target_customer=e["target_customer"],
|
||||||
|
selected_profiles=e["selected_profiles"],
|
||||||
|
).with_inputs("company_description", "target_customer")
|
||||||
|
for e in data
|
||||||
|
]
|
||||||
|
compiler = dspy.GEPA(
|
||||||
|
metric=evaluate_results,
|
||||||
|
auto="light",
|
||||||
|
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
||||||
|
)
|
||||||
|
compiled_searcher = compiler.compile(searcher, trainset=trainset)
|
||||||
|
compiled_searcher.save("compiled_searcher.json")
|
||||||
|
compiled_searcher.push_to_hub("<your-username>/persana-lead-gen") # Replace <your-username> with your username
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the file
|
||||||
|
With uv:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run compile.py
|
||||||
|
```
|
||||||
|
|
||||||
|
With python:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python compile.py
|
||||||
|
```
|
||||||
|
|||||||
236
agent/agent.py
236
agent/agent.py
@@ -1,236 +0,0 @@
|
|||||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
||||||
import dspy
|
|
||||||
from typing import Literal, Optional, List
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from .persana import PersanaClient, truncate_profiles
|
|
||||||
|
|
||||||
# Persana API Company Type Options
|
|
||||||
CompanyType = Literal[
|
|
||||||
"Public Company",
|
|
||||||
"Educational",
|
|
||||||
"Self Employed",
|
|
||||||
"Government Agency",
|
|
||||||
"Non Profit",
|
|
||||||
"Self Owned",
|
|
||||||
"Privately Held",
|
|
||||||
"Partnership",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# used to store feedback on the search parameters
|
|
||||||
class SearchFeedback(TypedDict):
|
|
||||||
search_parameters: dict
|
|
||||||
feedback: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
# Defines the prompt signature for the agent that takes in the user's company description
|
|
||||||
# and target customer description and runs a search on the Persana API
|
|
||||||
class ICPTOSearchSignature(dspy.Signature):
|
|
||||||
"""
|
|
||||||
You are a marketing agent that comes up with the ideal people search paramaters for Persana given a
|
|
||||||
company description and a target customer description. The search API allows the following parameters:
|
|
||||||
- Included Job Titles
|
|
||||||
- Excluded Job Titles
|
|
||||||
- Included Companies
|
|
||||||
- Excluded Companies
|
|
||||||
- Company Type
|
|
||||||
- Company Include Keywords
|
|
||||||
- Company Exclude Keywords
|
|
||||||
- Included Industries
|
|
||||||
- Excluded Industries
|
|
||||||
|
|
||||||
You may use all or some of these parameters to craft your search. It is recommended not to use all parameters unless necessary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
company_description: str = dspy.InputField(
|
|
||||||
description="The description of the company"
|
|
||||||
)
|
|
||||||
target_customer: str = dspy.InputField(
|
|
||||||
description="The target customer of the company"
|
|
||||||
)
|
|
||||||
feedback_history: List[SearchFeedback] = dspy.InputField(
|
|
||||||
description="Feedback on previous searches"
|
|
||||||
)
|
|
||||||
tools: list[dspy.Tool] = dspy.InputField()
|
|
||||||
tool_calls: dspy.ToolCalls = dspy.OutputField()
|
|
||||||
|
|
||||||
|
|
||||||
# Defines the prompt signature for the agent that takes in the search parameters, its search results,
|
|
||||||
# and optional user feedback on the search and generates feedback for the AI agent that generated the search parameters
|
|
||||||
class FeedbackSignature(dspy.Signature):
|
|
||||||
"""
|
|
||||||
You are a helpful assistant that helps a user give feedback to an AI agent that generates search parameters for a Persana API,
|
|
||||||
given a company description and a target customer description from the user. Use the search result profiles the user selected, along with the
|
|
||||||
results the user deselected, along with the feedback from the user to give feedback to the AI agent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
search_parameters: dict = dspy.InputField(
|
|
||||||
description="The search parameters for the Persana API"
|
|
||||||
)
|
|
||||||
|
|
||||||
selected_profiles: List[dict] = dspy.InputField(
|
|
||||||
description="The profiles the user selected"
|
|
||||||
)
|
|
||||||
unselected_profiles: List[dict] = dspy.InputField(
|
|
||||||
description="The profiles the user did not select"
|
|
||||||
)
|
|
||||||
user_feedback: Optional[str] = dspy.InputField(
|
|
||||||
description="Feedback from the user on the previous search"
|
|
||||||
)
|
|
||||||
|
|
||||||
feedback: str = dspy.OutputField(
|
|
||||||
description="Feedback to the AI agent that generated the search parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Config for main agent, handles performance-altering parameters for the agent
|
|
||||||
class PersanaConfig(PrecompiledConfig):
|
|
||||||
model: str = "anthropic/claude-3-5-haiku-20241022"
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackDict(TypedDict):
|
|
||||||
selected_profiles: List[dict]
|
|
||||||
unselected_profiles: List[dict]
|
|
||||||
user_feedback: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
# Main agent that handles the search and feedback
|
|
||||||
class PersanaSearcher(PrecompiledAgent):
|
|
||||||
config: PersanaConfig # ! important, used to link the config class to the agent
|
|
||||||
|
|
||||||
feedback_history: List[SearchFeedback] = []
|
|
||||||
search_parameters: Optional[dict] = None
|
|
||||||
response: Optional[dspy.Prediction] = (
|
|
||||||
None # stores last response (used for debugging purposes)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Constructor, all PrecompiledAgent subclasses must have a constructor that takes in a config.
|
|
||||||
# You can optionally take in other variables specific to the environment
|
|
||||||
# It is recommended to track any performance-altering variables (model, temperature, etc.) in the config, not the constructor
|
|
||||||
def __init__(self, config: PersanaConfig, api_key: Optional[str] = None):
|
|
||||||
super().__init__(config)
|
|
||||||
self.client = PersanaClient(api_key=api_key)
|
|
||||||
self.search_creator = dspy.Predict(ICPTOSearchSignature)
|
|
||||||
self.search_creator.set_lm(dspy.LM(config.model))
|
|
||||||
self.feedback_creator = dspy.Predict(FeedbackSignature)
|
|
||||||
self.feedback_creator.set_lm(dspy.LM(config.model))
|
|
||||||
|
|
||||||
# all Modules must define a forward, contains the logic for the module. Called when you run PrecompiledAgent.__call__
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
company_description: str,
|
|
||||||
target_customer: str,
|
|
||||||
feedback_dict: Optional[FeedbackDict] = None,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""
|
|
||||||
Runs a search on the Persana API with the given company description and target customer description.
|
|
||||||
Also saves the search parameters to self.search_parameters so they can be used for feedback later
|
|
||||||
Args:
|
|
||||||
company_description: The description of the company
|
|
||||||
target_customer: The target customer of the company
|
|
||||||
Returns:
|
|
||||||
The profiles returned from the search
|
|
||||||
"""
|
|
||||||
feedback_response = None
|
|
||||||
if feedback_dict is not None:
|
|
||||||
feedback_response = self.enter_feedback(
|
|
||||||
selected_profiles=feedback_dict["selected_profiles"],
|
|
||||||
unselected_profiles=feedback_dict["unselected_profiles"],
|
|
||||||
user_feedback=feedback_dict["user_feedback"],
|
|
||||||
).feedback
|
|
||||||
|
|
||||||
# retry up to 10 times if model generates invalid search parameters
|
|
||||||
search_tool = dspy.Tool(self.client.people_search)
|
|
||||||
for _ in range(10):
|
|
||||||
response = self.search_creator(
|
|
||||||
company_description=company_description,
|
|
||||||
target_customer=target_customer,
|
|
||||||
feedback_history=self.feedback_history,
|
|
||||||
tools=[search_tool],
|
|
||||||
)
|
|
||||||
tool_call = response.tool_calls.tool_calls[0]
|
|
||||||
self.search_parameters = tool_call.args
|
|
||||||
try:
|
|
||||||
profiles = search_tool(**tool_call.args)
|
|
||||||
self.response = response
|
|
||||||
if len(profiles) == 0:
|
|
||||||
self.feedback_history.append(
|
|
||||||
SearchFeedback(
|
|
||||||
search_parameters=self.search_parameters,
|
|
||||||
feedback="No profiles found. Try loosening the search parameters",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
return dspy.Prediction(
|
|
||||||
search_parameters=self.search_parameters,
|
|
||||||
profiles=profiles,
|
|
||||||
feedback=feedback_response,
|
|
||||||
)
|
|
||||||
except TypeError as e:
|
|
||||||
feedback = str(e)
|
|
||||||
self.feedback_history.append(
|
|
||||||
SearchFeedback(
|
|
||||||
search_parameters=self.search_parameters,
|
|
||||||
feedback=feedback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise Exception("Model failed to generate valid search parameters")
|
|
||||||
|
|
||||||
def enter_feedback(
|
|
||||||
self,
|
|
||||||
selected_profiles: List[dict],
|
|
||||||
unselected_profiles: List[dict],
|
|
||||||
user_feedback: Optional[str] = None,
|
|
||||||
) -> dspy.Prediction:
|
|
||||||
"""
|
|
||||||
Allows user to enter feedback on the previous search results
|
|
||||||
Args:
|
|
||||||
selected_profiles: The profiles the user selected
|
|
||||||
unselected_profiles: The profiles the user did not select
|
|
||||||
user_feedback: User's feedback on the previous search
|
|
||||||
Returns:
|
|
||||||
The feedback Prediction object (for debugging purposes)
|
|
||||||
"""
|
|
||||||
feedback_response = self.feedback_creator(
|
|
||||||
search_parameters=self.search_parameters,
|
|
||||||
selected_profiles=truncate_profiles(selected_profiles),
|
|
||||||
unselected_profiles=truncate_profiles(unselected_profiles),
|
|
||||||
user_feedback=user_feedback,
|
|
||||||
)
|
|
||||||
self.feedback_history.append(
|
|
||||||
SearchFeedback(
|
|
||||||
search_parameters=self.search_parameters,
|
|
||||||
feedback=feedback_response.feedback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return feedback_response
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
||||||
results = searcher(
|
|
||||||
company_description="A company that makes software for businesses",
|
|
||||||
target_customer="A business that needs software",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(results)
|
|
||||||
searcher.enter_feedback(
|
|
||||||
selected_profiles=results[:-1],
|
|
||||||
unselected_profiles=[results[-1]],
|
|
||||||
user_feedback="Look for businesses that would need mobile apps",
|
|
||||||
)
|
|
||||||
results = searcher(
|
|
||||||
company_description="A company that makes software for businesses",
|
|
||||||
target_customer="A business that needs software",
|
|
||||||
)
|
|
||||||
print("--------------------------------")
|
|
||||||
print("--------------------------------")
|
|
||||||
print("--------------------------------")
|
|
||||||
print("--------------------------------")
|
|
||||||
print(results)
|
|
||||||
print(searcher.feedback_history)
|
|
||||||
print(searcher.search_parameters)
|
|
||||||
113
agent/persana.py
113
agent/persana.py
@@ -1,113 +0,0 @@
|
|||||||
import requests
|
|
||||||
from typing import Optional, Literal, List
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
CompanyType = Literal[
|
|
||||||
"Public Company",
|
|
||||||
"Educational",
|
|
||||||
"Self Employed",
|
|
||||||
"Government Agency",
|
|
||||||
"Non Profit",
|
|
||||||
"Self Owned",
|
|
||||||
"Privately Held",
|
|
||||||
"Partnership",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class PersanaClient:
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
|
||||||
self.api_key = api_key or os.getenv("PERSANA_KEY")
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("PERSANA_KEY is not set")
|
|
||||||
|
|
||||||
def people_search(
|
|
||||||
self,
|
|
||||||
include_job_titles: Optional[list[str]] = None,
|
|
||||||
exclude_job_titles: Optional[list[str]] = None,
|
|
||||||
include_companies: Optional[list[str]] = None,
|
|
||||||
exclude_companies: Optional[list[str]] = None,
|
|
||||||
company_types: Optional[list[CompanyType]] = None,
|
|
||||||
company_include_keywords: Optional[list[str]] = None,
|
|
||||||
company_exclude_keywords: Optional[list[str]] = None,
|
|
||||||
include_industries: Optional[list[str]] = None,
|
|
||||||
exclude_industries: Optional[list[str]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Runs a persana people search with the given filter parameters and returns the results
|
|
||||||
Args:
|
|
||||||
include_job_titles: The job titles to include in the search. (Optional) should stay set to None unless requested
|
|
||||||
exclude_job_titles: The job titles to exclude from the search. (Optional) should stay set to None unless requested
|
|
||||||
include_companies: The companies to include in the search. (Optional) should stay set to None unless requested
|
|
||||||
exclude_companies: The companies to exclude from the search. (Optional) should stay set to None unless requested
|
|
||||||
company_type: The type of the company. (Optional) should stay set to None unless requested
|
|
||||||
company_include_keywords: The keywords to use to search for the company. (Optional) should stay set to None unless requested
|
|
||||||
company_exclude_keywords: The keywords to exclude from the search for the company. (Optional) should stay set to None unless requested
|
|
||||||
include_industries: The industries to include in the search. (Optional) should stay set to None unless requested
|
|
||||||
exclude_industries: The industries to exclude from the search. (Optional) should stay set to None unless requested
|
|
||||||
Returns:
|
|
||||||
The results of the people search
|
|
||||||
"""
|
|
||||||
params = {
|
|
||||||
"title_includes": include_job_titles,
|
|
||||||
"title_excludes": exclude_job_titles,
|
|
||||||
"companies_includes": include_companies,
|
|
||||||
"companies_excludes": exclude_companies,
|
|
||||||
"company_types": company_types,
|
|
||||||
"company_keywords_includes": company_include_keywords,
|
|
||||||
"company_keywords_excludes": company_exclude_keywords,
|
|
||||||
"industries_includes": include_industries,
|
|
||||||
"industries_excludes": exclude_industries,
|
|
||||||
}
|
|
||||||
params = {k: v for k, v in params.items() if v is not None}
|
|
||||||
response = requests.post(
|
|
||||||
"https://prod.api.persana.ai/api/v1/people/search",
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"x-api-key": self.api_key,
|
|
||||||
},
|
|
||||||
json=params,
|
|
||||||
)
|
|
||||||
json_data = response.json()
|
|
||||||
return json_data["data"]["profiles"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
client = PersanaClient()
|
|
||||||
result = client.people_search(
|
|
||||||
include_job_titles=["Software Engineer"],
|
|
||||||
exclude_job_titles=["CTO"],
|
|
||||||
include_companies=["Google"],
|
|
||||||
exclude_companies=["Apple"],
|
|
||||||
company_types=["Public Company"],
|
|
||||||
company_include_keywords=["Software"],
|
|
||||||
company_exclude_keywords=["Data"],
|
|
||||||
include_industries=["Technology"],
|
|
||||||
exclude_industries=["Finance"],
|
|
||||||
)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_profiles(profiles: List[dict]) -> List[dict]:
|
|
||||||
new_profiles = []
|
|
||||||
experience_data_keep_keys = [
|
|
||||||
"title",
|
|
||||||
"company_name",
|
|
||||||
"company_company_headline",
|
|
||||||
"company_hero_image",
|
|
||||||
"company_description",
|
|
||||||
"company_type",
|
|
||||||
]
|
|
||||||
profile_keep_keys = ["name", "headline", "company", "title"]
|
|
||||||
for profile in profiles:
|
|
||||||
new_experience_data = {
|
|
||||||
k: v
|
|
||||||
for k, v in profile["experience_data"].items()
|
|
||||||
if k in experience_data_keep_keys
|
|
||||||
}
|
|
||||||
new_profile = {k: v for k, v in profile.items() if k in profile_keep_keys}
|
|
||||||
new_profile["experience_data"] = new_experience_data
|
|
||||||
new_profiles.append(new_profile)
|
|
||||||
return new_profiles
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
{
|
|
||||||
"AutoConfig": "agent.agent.PersanaConfig",
|
|
||||||
"AutoAgent": "agent.agent.PersanaSearcher"
|
|
||||||
}
|
|
||||||
86
compile.py
86
compile.py
@@ -1,86 +0,0 @@
|
|||||||
from agent.agent import PersanaSearcher, PersanaConfig
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
|
||||||
import dspy
|
|
||||||
import json
|
|
||||||
from dspy import Prediction, Example
|
|
||||||
|
|
||||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
||||||
feedback_creator = searcher.feedback_creator
|
|
||||||
searcher.push_to_hub("swagginty/persana-lead-gen", with_code=True)
|
|
||||||
raise Exception("Done")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchExample(Example):
|
|
||||||
company_description: str
|
|
||||||
target_customer: str
|
|
||||||
selected_profiles: list[dict]
|
|
||||||
|
|
||||||
|
|
||||||
class SearchPrediction(Prediction):
|
|
||||||
profiles: list[dict]
|
|
||||||
search_parameters: dict
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_results(
|
|
||||||
target: SearchExample,
|
|
||||||
predictied: SearchPrediction,
|
|
||||||
trace=None,
|
|
||||||
pred_name=None,
|
|
||||||
pred_trace=None,
|
|
||||||
) -> Prediction:
|
|
||||||
"""
|
|
||||||
Evaluates the search results target results were retrieved
|
|
||||||
"""
|
|
||||||
# How many of the target profiles were retrieved
|
|
||||||
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
|
||||||
count = 0
|
|
||||||
for t_result in target.selected_profiles:
|
|
||||||
if t_result["profile_id"] in pred_ids:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
score = count / len(target.selected_profiles)
|
|
||||||
|
|
||||||
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
|
||||||
# Which retrieved profiles were target profiles
|
|
||||||
selected_preds = [
|
|
||||||
result for result in predictied.profiles if result["profile_id"] in target_ids
|
|
||||||
]
|
|
||||||
# Which retrieved profiles were not target profiles
|
|
||||||
unselected_preds = [
|
|
||||||
result
|
|
||||||
for result in predictied.profiles
|
|
||||||
if result["profile_id"] not in target_ids
|
|
||||||
]
|
|
||||||
# Resuse feedback creator to get feedback for prompt creation
|
|
||||||
feedback = feedback_creator(
|
|
||||||
search_parameters=predictied.search_parameters,
|
|
||||||
selected_profiles=selected_preds,
|
|
||||||
unselected_profiles=unselected_preds,
|
|
||||||
user_feedback=None,
|
|
||||||
).feedback
|
|
||||||
return Prediction(
|
|
||||||
score=score,
|
|
||||||
feedback=feedback,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
|
||||||
trainset = [
|
|
||||||
dspy.Example(
|
|
||||||
company_description=e["company_description"],
|
|
||||||
target_customer=e["target_customer"],
|
|
||||||
selected_profiles=e["selected_profiles"],
|
|
||||||
).with_inputs("company_description", "target_customer")
|
|
||||||
for e in data
|
|
||||||
]
|
|
||||||
compiler = dspy.GEPA(
|
|
||||||
metric=evaluate_results,
|
|
||||||
auto="light",
|
|
||||||
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
|
||||||
)
|
|
||||||
compiled_searcher = compiler.compile(searcher, trainset=trainset)
|
|
||||||
compiled_searcher.push_to_hub("swagginty/persana-lead-gen")
|
|
||||||
File diff suppressed because one or more lines are too long
334
main.py
334
main.py
@@ -1,334 +0,0 @@
|
|||||||
from textual.app import App, ComposeResult
|
|
||||||
from textual.screen import Screen
|
|
||||||
from textual.widgets import (
|
|
||||||
Footer,
|
|
||||||
Header,
|
|
||||||
SelectionList,
|
|
||||||
Button,
|
|
||||||
TextArea,
|
|
||||||
Markdown,
|
|
||||||
MarkdownViewer,
|
|
||||||
)
|
|
||||||
from textual.containers import VerticalScroll, Horizontal, Vertical
|
|
||||||
from textual.binding import Binding
|
|
||||||
from textual.widgets.selection_list import Selection
|
|
||||||
from agent.agent import PersanaSearcher, PersanaConfig
|
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import json
|
|
||||||
import dspy
|
|
||||||
from dspy.streaming import (
|
|
||||||
StatusMessage,
|
|
||||||
StatusMessageProvider,
|
|
||||||
StreamListener,
|
|
||||||
StreamResponse,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import ModelResponseStream
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
||||||
|
|
||||||
|
|
||||||
class StartScreen(Screen):
|
|
||||||
BINDINGS = [
|
|
||||||
("ctrl+s", "start", "Start"),
|
|
||||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
|
||||||
]
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
yield Header()
|
|
||||||
yield TextArea(
|
|
||||||
placeholder="Enter your company description",
|
|
||||||
id="company_description",
|
|
||||||
)
|
|
||||||
yield TextArea(
|
|
||||||
placeholder="Enter your target customer",
|
|
||||||
id="target_customer",
|
|
||||||
)
|
|
||||||
yield Button("Start", classes="buttons", id="start", action="start")
|
|
||||||
yield Footer()
|
|
||||||
|
|
||||||
def action_start(self) -> None:
|
|
||||||
self.app.company_description = self.query_one(
|
|
||||||
"#company_description", TextArea
|
|
||||||
).text
|
|
||||||
self.app.target_customer = self.query_one("#target_customer", TextArea).text
|
|
||||||
self.app.profiles = None
|
|
||||||
self.app.all_selected_profiles = []
|
|
||||||
self.app.push_screen("select")
|
|
||||||
|
|
||||||
|
|
||||||
class Status(StatusMessageProvider):
|
|
||||||
def tool_start_status_message(self, instance, inputs):
|
|
||||||
return "\n\nRunning search...\n\n" + self.dict_to_markdown_table(
|
|
||||||
inputs["kwargs"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def tool_end_status_message(self, instance, outputs):
|
|
||||||
"\n\n Done \n\n"
|
|
||||||
|
|
||||||
def dict_to_markdown_table(self, d: dict) -> str:
|
|
||||||
"""Convert a dict into a Markdown table with 'Name' and 'Value' columns."""
|
|
||||||
headers = ["Param", "Value"]
|
|
||||||
header_row = f"| {' | '.join(headers)} |"
|
|
||||||
separator = f"| {' | '.join(['---'] * len(headers))} |"
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
for key, value in d.items():
|
|
||||||
formatted_value = self.format_value(value)
|
|
||||||
rows.append(f"| {key} | {formatted_value} |")
|
|
||||||
|
|
||||||
return "\n".join([header_row, separator, *rows])
|
|
||||||
|
|
||||||
def format_value(self, v):
|
|
||||||
if isinstance(v, list):
|
|
||||||
return ", ".join(v)
|
|
||||||
return str(v)
|
|
||||||
|
|
||||||
|
|
||||||
streaming_searcher = dspy.streamify(
|
|
||||||
searcher,
|
|
||||||
status_message_provider=Status(),
|
|
||||||
stream_listeners=[
|
|
||||||
StreamListener(
|
|
||||||
predict=searcher.feedback_creator,
|
|
||||||
signature_field_name="feedback",
|
|
||||||
allow_reuse=True,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SelectScreen(Screen):
|
|
||||||
BINDINGS = [
|
|
||||||
Binding(key="ctrl+r", action="refine", description="Refine", priority=True),
|
|
||||||
Binding(key="ctrl+s", action="submit", description="Submit", priority=True),
|
|
||||||
Binding(key="escape", action="back", description="Back"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
yield Header()
|
|
||||||
yield Horizontal(
|
|
||||||
Vertical(
|
|
||||||
VerticalScroll(SelectionList[int](id="profiles"), id="profiles_scroll"),
|
|
||||||
TextArea(placeholder="Feedback for next search", id="feedback"),
|
|
||||||
id="left_container",
|
|
||||||
),
|
|
||||||
MarkdownViewer(show_table_of_contents=False),
|
|
||||||
id="main_container",
|
|
||||||
)
|
|
||||||
yield Horizontal(
|
|
||||||
Button("Refine", id="refine", action="screen.refine"),
|
|
||||||
Button("Submit", id="submit", action="screen.submit"),
|
|
||||||
classes="buttons",
|
|
||||||
)
|
|
||||||
yield Footer(id="footer")
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
self.query_one("#profiles", SelectionList)._bindings.bind(
|
|
||||||
"right", "screen.view_profile", description="View Profile", priority=True
|
|
||||||
)
|
|
||||||
self.query_one("#footer", Footer).refresh_bindings()
|
|
||||||
|
|
||||||
def on_screen_resume(self, event) -> None:
|
|
||||||
if self.app.profiles is None:
|
|
||||||
self.run_worker(self.run_search(with_feedback=False), exclusive=True)
|
|
||||||
|
|
||||||
def action_refine(self) -> None:
|
|
||||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
||||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
|
||||||
self.run_worker(self.run_search(with_feedback=True), exclusive=True)
|
|
||||||
|
|
||||||
def action_submit(self) -> None:
|
|
||||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
||||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
|
||||||
visited = set()
|
|
||||||
new_selected_profiles = []
|
|
||||||
for profile in self.app.all_selected_profiles:
|
|
||||||
if profile["profile_id"] not in visited:
|
|
||||||
new_selected_profiles.append(profile)
|
|
||||||
visited.add(profile["profile_id"])
|
|
||||||
self.app.all_selected_profiles = new_selected_profiles
|
|
||||||
entry = {
|
|
||||||
"company_description": self.app.company_description,
|
|
||||||
"target_customer": self.app.target_customer,
|
|
||||||
"selected_profiles": self.app.all_selected_profiles,
|
|
||||||
}
|
|
||||||
with open("dataset.jsonl", "a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(entry))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
self.app.all_selected_profiles = []
|
|
||||||
self.clear_content()
|
|
||||||
self.app.push_screen("start")
|
|
||||||
|
|
||||||
def action_view_profile(self) -> None:
|
|
||||||
selected_profile = self.query_one("#profiles", SelectionList).highlighted
|
|
||||||
self.app.push_screen(ProfileScreen(selected_profile))
|
|
||||||
|
|
||||||
def action_back(self) -> None:
|
|
||||||
if self.focused.id != "feedback":
|
|
||||||
self.app.push_screen("start")
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
if event.button.id == "refine":
|
|
||||||
self.action_refine()
|
|
||||||
elif event.button.id == "submit":
|
|
||||||
self.action_submit()
|
|
||||||
|
|
||||||
async def run_search(
|
|
||||||
self,
|
|
||||||
with_feedback: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Runs a search using the search agent and updates the pofiles list in UI. If with_feedback is True,
|
|
||||||
it will enter feedback and update the profiles list.
|
|
||||||
"""
|
|
||||||
viewer = self.query_one(MarkdownViewer)
|
|
||||||
ui_log = self.query_one(Markdown)
|
|
||||||
sl = self.query_one("#profiles", SelectionList)
|
|
||||||
sl.clear_options()
|
|
||||||
scroll = self.query_one("#profiles_scroll", VerticalScroll)
|
|
||||||
scroll.loading = True
|
|
||||||
feedback_dict = None
|
|
||||||
if with_feedback:
|
|
||||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
||||||
num_profiles = len(self.query_one("#profiles", SelectionList).options)
|
|
||||||
unselected_profiles = list(
|
|
||||||
set(range(num_profiles)) - set(selected_profiles)
|
|
||||||
)
|
|
||||||
feedback_dict = {
|
|
||||||
"selected_profiles": self.get_profiles(selected_profiles),
|
|
||||||
"unselected_profiles": self.get_profiles(unselected_profiles),
|
|
||||||
"user_feedback": self.query_one("#feedback", TextArea).text,
|
|
||||||
}
|
|
||||||
stream = streaming_searcher(
|
|
||||||
company_description=self.app.company_description,
|
|
||||||
target_customer=self.app.target_customer,
|
|
||||||
feedback_dict=feedback_dict,
|
|
||||||
)
|
|
||||||
async for message in stream:
|
|
||||||
if isinstance(message, StatusMessage):
|
|
||||||
ui_log.append(message.message)
|
|
||||||
elif isinstance(message, dspy.Prediction):
|
|
||||||
self.app.profiles = message.profiles
|
|
||||||
elif isinstance(message, ModelResponseStream):
|
|
||||||
if message.choices[0].delta.content is not None:
|
|
||||||
ui_log.append(message.choices[0].delta.content)
|
|
||||||
elif isinstance(message, StreamResponse):
|
|
||||||
ui_log.append(message.chunk)
|
|
||||||
viewer.scroll_end(force=True)
|
|
||||||
|
|
||||||
for i, profile in enumerate(self.app.profiles):
|
|
||||||
summary = f"{profile['name']} - {profile['experience_data']['title']} | {profile['experience_data']['company_name']}"
|
|
||||||
option = Selection(summary, i)
|
|
||||||
sl.add_option(option)
|
|
||||||
scroll.loading = False
|
|
||||||
|
|
||||||
def get_profiles(self, indices: list[int]) -> list[dict]:
|
|
||||||
return [self.app.profiles[i] for i in indices]
|
|
||||||
|
|
||||||
def clear_content(self) -> None:
|
|
||||||
self.query_one("#profiles", SelectionList).clear_options()
|
|
||||||
self.query_one("#profiles_scroll", VerticalScroll).loading = False
|
|
||||||
self.query_one("#feedback", TextArea).text = ""
|
|
||||||
self.query_one(Markdown).update("")
|
|
||||||
|
|
||||||
|
|
||||||
class ProfileScreen(Screen):
|
|
||||||
BINDINGS = [
|
|
||||||
Binding(
|
|
||||||
key="left",
|
|
||||||
action="app.push_screen('select')",
|
|
||||||
description="Back to Select",
|
|
||||||
show=True,
|
|
||||||
priority=True,
|
|
||||||
),
|
|
||||||
Binding(
|
|
||||||
key="escape",
|
|
||||||
action="app.pop_screen",
|
|
||||||
description="Back",
|
|
||||||
priority=True,
|
|
||||||
show=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, profile: dict) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.profile = profile
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
yield Header()
|
|
||||||
yield Markdown(
|
|
||||||
markdown=self.get_profile_markdown(self.app.profiles[self.profile]),
|
|
||||||
id="profile",
|
|
||||||
)
|
|
||||||
yield Footer()
|
|
||||||
|
|
||||||
def get_profile_markdown(self, profile: dict):
|
|
||||||
headline = f"{profile['headline']} \n" if profile["headline"] else ""
|
|
||||||
md_text = (
|
|
||||||
f"# {profile['name']} - {profile['experience_data']['title']} · {profile['locality']} [🔗]({profile['url']})",
|
|
||||||
f"{headline}",
|
|
||||||
f"## {profile['experience_data']['company_name']} - {profile['experience_data']['company_locality']} · {profile['experience_data']['company_type']} [🔗]({profile['experience_data']['company_linkedin_url']})",
|
|
||||||
f"{profile['experience_data']['company_description']}",
|
|
||||||
)
|
|
||||||
md_text = "\n".join(md_text)
|
|
||||||
return md_text
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetEntry(TypedDict):
|
|
||||||
company_description: str
|
|
||||||
target_customer: str
|
|
||||||
selected_profiles: list[dict]
|
|
||||||
|
|
||||||
|
|
||||||
class PersanaApp(App):
|
|
||||||
SCREENS = {"start": StartScreen, "select": SelectScreen, "profile": ProfileScreen}
|
|
||||||
BINDINGS = [
|
|
||||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
|
||||||
]
|
|
||||||
CSS = """
|
|
||||||
Horizontal {
|
|
||||||
height: 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button {
|
|
||||||
width: 1fr; /* 👈 expands equally like flex:1 */
|
|
||||||
margin: 1;
|
|
||||||
background: #bae0f5;
|
|
||||||
}
|
|
||||||
#start{
|
|
||||||
height: 3;
|
|
||||||
}
|
|
||||||
#left_container {
|
|
||||||
height: 1fr;
|
|
||||||
}
|
|
||||||
#main_container {
|
|
||||||
height: 1fr;
|
|
||||||
}
|
|
||||||
#log {
|
|
||||||
height: 5;
|
|
||||||
width: 1fr;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
company_description: Optional[str] = None
|
|
||||||
target_customer: Optional[str] = None
|
|
||||||
profiles: Optional[list[dict]] = None
|
|
||||||
all_selected_profiles: list[dict] = []
|
|
||||||
dataset: list[DatasetEntry] = []
|
|
||||||
log_buffer: str = ""
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
yield StartScreen()
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
self.push_screen("start")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app = PersanaApp()
|
|
||||||
app.run()
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "persana-lead-gen"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.11"
|
|
||||||
dependencies = ["modaic>=0.3.0", "rich>=14.1.0", "textual>=6.2.1", "textual-dev>=1.7.0"]
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user