(no commit message)
This commit is contained in:
88
compile.py
88
compile.py
@@ -1,88 +0,0 @@
|
||||
from agent.agent import PersanaSearcher, PersanaConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import dspy
|
||||
import json
|
||||
from dspy import Prediction, Example
|
||||
|
||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
||||
feedback_creator = searcher.feedback_creator
|
||||
searcher.push_to_hub(
|
||||
"swagginty/persana-lead-gen", with_code=True, commit_message="Uncompiled"
|
||||
)
|
||||
|
||||
|
||||
class SearchExample(Example):
|
||||
company_description: str
|
||||
target_customer: str
|
||||
selected_profiles: list[dict]
|
||||
|
||||
|
||||
class SearchPrediction(Prediction):
|
||||
profiles: list[dict]
|
||||
search_parameters: dict
|
||||
|
||||
|
||||
def evaluate_results(
|
||||
target: SearchExample,
|
||||
predictied: SearchPrediction,
|
||||
trace=None,
|
||||
pred_name=None,
|
||||
pred_trace=None,
|
||||
) -> Prediction:
|
||||
"""
|
||||
Evaluates the search results target results were retrieved
|
||||
"""
|
||||
# How many of the target profiles were retrieved
|
||||
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
||||
count = 0
|
||||
for t_result in target.selected_profiles:
|
||||
if t_result["profile_id"] in pred_ids:
|
||||
count += 1
|
||||
|
||||
score = count / len(target.selected_profiles)
|
||||
|
||||
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
||||
# Which retrieved profiles were target profiles
|
||||
selected_preds = [
|
||||
result for result in predictied.profiles if result["profile_id"] in target_ids
|
||||
]
|
||||
# Which retrieved profiles were not target profiles
|
||||
unselected_preds = [
|
||||
result
|
||||
for result in predictied.profiles
|
||||
if result["profile_id"] not in target_ids
|
||||
]
|
||||
# Resuse feedback creator to get feedback for prompt creation
|
||||
feedback = feedback_creator(
|
||||
search_parameters=predictied.search_parameters,
|
||||
selected_profiles=selected_preds,
|
||||
unselected_profiles=unselected_preds,
|
||||
user_feedback=None,
|
||||
).feedback
|
||||
return Prediction(
|
||||
score=score,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
company_description=e["company_description"],
|
||||
target_customer=e["target_customer"],
|
||||
selected_profiles=e["selected_profiles"],
|
||||
).with_inputs("company_description", "target_customer")
|
||||
for e in data
|
||||
]
|
||||
compiler = dspy.GEPA(
|
||||
metric=evaluate_results,
|
||||
auto="light",
|
||||
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
||||
)
|
||||
compiled_searcher = compiler.compile(searcher, trainset=trainset)
|
||||
compiled_searcher.save("compiled_searcher.json")
|
||||
compiled_searcher.push_to_hub("swagginty/persana-lead-gen")
|
||||
Reference in New Issue
Block a user