335 lines
11 KiB
Python
335 lines
11 KiB
Python
from textual.app import App, ComposeResult
|
|
from textual.screen import Screen
|
|
from textual.widgets import (
|
|
Footer,
|
|
Header,
|
|
SelectionList,
|
|
Button,
|
|
TextArea,
|
|
Markdown,
|
|
MarkdownViewer,
|
|
)
|
|
from textual.containers import VerticalScroll, Horizontal, Vertical
|
|
from textual.binding import Binding
|
|
from textual.widgets.selection_list import Selection
|
|
from agent.agent import PersanaSearcher, PersanaConfig
|
|
from typing import Optional
|
|
from typing_extensions import TypedDict
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import json
|
|
import dspy
|
|
from dspy.streaming import (
|
|
StatusMessage,
|
|
StatusMessageProvider,
|
|
StreamListener,
|
|
StreamResponse,
|
|
)
|
|
from litellm.types.utils import ModelResponseStream
|
|
|
|
load_dotenv()
|
|
|
|
searcher = PersanaSearcher(config=PersanaConfig(), api_key=os.getenv("PERSANA_KEY"))
|
|
|
|
|
|
class StartScreen(Screen):
|
|
BINDINGS = [
|
|
("ctrl+s", "start", "Start"),
|
|
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
|
]
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield TextArea(
|
|
placeholder="Enter your company description",
|
|
id="company_description",
|
|
)
|
|
yield TextArea(
|
|
placeholder="Enter your target customer",
|
|
id="target_customer",
|
|
)
|
|
yield Button("Start", classes="buttons", id="start", action="start")
|
|
yield Footer()
|
|
|
|
def action_start(self) -> None:
|
|
self.app.company_description = self.query_one(
|
|
"#company_description", TextArea
|
|
).text
|
|
self.app.target_customer = self.query_one("#target_customer", TextArea).text
|
|
self.app.profiles = None
|
|
self.app.all_selected_profiles = []
|
|
self.app.push_screen("select")
|
|
|
|
|
|
class Status(StatusMessageProvider):
|
|
def tool_start_status_message(self, instance, inputs):
|
|
return "\n\nRunning search...\n\n" + self.dict_to_markdown_table(
|
|
inputs["kwargs"]
|
|
)
|
|
|
|
def tool_end_status_message(self, instance, outputs):
|
|
"\n\n Done \n\n"
|
|
|
|
def dict_to_markdown_table(self, d: dict) -> str:
|
|
"""Convert a dict into a Markdown table with 'Name' and 'Value' columns."""
|
|
headers = ["Param", "Value"]
|
|
header_row = f"| {' | '.join(headers)} |"
|
|
separator = f"| {' | '.join(['---'] * len(headers))} |"
|
|
|
|
rows = []
|
|
for key, value in d.items():
|
|
formatted_value = self.format_value(value)
|
|
rows.append(f"| {key} | {formatted_value} |")
|
|
|
|
return "\n".join([header_row, separator, *rows])
|
|
|
|
def format_value(self, v):
|
|
if isinstance(v, list):
|
|
return ", ".join(v)
|
|
return str(v)
|
|
|
|
|
|
streaming_searcher = dspy.streamify(
|
|
searcher,
|
|
status_message_provider=Status(),
|
|
stream_listeners=[
|
|
StreamListener(
|
|
predict=searcher.feedback_creator,
|
|
signature_field_name="feedback",
|
|
allow_reuse=True,
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
class SelectScreen(Screen):
|
|
BINDINGS = [
|
|
Binding(key="ctrl+r", action="refine", description="Refine", priority=True),
|
|
Binding(key="ctrl+s", action="submit", description="Submit", priority=True),
|
|
Binding(key="escape", action="back", description="Back"),
|
|
]
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield Horizontal(
|
|
Vertical(
|
|
VerticalScroll(SelectionList[int](id="profiles"), id="profiles_scroll"),
|
|
TextArea(placeholder="Feedback for next search", id="feedback"),
|
|
id="left_container",
|
|
),
|
|
MarkdownViewer(show_table_of_contents=False),
|
|
id="main_container",
|
|
)
|
|
yield Horizontal(
|
|
Button("Refine", id="refine", action="screen.refine"),
|
|
Button("Submit", id="submit", action="screen.submit"),
|
|
classes="buttons",
|
|
)
|
|
yield Footer(id="footer")
|
|
|
|
def on_mount(self) -> None:
|
|
self.query_one("#profiles", SelectionList)._bindings.bind(
|
|
"right", "screen.view_profile", description="View Profile", priority=True
|
|
)
|
|
self.query_one("#footer", Footer).refresh_bindings()
|
|
|
|
def on_screen_resume(self, event) -> None:
|
|
if self.app.profiles is None:
|
|
self.run_worker(self.run_search(with_feedback=False), exclusive=True)
|
|
|
|
def action_refine(self) -> None:
|
|
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
|
self.run_worker(self.run_search(with_feedback=True), exclusive=True)
|
|
|
|
def action_submit(self) -> None:
|
|
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
|
visited = set()
|
|
new_selected_profiles = []
|
|
for profile in self.app.all_selected_profiles:
|
|
if profile["profile_id"] not in visited:
|
|
new_selected_profiles.append(profile)
|
|
visited.add(profile["profile_id"])
|
|
self.app.all_selected_profiles = new_selected_profiles
|
|
entry = {
|
|
"company_description": self.app.company_description,
|
|
"target_customer": self.app.target_customer,
|
|
"selected_profiles": self.app.all_selected_profiles,
|
|
}
|
|
with open("dataset.jsonl", "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(entry))
|
|
f.write("\n")
|
|
|
|
self.app.all_selected_profiles = []
|
|
self.clear_content()
|
|
self.app.push_screen("start")
|
|
|
|
def action_view_profile(self) -> None:
|
|
selected_profile = self.query_one("#profiles", SelectionList).highlighted
|
|
self.app.push_screen(ProfileScreen(selected_profile))
|
|
|
|
def action_back(self) -> None:
|
|
if self.focused.id != "feedback":
|
|
self.app.push_screen("start")
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
if event.button.id == "refine":
|
|
self.action_refine()
|
|
elif event.button.id == "submit":
|
|
self.action_submit()
|
|
|
|
async def run_search(
|
|
self,
|
|
with_feedback: bool = False,
|
|
):
|
|
"""
|
|
Runs a search using the search agent and updates the pofiles list in UI. If with_feedback is True,
|
|
it will enter feedback and update the profiles list.
|
|
"""
|
|
viewer = self.query_one(MarkdownViewer)
|
|
ui_log = self.query_one(Markdown)
|
|
sl = self.query_one("#profiles", SelectionList)
|
|
sl.clear_options()
|
|
scroll = self.query_one("#profiles_scroll", VerticalScroll)
|
|
scroll.loading = True
|
|
feedback_dict = None
|
|
if with_feedback:
|
|
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
|
num_profiles = len(self.query_one("#profiles", SelectionList).options)
|
|
unselected_profiles = list(
|
|
set(range(num_profiles)) - set(selected_profiles)
|
|
)
|
|
feedback_dict = {
|
|
"selected_profiles": self.get_profiles(selected_profiles),
|
|
"unselected_profiles": self.get_profiles(unselected_profiles),
|
|
"user_feedback": self.query_one("#feedback", TextArea).text,
|
|
}
|
|
stream = streaming_searcher(
|
|
company_description=self.app.company_description,
|
|
target_customer=self.app.target_customer,
|
|
feedback_dict=feedback_dict,
|
|
)
|
|
async for message in stream:
|
|
if isinstance(message, StatusMessage):
|
|
ui_log.append(message.message)
|
|
elif isinstance(message, dspy.Prediction):
|
|
self.app.profiles = message.profiles
|
|
elif isinstance(message, ModelResponseStream):
|
|
if message.choices[0].delta.content is not None:
|
|
ui_log.append(message.choices[0].delta.content)
|
|
elif isinstance(message, StreamResponse):
|
|
ui_log.append(message.chunk)
|
|
viewer.scroll_end(force=True)
|
|
|
|
for i, profile in enumerate(self.app.profiles):
|
|
summary = f"{profile['name']} - {profile['experience_data']['title']} | {profile['experience_data']['company_name']}"
|
|
option = Selection(summary, i)
|
|
sl.add_option(option)
|
|
scroll.loading = False
|
|
|
|
def get_profiles(self, indices: list[int]) -> list[dict]:
|
|
return [self.app.profiles[i] for i in indices]
|
|
|
|
def clear_content(self) -> None:
|
|
self.query_one("#profiles", SelectionList).clear_options()
|
|
self.query_one("#profiles_scroll", VerticalScroll).loading = False
|
|
self.query_one("#feedback", TextArea).text = ""
|
|
self.query_one(Markdown).update("")
|
|
|
|
|
|
class ProfileScreen(Screen):
|
|
BINDINGS = [
|
|
Binding(
|
|
key="left",
|
|
action="app.push_screen('select')",
|
|
description="Back to Select",
|
|
show=True,
|
|
priority=True,
|
|
),
|
|
Binding(
|
|
key="escape",
|
|
action="app.pop_screen",
|
|
description="Back",
|
|
priority=True,
|
|
show=True,
|
|
),
|
|
]
|
|
|
|
def __init__(self, profile: dict) -> None:
|
|
super().__init__()
|
|
self.profile = profile
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield Header()
|
|
yield Markdown(
|
|
markdown=self.get_profile_markdown(self.app.profiles[self.profile]),
|
|
id="profile",
|
|
)
|
|
yield Footer()
|
|
|
|
def get_profile_markdown(self, profile: dict):
|
|
headline = f"{profile['headline']} \n" if profile["headline"] else ""
|
|
md_text = (
|
|
f"# {profile['name']} - {profile['experience_data']['title']} · {profile['locality']} [🔗]({profile['url']})",
|
|
f"{headline}",
|
|
f"## {profile['experience_data']['company_name']} - {profile['experience_data']['company_locality']} · {profile['experience_data']['company_type']} [🔗]({profile['experience_data']['company_linkedin_url']})",
|
|
f"{profile['experience_data']['company_description']}",
|
|
)
|
|
md_text = "\n".join(md_text)
|
|
return md_text
|
|
|
|
|
|
class DatasetEntry(TypedDict):
|
|
company_description: str
|
|
target_customer: str
|
|
selected_profiles: list[dict]
|
|
|
|
|
|
class PersanaApp(App):
|
|
SCREENS = {"start": StartScreen, "select": SelectScreen, "profile": ProfileScreen}
|
|
BINDINGS = [
|
|
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
|
]
|
|
CSS = """
|
|
Horizontal {
|
|
height: 3;
|
|
}
|
|
|
|
Button {
|
|
width: 1fr; /* 👈 expands equally like flex:1 */
|
|
margin: 1;
|
|
background: #bae0f5;
|
|
}
|
|
#start{
|
|
height: 3;
|
|
}
|
|
#left_container {
|
|
height: 1fr;
|
|
}
|
|
#main_container {
|
|
height: 1fr;
|
|
}
|
|
#log {
|
|
height: 5;
|
|
width: 1fr;
|
|
}
|
|
"""
|
|
company_description: Optional[str] = None
|
|
target_customer: Optional[str] = None
|
|
profiles: Optional[list[dict]] = None
|
|
all_selected_profiles: list[dict] = []
|
|
dataset: list[DatasetEntry] = []
|
|
log_buffer: str = ""
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield StartScreen()
|
|
|
|
def on_mount(self) -> None:
|
|
self.push_screen("start")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app = PersanaApp()
|
|
app.run()
|