Uncompiled
This commit is contained in:
2
.example.env
Normal file
2
.example.env
Normal file
@@ -0,0 +1,2 @@
|
||||
PERSANA_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
0
__init__.py
Normal file
0
__init__.py
Normal file
0
agent/__init__.py
Normal file
0
agent/__init__.py
Normal file
236
agent/agent.py
Normal file
236
agent/agent.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
import dspy
|
||||
from typing import Literal, Optional, List
|
||||
from typing_extensions import TypedDict
|
||||
from .persana import PersanaClient, truncate_profiles
|
||||
|
||||
# Persana API Company Type Options
|
||||
CompanyType = Literal[
|
||||
"Public Company",
|
||||
"Educational",
|
||||
"Self Employed",
|
||||
"Government Agency",
|
||||
"Non Profit",
|
||||
"Self Owned",
|
||||
"Privately Held",
|
||||
"Partnership",
|
||||
]
|
||||
|
||||
|
||||
# used to store feedback on the search parameters
|
||||
class SearchFeedback(TypedDict):
|
||||
search_parameters: dict
|
||||
feedback: Optional[str]
|
||||
|
||||
|
||||
# Defines the prompt signature for the agent that takes in the user's company description
|
||||
# and target customer description and runs a search on the Persana API
|
||||
class ICPTOSearchSignature(dspy.Signature):
|
||||
"""
|
||||
You are a marketing agent that comes up with the ideal people search paramaters for Persana given a
|
||||
company description and a target customer description. The search API allows the following parameters:
|
||||
- Included Job Titles
|
||||
- Excluded Job Titles
|
||||
- Included Companies
|
||||
- Excluded Companies
|
||||
- Company Type
|
||||
- Company Include Keywords
|
||||
- Company Exclude Keywords
|
||||
- Included Industries
|
||||
- Excluded Industries
|
||||
|
||||
You may use all or some of these parameters to craft your search. It is recommended not to use all parameters unless necessary.
|
||||
"""
|
||||
|
||||
company_description: str = dspy.InputField(
|
||||
description="The description of the company"
|
||||
)
|
||||
target_customer: str = dspy.InputField(
|
||||
description="The target customer of the company"
|
||||
)
|
||||
feedback_history: List[SearchFeedback] = dspy.InputField(
|
||||
description="Feedback on previous searches"
|
||||
)
|
||||
tools: list[dspy.Tool] = dspy.InputField()
|
||||
tool_calls: dspy.ToolCalls = dspy.OutputField()
|
||||
|
||||
|
||||
# Defines the prompt signature for the agent that takes in the search parameters, its search results,
|
||||
# and optional user feedback on the search and generates feedback for the AI agent that generated the search parameters
|
||||
class FeedbackSignature(dspy.Signature):
|
||||
"""
|
||||
You are a helpful assistant that helps a user give feedback to an AI agent that generates search parameters for a Persana API,
|
||||
given a company description and a target customer description from the user. Use the search result profiles the user selected, along with the
|
||||
results the user deselected, along with the feedback from the user to give feedback to the AI agent.
|
||||
"""
|
||||
|
||||
search_parameters: dict = dspy.InputField(
|
||||
description="The search parameters for the Persana API"
|
||||
)
|
||||
|
||||
selected_profiles: List[dict] = dspy.InputField(
|
||||
description="The profiles the user selected"
|
||||
)
|
||||
unselected_profiles: List[dict] = dspy.InputField(
|
||||
description="The profiles the user did not select"
|
||||
)
|
||||
user_feedback: Optional[str] = dspy.InputField(
|
||||
description="Feedback from the user on the previous search"
|
||||
)
|
||||
|
||||
feedback: str = dspy.OutputField(
|
||||
description="Feedback to the AI agent that generated the search parameters"
|
||||
)
|
||||
|
||||
|
||||
# Config for main agent, handles performance-altering parameters for the agent
|
||||
class PersanaConfig(PrecompiledConfig):
|
||||
model: str = "anthropic/claude-3-5-haiku-20241022"
|
||||
|
||||
|
||||
class FeedbackDict(TypedDict):
|
||||
selected_profiles: List[dict]
|
||||
unselected_profiles: List[dict]
|
||||
user_feedback: Optional[str]
|
||||
|
||||
|
||||
# Main agent that handles the search and feedback
|
||||
class PersanaSearcher(PrecompiledAgent):
|
||||
config: PersanaConfig # ! important, used to link the config class to the agent
|
||||
|
||||
feedback_history: List[SearchFeedback] = []
|
||||
search_parameters: Optional[dict] = None
|
||||
response: Optional[dspy.Prediction] = (
|
||||
None # stores last response (used for debugging purposes)
|
||||
)
|
||||
|
||||
# Constructor, all PrecompiledAgent subclasses must have a constructor that takes in a config.
|
||||
# You can optionally take in other variables specific to the environment
|
||||
# It is recommended to track any performance-altering variables (model, temperature, etc.) in the config, not the constructor
|
||||
def __init__(self, config: PersanaConfig, api_key: Optional[str] = None):
|
||||
super().__init__(config)
|
||||
self.client = PersanaClient(api_key=api_key)
|
||||
self.search_creator = dspy.Predict(ICPTOSearchSignature)
|
||||
self.search_creator.set_lm(dspy.LM(config.model))
|
||||
self.feedback_creator = dspy.Predict(FeedbackSignature)
|
||||
self.feedback_creator.set_lm(dspy.LM(config.model))
|
||||
|
||||
# all Modules must define a forward, contains the logic for the module. Called when you run PrecompiledAgent.__call__
|
||||
def forward(
|
||||
self,
|
||||
company_description: str,
|
||||
target_customer: str,
|
||||
feedback_dict: Optional[FeedbackDict] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Runs a search on the Persana API with the given company description and target customer description.
|
||||
Also saves the search parameters to self.search_parameters so they can be used for feedback later
|
||||
Args:
|
||||
company_description: The description of the company
|
||||
target_customer: The target customer of the company
|
||||
Returns:
|
||||
The profiles returned from the search
|
||||
"""
|
||||
feedback_response = None
|
||||
if feedback_dict is not None:
|
||||
feedback_response = self.enter_feedback(
|
||||
selected_profiles=feedback_dict["selected_profiles"],
|
||||
unselected_profiles=feedback_dict["unselected_profiles"],
|
||||
user_feedback=feedback_dict["user_feedback"],
|
||||
).feedback
|
||||
|
||||
# retry up to 10 times if model generates invalid search parameters
|
||||
search_tool = dspy.Tool(self.client.people_search)
|
||||
for _ in range(10):
|
||||
response = self.search_creator(
|
||||
company_description=company_description,
|
||||
target_customer=target_customer,
|
||||
feedback_history=self.feedback_history,
|
||||
tools=[search_tool],
|
||||
)
|
||||
tool_call = response.tool_calls.tool_calls[0]
|
||||
self.search_parameters = tool_call.args
|
||||
try:
|
||||
profiles = search_tool(**tool_call.args)
|
||||
self.response = response
|
||||
if len(profiles) == 0:
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback="No profiles found. Try loosening the search parameters",
|
||||
)
|
||||
)
|
||||
continue
|
||||
return dspy.Prediction(
|
||||
search_parameters=self.search_parameters,
|
||||
profiles=profiles,
|
||||
feedback=feedback_response,
|
||||
)
|
||||
except TypeError as e:
|
||||
feedback = str(e)
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback=feedback,
|
||||
)
|
||||
)
|
||||
raise Exception("Model failed to generate valid search parameters")
|
||||
|
||||
def enter_feedback(
|
||||
self,
|
||||
selected_profiles: List[dict],
|
||||
unselected_profiles: List[dict],
|
||||
user_feedback: Optional[str] = None,
|
||||
) -> dspy.Prediction:
|
||||
"""
|
||||
Allows user to enter feedback on the previous search results
|
||||
Args:
|
||||
selected_profiles: The profiles the user selected
|
||||
unselected_profiles: The profiles the user did not select
|
||||
user_feedback: User's feedback on the previous search
|
||||
Returns:
|
||||
The feedback Prediction object (for debugging purposes)
|
||||
"""
|
||||
feedback_response = self.feedback_creator(
|
||||
search_parameters=self.search_parameters,
|
||||
selected_profiles=truncate_profiles(selected_profiles),
|
||||
unselected_profiles=truncate_profiles(unselected_profiles),
|
||||
user_feedback=user_feedback,
|
||||
)
|
||||
self.feedback_history.append(
|
||||
SearchFeedback(
|
||||
search_parameters=self.search_parameters,
|
||||
feedback=feedback_response.feedback,
|
||||
)
|
||||
)
|
||||
return feedback_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
||||
results = searcher(
|
||||
company_description="A company that makes software for businesses",
|
||||
target_customer="A business that needs software",
|
||||
)
|
||||
|
||||
print(results)
|
||||
searcher.enter_feedback(
|
||||
selected_profiles=results[:-1],
|
||||
unselected_profiles=[results[-1]],
|
||||
user_feedback="Look for businesses that would need mobile apps",
|
||||
)
|
||||
results = searcher(
|
||||
company_description="A company that makes software for businesses",
|
||||
target_customer="A business that needs software",
|
||||
)
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print("--------------------------------")
|
||||
print(results)
|
||||
print(searcher.feedback_history)
|
||||
print(searcher.search_parameters)
|
||||
113
agent/persana.py
Normal file
113
agent/persana.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import requests
|
||||
from typing import Optional, Literal, List
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
CompanyType = Literal[
|
||||
"Public Company",
|
||||
"Educational",
|
||||
"Self Employed",
|
||||
"Government Agency",
|
||||
"Non Profit",
|
||||
"Self Owned",
|
||||
"Privately Held",
|
||||
"Partnership",
|
||||
]
|
||||
|
||||
|
||||
class PersanaClient:
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or os.getenv("PERSANA_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("PERSANA_KEY is not set")
|
||||
|
||||
def people_search(
|
||||
self,
|
||||
include_job_titles: Optional[list[str]] = None,
|
||||
exclude_job_titles: Optional[list[str]] = None,
|
||||
include_companies: Optional[list[str]] = None,
|
||||
exclude_companies: Optional[list[str]] = None,
|
||||
company_types: Optional[list[CompanyType]] = None,
|
||||
company_include_keywords: Optional[list[str]] = None,
|
||||
company_exclude_keywords: Optional[list[str]] = None,
|
||||
include_industries: Optional[list[str]] = None,
|
||||
exclude_industries: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Runs a persana people search with the given filter parameters and returns the results
|
||||
Args:
|
||||
include_job_titles: The job titles to include in the search. (Optional) should stay set to None unless requested
|
||||
exclude_job_titles: The job titles to exclude from the search. (Optional) should stay set to None unless requested
|
||||
include_companies: The companies to include in the search. (Optional) should stay set to None unless requested
|
||||
exclude_companies: The companies to exclude from the search. (Optional) should stay set to None unless requested
|
||||
company_type: The type of the company. (Optional) should stay set to None unless requested
|
||||
company_include_keywords: The keywords to use to search for the company. (Optional) should stay set to None unless requested
|
||||
company_exclude_keywords: The keywords to exclude from the search for the company. (Optional) should stay set to None unless requested
|
||||
include_industries: The industries to include in the search. (Optional) should stay set to None unless requested
|
||||
exclude_industries: The industries to exclude from the search. (Optional) should stay set to None unless requested
|
||||
Returns:
|
||||
The results of the people search
|
||||
"""
|
||||
params = {
|
||||
"title_includes": include_job_titles,
|
||||
"title_excludes": exclude_job_titles,
|
||||
"companies_includes": include_companies,
|
||||
"companies_excludes": exclude_companies,
|
||||
"company_types": company_types,
|
||||
"company_keywords_includes": company_include_keywords,
|
||||
"company_keywords_excludes": company_exclude_keywords,
|
||||
"industries_includes": include_industries,
|
||||
"industries_excludes": exclude_industries,
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
response = requests.post(
|
||||
"https://prod.api.persana.ai/api/v1/people/search",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self.api_key,
|
||||
},
|
||||
json=params,
|
||||
)
|
||||
json_data = response.json()
|
||||
return json_data["data"]["profiles"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = PersanaClient()
|
||||
result = client.people_search(
|
||||
include_job_titles=["Software Engineer"],
|
||||
exclude_job_titles=["CTO"],
|
||||
include_companies=["Google"],
|
||||
exclude_companies=["Apple"],
|
||||
company_types=["Public Company"],
|
||||
company_include_keywords=["Software"],
|
||||
company_exclude_keywords=["Data"],
|
||||
include_industries=["Technology"],
|
||||
exclude_industries=["Finance"],
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def truncate_profiles(profiles: List[dict]) -> List[dict]:
|
||||
new_profiles = []
|
||||
experience_data_keep_keys = [
|
||||
"title",
|
||||
"company_name",
|
||||
"company_company_headline",
|
||||
"company_hero_image",
|
||||
"company_description",
|
||||
"company_type",
|
||||
]
|
||||
profile_keep_keys = ["name", "headline", "company", "title"]
|
||||
for profile in profiles:
|
||||
new_experience_data = {
|
||||
k: v
|
||||
for k, v in profile["experience_data"].items()
|
||||
if k in experience_data_keep_keys
|
||||
}
|
||||
new_profile = {k: v for k, v in profile.items() if k in profile_keep_keys}
|
||||
new_profile["experience_data"] = new_experience_data
|
||||
new_profiles.append(new_profile)
|
||||
return new_profiles
|
||||
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"AutoConfig": "agent.agent.PersanaConfig",
|
||||
"AutoAgent": "agent.agent.PersanaSearcher"
|
||||
}
|
||||
88
compile.py
Normal file
88
compile.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from agent.agent import PersanaSearcher, PersanaConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import dspy
|
||||
import json
|
||||
from dspy import Prediction, Example
|
||||
|
||||
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
||||
feedback_creator = searcher.feedback_creator
|
||||
searcher.push_to_hub(
|
||||
"swagginty/persana-lead-gen", with_code=True, commit_message="Uncompiled"
|
||||
)
|
||||
|
||||
|
||||
class SearchExample(Example):
|
||||
company_description: str
|
||||
target_customer: str
|
||||
selected_profiles: list[dict]
|
||||
|
||||
|
||||
class SearchPrediction(Prediction):
|
||||
profiles: list[dict]
|
||||
search_parameters: dict
|
||||
|
||||
|
||||
def evaluate_results(
|
||||
target: SearchExample,
|
||||
predictied: SearchPrediction,
|
||||
trace=None,
|
||||
pred_name=None,
|
||||
pred_trace=None,
|
||||
) -> Prediction:
|
||||
"""
|
||||
Evaluates the search results target results were retrieved
|
||||
"""
|
||||
# How many of the target profiles were retrieved
|
||||
pred_ids = {result["profile_id"] for result in predictied.profiles}
|
||||
count = 0
|
||||
for t_result in target.selected_profiles:
|
||||
if t_result["profile_id"] in pred_ids:
|
||||
count += 1
|
||||
|
||||
score = count / len(target.selected_profiles)
|
||||
|
||||
target_ids = {result["profile_id"] for result in target.selected_profiles}
|
||||
# Which retrieved profiles were target profiles
|
||||
selected_preds = [
|
||||
result for result in predictied.profiles if result["profile_id"] in target_ids
|
||||
]
|
||||
# Which retrieved profiles were not target profiles
|
||||
unselected_preds = [
|
||||
result
|
||||
for result in predictied.profiles
|
||||
if result["profile_id"] not in target_ids
|
||||
]
|
||||
# Resuse feedback creator to get feedback for prompt creation
|
||||
feedback = feedback_creator(
|
||||
search_parameters=predictied.search_parameters,
|
||||
selected_profiles=selected_preds,
|
||||
unselected_profiles=unselected_preds,
|
||||
user_feedback=None,
|
||||
).feedback
|
||||
return Prediction(
|
||||
score=score,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
data = [json.loads(line) for line in open("dataset.jsonl", "r")]
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
company_description=e["company_description"],
|
||||
target_customer=e["target_customer"],
|
||||
selected_profiles=e["selected_profiles"],
|
||||
).with_inputs("company_description", "target_customer")
|
||||
for e in data
|
||||
]
|
||||
compiler = dspy.GEPA(
|
||||
metric=evaluate_results,
|
||||
auto="light",
|
||||
reflection_lm=dspy.LM("openai/gpt-5", temperature=1.0, max_tokens=32000),
|
||||
)
|
||||
compiled_searcher = compiler.compile(searcher, trainset=trainset)
|
||||
compiled_searcher.save("compiled_searcher.json")
|
||||
compiled_searcher.push_to_hub("swagginty/persana-lead-gen")
|
||||
4
dataset.jsonl
Normal file
4
dataset.jsonl
Normal file
File diff suppressed because one or more lines are too long
336
main.py
Normal file
336
main.py
Normal file
@@ -0,0 +1,336 @@
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Footer,
|
||||
Header,
|
||||
SelectionList,
|
||||
Button,
|
||||
TextArea,
|
||||
Markdown,
|
||||
MarkdownViewer,
|
||||
)
|
||||
from textual.containers import VerticalScroll, Horizontal, Vertical
|
||||
from textual.binding import Binding
|
||||
from textual.widgets.selection_list import Selection
|
||||
from typing import Optional
|
||||
from typing_extensions import TypedDict
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import dspy
|
||||
from dspy.streaming import (
|
||||
StatusMessage,
|
||||
StatusMessageProvider,
|
||||
StreamListener,
|
||||
StreamResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from modaic import AutoAgent
|
||||
|
||||
load_dotenv()
|
||||
|
||||
searcher = AutoAgent.from_precompiled(
|
||||
"swagginty/persana-lead-gen", api_key=os.getenv("PERSANA_KEY")
|
||||
)
|
||||
|
||||
|
||||
class StartScreen(Screen):
|
||||
BINDINGS = [
|
||||
("ctrl+s", "start", "Start"),
|
||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield TextArea(
|
||||
placeholder="Enter your company description",
|
||||
id="company_description",
|
||||
)
|
||||
yield TextArea(
|
||||
placeholder="Enter your target customer",
|
||||
id="target_customer",
|
||||
)
|
||||
yield Button("Start", classes="buttons", id="start", action="start")
|
||||
yield Footer()
|
||||
|
||||
def action_start(self) -> None:
|
||||
self.app.company_description = self.query_one(
|
||||
"#company_description", TextArea
|
||||
).text
|
||||
self.app.target_customer = self.query_one("#target_customer", TextArea).text
|
||||
self.app.profiles = None
|
||||
self.app.all_selected_profiles = []
|
||||
self.app.push_screen("select")
|
||||
|
||||
|
||||
class Status(StatusMessageProvider):
|
||||
def tool_start_status_message(self, instance, inputs):
|
||||
return "\n\nRunning search...\n\n" + self.dict_to_markdown_table(
|
||||
inputs["kwargs"]
|
||||
)
|
||||
|
||||
def tool_end_status_message(self, instance, outputs):
|
||||
"\n\n Done \n\n"
|
||||
|
||||
def dict_to_markdown_table(self, d: dict) -> str:
|
||||
"""Convert a dict into a Markdown table with 'Name' and 'Value' columns."""
|
||||
headers = ["Param", "Value"]
|
||||
header_row = f"| {' | '.join(headers)} |"
|
||||
separator = f"| {' | '.join(['---'] * len(headers))} |"
|
||||
|
||||
rows = []
|
||||
for key, value in d.items():
|
||||
formatted_value = self.format_value(value)
|
||||
rows.append(f"| {key} | {formatted_value} |")
|
||||
|
||||
return "\n".join([header_row, separator, *rows])
|
||||
|
||||
def format_value(self, v):
|
||||
if isinstance(v, list):
|
||||
return ", ".join(v)
|
||||
return str(v)
|
||||
|
||||
|
||||
streaming_searcher = dspy.streamify(
|
||||
searcher,
|
||||
status_message_provider=Status(),
|
||||
stream_listeners=[
|
||||
StreamListener(
|
||||
predict=searcher.feedback_creator,
|
||||
signature_field_name="feedback",
|
||||
allow_reuse=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class SelectScreen(Screen):
|
||||
BINDINGS = [
|
||||
Binding(key="ctrl+r", action="refine", description="Refine", priority=True),
|
||||
Binding(key="ctrl+s", action="submit", description="Submit", priority=True),
|
||||
Binding(key="escape", action="back", description="Back"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Horizontal(
|
||||
Vertical(
|
||||
VerticalScroll(SelectionList[int](id="profiles"), id="profiles_scroll"),
|
||||
TextArea(placeholder="Feedback for next search", id="feedback"),
|
||||
id="left_container",
|
||||
),
|
||||
MarkdownViewer(show_table_of_contents=False),
|
||||
id="main_container",
|
||||
)
|
||||
yield Horizontal(
|
||||
Button("Refine", id="refine", action="screen.refine"),
|
||||
Button("Submit", id="submit", action="screen.submit"),
|
||||
classes="buttons",
|
||||
)
|
||||
yield Footer(id="footer")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.query_one("#profiles", SelectionList)._bindings.bind(
|
||||
"right", "screen.view_profile", description="View Profile", priority=True
|
||||
)
|
||||
self.query_one("#footer", Footer).refresh_bindings()
|
||||
|
||||
def on_screen_resume(self, event) -> None:
|
||||
if self.app.profiles is None:
|
||||
self.run_worker(self.run_search(with_feedback=False), exclusive=True)
|
||||
|
||||
def action_refine(self) -> None:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
||||
self.run_worker(self.run_search(with_feedback=True), exclusive=True)
|
||||
|
||||
def action_submit(self) -> None:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
||||
visited = set()
|
||||
new_selected_profiles = []
|
||||
for profile in self.app.all_selected_profiles:
|
||||
if profile["profile_id"] not in visited:
|
||||
new_selected_profiles.append(profile)
|
||||
visited.add(profile["profile_id"])
|
||||
self.app.all_selected_profiles = new_selected_profiles
|
||||
entry = {
|
||||
"company_description": self.app.company_description,
|
||||
"target_customer": self.app.target_customer,
|
||||
"selected_profiles": self.app.all_selected_profiles,
|
||||
}
|
||||
with open("dataset.jsonl", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry))
|
||||
f.write("\n")
|
||||
|
||||
self.app.all_selected_profiles = []
|
||||
self.clear_content()
|
||||
self.app.push_screen("start")
|
||||
|
||||
def action_view_profile(self) -> None:
|
||||
selected_profile = self.query_one("#profiles", SelectionList).highlighted
|
||||
self.app.push_screen(ProfileScreen(selected_profile))
|
||||
|
||||
def action_back(self) -> None:
|
||||
if self.focused.id != "feedback":
|
||||
self.app.push_screen("start")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if event.button.id == "refine":
|
||||
self.action_refine()
|
||||
elif event.button.id == "submit":
|
||||
self.action_submit()
|
||||
|
||||
async def run_search(
|
||||
self,
|
||||
with_feedback: bool = False,
|
||||
):
|
||||
"""
|
||||
Runs a search using the search agent and updates the pofiles list in UI. If with_feedback is True,
|
||||
it will enter feedback and update the profiles list.
|
||||
"""
|
||||
viewer = self.query_one(MarkdownViewer)
|
||||
ui_log = self.query_one(Markdown)
|
||||
sl = self.query_one("#profiles", SelectionList)
|
||||
sl.clear_options()
|
||||
scroll = self.query_one("#profiles_scroll", VerticalScroll)
|
||||
scroll.loading = True
|
||||
feedback_dict = None
|
||||
if with_feedback:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
num_profiles = len(self.query_one("#profiles", SelectionList).options)
|
||||
unselected_profiles = list(
|
||||
set(range(num_profiles)) - set(selected_profiles)
|
||||
)
|
||||
feedback_dict = {
|
||||
"selected_profiles": self.get_profiles(selected_profiles),
|
||||
"unselected_profiles": self.get_profiles(unselected_profiles),
|
||||
"user_feedback": self.query_one("#feedback", TextArea).text,
|
||||
}
|
||||
stream = streaming_searcher(
|
||||
company_description=self.app.company_description,
|
||||
target_customer=self.app.target_customer,
|
||||
feedback_dict=feedback_dict,
|
||||
)
|
||||
async for message in stream:
|
||||
if isinstance(message, StatusMessage):
|
||||
ui_log.append(message.message)
|
||||
elif isinstance(message, dspy.Prediction):
|
||||
self.app.profiles = message.profiles
|
||||
elif isinstance(message, ModelResponseStream):
|
||||
if message.choices[0].delta.content is not None:
|
||||
ui_log.append(message.choices[0].delta.content)
|
||||
elif isinstance(message, StreamResponse):
|
||||
ui_log.append(message.chunk)
|
||||
viewer.scroll_end(force=True)
|
||||
|
||||
for i, profile in enumerate(self.app.profiles):
|
||||
summary = f"{profile['name']} - {profile['experience_data']['title']} | {profile['experience_data']['company_name']}"
|
||||
option = Selection(summary, i)
|
||||
sl.add_option(option)
|
||||
scroll.loading = False
|
||||
|
||||
def get_profiles(self, indices: list[int]) -> list[dict]:
|
||||
return [self.app.profiles[i] for i in indices]
|
||||
|
||||
def clear_content(self) -> None:
|
||||
self.query_one("#profiles", SelectionList).clear_options()
|
||||
self.query_one("#profiles_scroll", VerticalScroll).loading = False
|
||||
self.query_one("#feedback", TextArea).text = ""
|
||||
self.query_one(Markdown).update("")
|
||||
|
||||
|
||||
class ProfileScreen(Screen):
|
||||
BINDINGS = [
|
||||
Binding(
|
||||
key="left",
|
||||
action="app.push_screen('select')",
|
||||
description="Back to Select",
|
||||
show=True,
|
||||
priority=True,
|
||||
),
|
||||
Binding(
|
||||
key="escape",
|
||||
action="app.pop_screen",
|
||||
description="Back",
|
||||
priority=True,
|
||||
show=True,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, profile: dict) -> None:
|
||||
super().__init__()
|
||||
self.profile = profile
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Markdown(
|
||||
markdown=self.get_profile_markdown(self.app.profiles[self.profile]),
|
||||
id="profile",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def get_profile_markdown(self, profile: dict):
|
||||
headline = f"{profile['headline']} \n" if profile["headline"] else ""
|
||||
md_text = (
|
||||
f"# {profile['name']} - {profile['experience_data']['title']} · {profile['locality']} [🔗]({profile['url']})",
|
||||
f"{headline}",
|
||||
f"## {profile['experience_data']['company_name']} - {profile['experience_data']['company_locality']} · {profile['experience_data']['company_type']} [🔗]({profile['experience_data']['company_linkedin_url']})",
|
||||
f"{profile['experience_data']['company_description']}",
|
||||
)
|
||||
md_text = "\n".join(md_text)
|
||||
return md_text
|
||||
|
||||
|
||||
class DatasetEntry(TypedDict):
|
||||
company_description: str
|
||||
target_customer: str
|
||||
selected_profiles: list[dict]
|
||||
|
||||
|
||||
class PersanaApp(App):
|
||||
SCREENS = {"start": StartScreen, "select": SelectScreen, "profile": ProfileScreen}
|
||||
BINDINGS = [
|
||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
||||
]
|
||||
CSS = """
|
||||
Horizontal {
|
||||
height: 3;
|
||||
}
|
||||
|
||||
Button {
|
||||
width: 1fr; /* 👈 expands equally like flex:1 */
|
||||
margin: 1;
|
||||
background: #bae0f5;
|
||||
}
|
||||
#start{
|
||||
height: 3;
|
||||
}
|
||||
#left_container {
|
||||
height: 1fr;
|
||||
}
|
||||
#main_container {
|
||||
height: 1fr;
|
||||
}
|
||||
#log {
|
||||
height: 5;
|
||||
width: 1fr;
|
||||
}
|
||||
"""
|
||||
company_description: Optional[str] = None
|
||||
target_customer: Optional[str] = None
|
||||
profiles: Optional[list[dict]] = None
|
||||
all_selected_profiles: list[dict] = []
|
||||
dataset: list[DatasetEntry] = []
|
||||
log_buffer: str = ""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield StartScreen()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.push_screen("start")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = PersanaApp()
|
||||
app.run()
|
||||
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[project]
|
||||
name = "persana-lead-gen"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = ["modaic>=0.3.0", "rich>=14.1.0", "textual>=6.2.1", "textual-dev>=1.7.0"]
|
||||
|
||||
Reference in New Issue
Block a user