218 lines
6.9 KiB
Python
218 lines
6.9 KiB
Python
from sqlalchemy.sql import true
|
|
from agent.agent import PersanaSearcher, PersanaConfig
|
|
from dotenv import load_dotenv
|
|
import os
|
|
import dspy
|
|
import json
|
|
from dspy import Prediction, Example
|
|
from typing import Optional, Tuple
|
|
from agent.persana import CompanyType
|
|
|
|
|
|
searcher = PersanaSearcher(
|
|
config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"), train=True
|
|
)
|
|
feedback_creator = searcher.feedback_creator
|
|
|
|
|
|
class SearchExample(Example):
|
|
company_description: str
|
|
target_customer: str
|
|
selected_profiles: list[dict]
|
|
|
|
|
|
class SearchPrediction(Prediction):
|
|
profiles: Optional[list[dict]]
|
|
search_parameters: dict
|
|
|
|
|
|
def any_in(list: list[str], string: str):
|
|
return any(item.lower() in string.lower() for item in list)
|
|
|
|
|
|
def include_profile(
|
|
search_parameters: dict,
|
|
profile: dict,
|
|
) -> Tuple[bool, Optional[str]]:
|
|
if (titles := search_parameters.get("include_job_titles")) and (
|
|
not any_in(titles, profile["experience_data"]["title"])
|
|
):
|
|
return (
|
|
False,
|
|
f"include_job_titles: {titles} not in {profile}['experience_data']['title']",
|
|
)
|
|
if (companies := search_parameters.get("include_companies")) and (
|
|
not any_in(companies, profile["experience_data"]["company_name"])
|
|
):
|
|
return (
|
|
False,
|
|
f"include_companies: {companies} not in {profile}['experience_data']['company_name']",
|
|
)
|
|
if (company_types := search_parameters.get("company_types")) and (
|
|
not any_in(company_types, profile["experience_data"]["company_type"])
|
|
):
|
|
return (
|
|
False,
|
|
f"company_types: {company_types} not in {profile}['experience_data']['company_type']",
|
|
)
|
|
if (
|
|
(company_keywords := search_parameters.get("company_include_keywords"))
|
|
and not any_in(
|
|
company_keywords, profile["experience_data"]["company_company_headline"]
|
|
)
|
|
and not any_in(
|
|
company_keywords, profile["experience_data"]["company_description"]
|
|
)
|
|
):
|
|
return (
|
|
False,
|
|
f"company_include_keywords: {company_keywords} not in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
|
|
)
|
|
return True, None
|
|
|
|
|
|
def exclude_profile(
|
|
search_parameters: dict,
|
|
profile: dict,
|
|
) -> Tuple[bool, Optional[str]]:
|
|
if (titles := search_parameters.get("exclude_job_titles")) and (
|
|
any_in(titles, profile["experience_data"]["title"])
|
|
):
|
|
return (
|
|
True,
|
|
f"exclude_job_titles: {titles} in {profile}['experience_data']['title']",
|
|
)
|
|
if (companies := search_parameters.get("exclude_companies")) and (
|
|
any_in(companies, profile["experience_data"]["company_name"])
|
|
):
|
|
return (
|
|
True,
|
|
f"exclude_companies: {companies} in {profile}['experience_data']['company_name']",
|
|
)
|
|
if (company_keywords := search_parameters.get("company_exclude_keywords")) and (
|
|
any_in(company_keywords, profile["experience_data"]["company_company_headline"])
|
|
and any_in(company_keywords, profile["experience_data"]["company_description"])
|
|
):
|
|
return (
|
|
True,
|
|
f"company_exclude_keywords: {company_keywords} in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
|
|
)
|
|
return False, None
|
|
|
|
|
|
def get_search_eval(
|
|
search_parameters: dict,
|
|
target_profiles: list[dict],
|
|
) -> Tuple[float, list[dict]]:
|
|
count = 0
|
|
exclude_reasons = []
|
|
for profile in target_profiles:
|
|
include, exclude_reason = include_profile(search_parameters, profile)
|
|
if include:
|
|
exclude, exclude_reason = exclude_profile(search_parameters, profile)
|
|
if not exclude:
|
|
count += 1
|
|
if exclude_reason:
|
|
exclude_reasons.append(exclude_reason)
|
|
score = count / len(target_profiles)
|
|
return score, exclude_reasons
|
|
|
|
|
|
def evaluate_results_expensive(
|
|
target: SearchExample,
|
|
predictied: SearchPrediction,
|
|
trace=None,
|
|
pred_name=None,
|
|
pred_trace=None,
|
|
) -> Prediction:
|
|
"""
|
|
Evaluates the search results target results were retrieved
|
|
"""
|
|
# How many of the target profiles were retrieved
|
|
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
|
count = 0
|
|
for t_result in target.selected_profiles:
|
|
if t_result["profile_id"] in pred_ids:
|
|
count += 1
|
|
score = count / len(target.selected_profiles)
|
|
|
|
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
|
# Which retrieved profiles were target profiles
|
|
selected_preds = [
|
|
result for result in predictied.profiles if result["profile_id"] in target_ids
|
|
]
|
|
# Which retrieved profiles were not target profiles
|
|
unselected_preds = [
|
|
result
|
|
for result in predictied.profiles
|
|
if result["profile_id"] not in target_ids
|
|
]
|
|
# Resuse feedback creator to get feedback for prompt creation
|
|
feedback = feedback_creator(
|
|
search_parameters=predictied.search_parameters,
|
|
selected_profiles=selected_preds,
|
|
unselected_profiles=unselected_preds,
|
|
user_feedback=None,
|
|
).feedback
|
|
return Prediction(
|
|
score=score,
|
|
feedback=feedback,
|
|
)
|
|
|
|
|
|
def evaluate_results_cheap(
|
|
target: SearchExample,
|
|
predictied: SearchPrediction,
|
|
trace=None,
|
|
pred_name=None,
|
|
pred_trace=None,
|
|
) -> Prediction:
|
|
"""
|
|
Evaluates the search results target results were retrieved
|
|
"""
|
|
# How many of the target profiles were retrieved
|
|
score, unselected_profiles = get_search_eval(
|
|
predictied.search_parameters, target.selected_profiles
|
|
)
|
|
feedback = (
|
|
"The model failed to retrieve the following profiles in the search: "
|
|
+ ", ".join([str(profile) for profile in unselected_profiles])
|
|
)
|
|
return Prediction(
|
|
score=score,
|
|
feedback=feedback,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
load_dotenv()
|
|
|
|
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
|
trainset = [
|
|
dspy.Example(
|
|
company_description=e["company_description"],
|
|
target_customer=e["target_customer"],
|
|
selected_profiles=e["selected_profiles"],
|
|
).with_inputs("company_description", "target_customer")
|
|
for e in data
|
|
]
|
|
|
|
# for d in trainset:
|
|
# pred = searcher(
|
|
# company_description=d.company_description,
|
|
# target_customer=d.target_customer,
|
|
# )
|
|
# x = evaluate_results_cheap(d, pred)
|
|
|
|
compiler = dspy.GEPA(
|
|
metric=evaluate_results_cheap,
|
|
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
|
auto="light",
|
|
)
|
|
compiled_searcher = compiler.compile(
|
|
searcher,
|
|
trainset=trainset,
|
|
)
|
|
compiled_searcher.save("compiled_searcher.json")
|
|
compiled_searcher.push_to_hub("swagginty/persana-lead-gen", with_code=True)
|