87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
from agent.agent import PersanaSearcher, PersanaConfig
|
|
from dotenv import load_dotenv
|
|
import os
|
|
import dspy
|
|
import json
|
|
from dspy import Prediction, Example
|
|
|
|
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
feedback_creator = searcher.feedback_creator
|
|
searcher.push_to_hub("swagginty/persana-lead-gen", with_code=True)
|
|
raise Exception("Done")
|
|
|
|
|
|
class SearchExample(Example):
|
|
company_description: str
|
|
target_customer: str
|
|
selected_profiles: list[dict]
|
|
|
|
|
|
class SearchPrediction(Prediction):
|
|
profiles: list[dict]
|
|
search_parameters: dict
|
|
|
|
|
|
def evaluate_results(
|
|
target: SearchExample,
|
|
predictied: SearchPrediction,
|
|
trace=None,
|
|
pred_name=None,
|
|
pred_trace=None,
|
|
) -> Prediction:
|
|
"""
|
|
Evaluates the search results target results were retrieved
|
|
"""
|
|
# How many of the target profiles were retrieved
|
|
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
|
count = 0
|
|
for t_result in target.selected_profiles:
|
|
if t_result["profile_id"] in pred_ids:
|
|
count += 1
|
|
|
|
score = count / len(target.selected_profiles)
|
|
|
|
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
|
# Which retrieved profiles were target profiles
|
|
selected_preds = [
|
|
result for result in predictied.profiles if result["profile_id"] in target_ids
|
|
]
|
|
# Which retrieved profiles were not target profiles
|
|
unselected_preds = [
|
|
result
|
|
for result in predictied.profiles
|
|
if result["profile_id"] not in target_ids
|
|
]
|
|
# Resuse feedback creator to get feedback for prompt creation
|
|
feedback = feedback_creator(
|
|
search_parameters=predictied.search_parameters,
|
|
selected_profiles=selected_preds,
|
|
unselected_profiles=unselected_preds,
|
|
user_feedback=None,
|
|
).feedback
|
|
return Prediction(
|
|
score=score,
|
|
feedback=feedback,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
load_dotenv()
|
|
|
|
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
|
trainset = [
|
|
dspy.Example(
|
|
company_description=e["company_description"],
|
|
target_customer=e["target_customer"],
|
|
selected_profiles=e["selected_profiles"],
|
|
).with_inputs("company_description", "target_customer")
|
|
for e in data
|
|
]
|
|
compiler = dspy.GEPA(
|
|
metric=evaluate_results,
|
|
auto="light",
|
|
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
|
)
|
|
compiled_searcher = compiler.compile(searcher, trainset=trainset)
|
|
compiled_searcher.push_to_hub("swagginty/persana-lead-gen")
|