(no commit message)
This commit is contained in:
217
compile.py
Normal file
217
compile.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from sqlalchemy.sql import true
|
||||
from agent.agent import PersanaSearcher, PersanaConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import dspy
|
||||
import json
|
||||
from dspy import Prediction, Example
|
||||
from typing import Optional, Tuple
|
||||
from agent.persana import CompanyType
|
||||
|
||||
|
||||
searcher = PersanaSearcher(
|
||||
config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"), train=True
|
||||
)
|
||||
feedback_creator = searcher.feedback_creator
|
||||
|
||||
|
||||
class SearchExample(Example):
|
||||
company_description: str
|
||||
target_customer: str
|
||||
selected_profiles: list[dict]
|
||||
|
||||
|
||||
class SearchPrediction(Prediction):
|
||||
profiles: Optional[list[dict]]
|
||||
search_parameters: dict
|
||||
|
||||
|
||||
def any_in(list: list[str], string: str):
|
||||
return any(item.lower() in string.lower() for item in list)
|
||||
|
||||
|
||||
def include_profile(
|
||||
search_parameters: dict,
|
||||
profile: dict,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if (titles := search_parameters.get("include_job_titles")) and (
|
||||
not any_in(titles, profile["experience_data"]["title"])
|
||||
):
|
||||
return (
|
||||
False,
|
||||
f"include_job_titles: {titles} not in {profile}['experience_data']['title']",
|
||||
)
|
||||
if (companies := search_parameters.get("include_companies")) and (
|
||||
not any_in(companies, profile["experience_data"]["company_name"])
|
||||
):
|
||||
return (
|
||||
False,
|
||||
f"include_companies: {companies} not in {profile}['experience_data']['company_name']",
|
||||
)
|
||||
if (company_types := search_parameters.get("company_types")) and (
|
||||
not any_in(company_types, profile["experience_data"]["company_type"])
|
||||
):
|
||||
return (
|
||||
False,
|
||||
f"company_types: {company_types} not in {profile}['experience_data']['company_type']",
|
||||
)
|
||||
if (
|
||||
(company_keywords := search_parameters.get("company_include_keywords"))
|
||||
and not any_in(
|
||||
company_keywords, profile["experience_data"]["company_company_headline"]
|
||||
)
|
||||
and not any_in(
|
||||
company_keywords, profile["experience_data"]["company_description"]
|
||||
)
|
||||
):
|
||||
return (
|
||||
False,
|
||||
f"company_include_keywords: {company_keywords} not in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
|
||||
)
|
||||
return True, None
|
||||
|
||||
|
||||
def exclude_profile(
|
||||
search_parameters: dict,
|
||||
profile: dict,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if (titles := search_parameters.get("exclude_job_titles")) and (
|
||||
any_in(titles, profile["experience_data"]["title"])
|
||||
):
|
||||
return (
|
||||
True,
|
||||
f"exclude_job_titles: {titles} in {profile}['experience_data']['title']",
|
||||
)
|
||||
if (companies := search_parameters.get("exclude_companies")) and (
|
||||
any_in(companies, profile["experience_data"]["company_name"])
|
||||
):
|
||||
return (
|
||||
True,
|
||||
f"exclude_companies: {companies} in {profile}['experience_data']['company_name']",
|
||||
)
|
||||
if (company_keywords := search_parameters.get("company_exclude_keywords")) and (
|
||||
any_in(company_keywords, profile["experience_data"]["company_company_headline"])
|
||||
and any_in(company_keywords, profile["experience_data"]["company_description"])
|
||||
):
|
||||
return (
|
||||
True,
|
||||
f"company_exclude_keywords: {company_keywords} in {profile}['experience_data']['company_company_headline'] or {profile}['experience_data']['company_description']",
|
||||
)
|
||||
return False, None
|
||||
|
||||
|
||||
def get_search_eval(
|
||||
search_parameters: dict,
|
||||
target_profiles: list[dict],
|
||||
) -> Tuple[float, list[dict]]:
|
||||
count = 0
|
||||
exclude_reasons = []
|
||||
for profile in target_profiles:
|
||||
include, exclude_reason = include_profile(search_parameters, profile)
|
||||
if include:
|
||||
exclude, exclude_reason = exclude_profile(search_parameters, profile)
|
||||
if not exclude:
|
||||
count += 1
|
||||
if exclude_reason:
|
||||
exclude_reasons.append(exclude_reason)
|
||||
score = count / len(target_profiles)
|
||||
return score, exclude_reasons
|
||||
|
||||
|
||||
def evaluate_results_expensive(
|
||||
target: SearchExample,
|
||||
predictied: SearchPrediction,
|
||||
trace=None,
|
||||
pred_name=None,
|
||||
pred_trace=None,
|
||||
) -> Prediction:
|
||||
"""
|
||||
Evaluates the search results target results were retrieved
|
||||
"""
|
||||
# How many of the target profiles were retrieved
|
||||
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
||||
count = 0
|
||||
for t_result in target.selected_profiles:
|
||||
if t_result["profile_id"] in pred_ids:
|
||||
count += 1
|
||||
score = count / len(target.selected_profiles)
|
||||
|
||||
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
||||
# Which retrieved profiles were target profiles
|
||||
selected_preds = [
|
||||
result for result in predictied.profiles if result["profile_id"] in target_ids
|
||||
]
|
||||
# Which retrieved profiles were not target profiles
|
||||
unselected_preds = [
|
||||
result
|
||||
for result in predictied.profiles
|
||||
if result["profile_id"] not in target_ids
|
||||
]
|
||||
# Resuse feedback creator to get feedback for prompt creation
|
||||
feedback = feedback_creator(
|
||||
search_parameters=predictied.search_parameters,
|
||||
selected_profiles=selected_preds,
|
||||
unselected_profiles=unselected_preds,
|
||||
user_feedback=None,
|
||||
).feedback
|
||||
return Prediction(
|
||||
score=score,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_results_cheap(
|
||||
target: SearchExample,
|
||||
predictied: SearchPrediction,
|
||||
trace=None,
|
||||
pred_name=None,
|
||||
pred_trace=None,
|
||||
) -> Prediction:
|
||||
"""
|
||||
Evaluates the search results target results were retrieved
|
||||
"""
|
||||
# How many of the target profiles were retrieved
|
||||
score, unselected_profiles = get_search_eval(
|
||||
predictied.search_parameters, target.selected_profiles
|
||||
)
|
||||
feedback = (
|
||||
"The model failed to retrieve the following profiles in the search: "
|
||||
+ ", ".join([str(profile) for profile in unselected_profiles])
|
||||
)
|
||||
return Prediction(
|
||||
score=score,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
company_description=e["company_description"],
|
||||
target_customer=e["target_customer"],
|
||||
selected_profiles=e["selected_profiles"],
|
||||
).with_inputs("company_description", "target_customer")
|
||||
for e in data
|
||||
]
|
||||
|
||||
# for d in trainset:
|
||||
# pred = searcher(
|
||||
# company_description=d.company_description,
|
||||
# target_customer=d.target_customer,
|
||||
# )
|
||||
# x = evaluate_results_cheap(d, pred)
|
||||
|
||||
compiler = dspy.GEPA(
|
||||
metric=evaluate_results_cheap,
|
||||
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
||||
auto="light",
|
||||
)
|
||||
compiled_searcher = compiler.compile(
|
||||
searcher,
|
||||
trainset=trainset,
|
||||
)
|
||||
compiled_searcher.save("compiled_searcher.json")
|
||||
compiled_searcher.push_to_hub("swagginty/persana-lead-gen", with_code=True)
|
||||
Reference in New Issue
Block a user