(no commit message)

This commit is contained in:
2025-10-06 21:49:50 -07:00
parent 8a2608f427
commit d0dbad868e
11 changed files with 1064 additions and 9 deletions

217
compile.py Normal file
View File

@@ -0,0 +1,217 @@
from sqlalchemy.sql import true
from agent.agent import PersanaSearcher, PersanaConfig
from dotenv import load_dotenv
import os
import dspy
import json
from dspy import Prediction, Example
from typing import Optional, Tuple
from agent.persana import CompanyType
searcher = PersanaSearcher(
config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"), train=True
)
feedback_creator = searcher.feedback_creator
class SearchExample(Example):
company_description: str
target_customer: str
selected_profiles: list[dict]
class SearchPrediction(Prediction):
profiles: Optional[list[dict]]
search_parameters: dict
def any_in(list: list[str], string: str):
return any(item.lower() in string.lower() for item in list)
def include_profile(
search_parameters: dict,
profile: dict,
) -> Tuple[bool, Optional[str]]:
if (titles := search_parameters.get("include_job_titles")) and (
not any_in(titles, profile["experience_data"]["title"])
):
return (
False,
f"include_job_titles: {titles} not in {profile}['experience_data']['title']",
)
if (companies := search_parameters.get("include_companies")) and (
not any_in(companies, profile["experience_data"]["company_name"])
):
return (
False,
f"include_companies: {companies} not in {profile}['experience_data']['company_name']",
)
if (company_types := search_parameters.get("company_types")) and (
not any_in(company_types, profile["experience_data"]["company_type"])
):
return (
False,
f"company_types: {company_types} not in {profile}['experience_data']['company_type']",
)
if (
(company_keywords := search_parameters.get("company_include_keywords"))
and not any_in(
company_keywords, profile["experience_data"]["company_company_headline"]
)
and not any_in(
company_keywords, profile["experience_data"]["company_description"]
)
):
return (
False,
f"company_include_keywords: {company_keywords} not in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
)
return True, None
def exclude_profile(
search_parameters: dict,
profile: dict,
) -> Tuple[bool, Optional[str]]:
if (titles := search_parameters.get("exclude_job_titles")) and (
any_in(titles, profile["experience_data"]["title"])
):
return (
True,
f"exclude_job_titles: {titles} in {profile}['experience_data']['title']",
)
if (companies := search_parameters.get("exclude_companies")) and (
any_in(companies, profile["experience_data"]["company_name"])
):
return (
True,
f"exclude_companies: {companies} in {profile}['experience_data']['company_name']",
)
if (company_keywords := search_parameters.get("company_exclude_keywords")) and (
any_in(company_keywords, profile["experience_data"]["company_company_headline"])
and any_in(company_keywords, profile["experience_data"]["company_description"])
):
return (
True,
f"company_exclude_keywords: {company_keywords} in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
)
return False, None
def get_search_eval(
search_parameters: dict,
target_profiles: list[dict],
) -> Tuple[float, list[dict]]:
count = 0
exclude_reasons = []
for profile in target_profiles:
include, exclude_reason = include_profile(search_parameters, profile)
if include:
exclude, exclude_reason = exclude_profile(search_parameters, profile)
if not exclude:
count += 1
if exclude_reason:
exclude_reasons.append(exclude_reason)
score = count / len(target_profiles)
return score, exclude_reasons
def evaluate_results_expensive(
target: SearchExample,
predictied: SearchPrediction,
trace=None,
pred_name=None,
pred_trace=None,
) -> Prediction:
"""
Evaluates the search results target results were retrieved
"""
# How many of the target profiles were retrieved
pred_ids = {result["profile_id"] for result in predictied.profiles}
count = 0
for t_result in target.selected_profiles:
if t_result["profile_id"] in pred_ids:
count += 1
score = count / len(target.selected_profiles)
target_ids = {result["profile_id"] for result in target.selected_profiles}
# Which retrieved profiles were target profiles
selected_preds = [
result for result in predictied.profiles if result["profile_id"] in target_ids
]
# Which retrieved profiles were not target profiles
unselected_preds = [
result
for result in predictied.profiles
if result["profile_id"] not in target_ids
]
# Resuse feedback creator to get feedback for prompt creation
feedback = feedback_creator(
search_parameters=predictied.search_parameters,
selected_profiles=selected_preds,
unselected_profiles=unselected_preds,
user_feedback=None,
).feedback
return Prediction(
score=score,
feedback=feedback,
)
def evaluate_results_cheap(
target: SearchExample,
predictied: SearchPrediction,
trace=None,
pred_name=None,
pred_trace=None,
) -> Prediction:
"""
Evaluates the search results target results were retrieved
"""
# How many of the target profiles were retrieved
score, unselected_profiles = get_search_eval(
predictied.search_parameters, target.selected_profiles
)
feedback = (
"The model failed to retrieve the following profiles in the search: "
+ ", ".join([str(profile) for profile in unselected_profiles])
)
return Prediction(
score=score,
feedback=feedback,
)
if __name__ == "__main__":
load_dotenv()
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
trainset = [
dspy.Example(
company_description=e["company_description"],
target_customer=e["target_customer"],
selected_profiles=e["selected_profiles"],
).with_inputs("company_description", "target_customer")
for e in data
]
# for d in trainset:
# pred = searcher(
# company_description=d.company_description,
# target_customer=d.target_customer,
# )
# x = evaluate_results_cheap(d, pred)
compiler = dspy.GEPA(
metric=evaluate_results_cheap,
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
auto="light",
)
compiled_searcher = compiler.compile(
searcher,
trainset=trainset,
)
compiled_searcher.save("compiled_searcher.json")
compiled_searcher.push_to_hub("swagginty/persana-lead-gen", with_code=True)