(no commit message)
This commit is contained in:
336
main.py
336
main.py
@@ -1,336 +0,0 @@
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Footer,
|
||||
Header,
|
||||
SelectionList,
|
||||
Button,
|
||||
TextArea,
|
||||
Markdown,
|
||||
MarkdownViewer,
|
||||
)
|
||||
from textual.containers import VerticalScroll, Horizontal, Vertical
|
||||
from textual.binding import Binding
|
||||
from textual.widgets.selection_list import Selection
|
||||
from typing import Optional
|
||||
from typing_extensions import TypedDict
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import dspy
|
||||
from dspy.streaming import (
|
||||
StatusMessage,
|
||||
StatusMessageProvider,
|
||||
StreamListener,
|
||||
StreamResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from modaic import AutoAgent
|
||||
|
||||
load_dotenv()
|
||||
|
||||
searcher = AutoAgent.from_precompiled(
|
||||
"swagginty/persana-lead-gen", api_key=os.getenv("PERSANA_KEY")
|
||||
)
|
||||
|
||||
|
||||
class StartScreen(Screen):
|
||||
BINDINGS = [
|
||||
("ctrl+s", "start", "Start"),
|
||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield TextArea(
|
||||
placeholder="Enter your company description",
|
||||
id="company_description",
|
||||
)
|
||||
yield TextArea(
|
||||
placeholder="Enter your target customer",
|
||||
id="target_customer",
|
||||
)
|
||||
yield Button("Start", classes="buttons", id="start", action="start")
|
||||
yield Footer()
|
||||
|
||||
def action_start(self) -> None:
|
||||
self.app.company_description = self.query_one(
|
||||
"#company_description", TextArea
|
||||
).text
|
||||
self.app.target_customer = self.query_one("#target_customer", TextArea).text
|
||||
self.app.profiles = None
|
||||
self.app.all_selected_profiles = []
|
||||
self.app.push_screen("select")
|
||||
|
||||
|
||||
class Status(StatusMessageProvider):
|
||||
def tool_start_status_message(self, instance, inputs):
|
||||
return "\n\nRunning search...\n\n" + self.dict_to_markdown_table(
|
||||
inputs["kwargs"]
|
||||
)
|
||||
|
||||
def tool_end_status_message(self, instance, outputs):
|
||||
"\n\n Done \n\n"
|
||||
|
||||
def dict_to_markdown_table(self, d: dict) -> str:
|
||||
"""Convert a dict into a Markdown table with 'Name' and 'Value' columns."""
|
||||
headers = ["Param", "Value"]
|
||||
header_row = f"| {' | '.join(headers)} |"
|
||||
separator = f"| {' | '.join(['---'] * len(headers))} |"
|
||||
|
||||
rows = []
|
||||
for key, value in d.items():
|
||||
formatted_value = self.format_value(value)
|
||||
rows.append(f"| {key} | {formatted_value} |")
|
||||
|
||||
return "\n".join([header_row, separator, *rows])
|
||||
|
||||
def format_value(self, v):
|
||||
if isinstance(v, list):
|
||||
return ", ".join(v)
|
||||
return str(v)
|
||||
|
||||
|
||||
streaming_searcher = dspy.streamify(
|
||||
searcher,
|
||||
status_message_provider=Status(),
|
||||
stream_listeners=[
|
||||
StreamListener(
|
||||
predict=searcher.feedback_creator,
|
||||
signature_field_name="feedback",
|
||||
allow_reuse=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class SelectScreen(Screen):
|
||||
BINDINGS = [
|
||||
Binding(key="ctrl+r", action="refine", description="Refine", priority=True),
|
||||
Binding(key="ctrl+s", action="submit", description="Submit", priority=True),
|
||||
Binding(key="escape", action="back", description="Back"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Horizontal(
|
||||
Vertical(
|
||||
VerticalScroll(SelectionList[int](id="profiles"), id="profiles_scroll"),
|
||||
TextArea(placeholder="Feedback for next search", id="feedback"),
|
||||
id="left_container",
|
||||
),
|
||||
MarkdownViewer(show_table_of_contents=False),
|
||||
id="main_container",
|
||||
)
|
||||
yield Horizontal(
|
||||
Button("Refine", id="refine", action="screen.refine"),
|
||||
Button("Submit", id="submit", action="screen.submit"),
|
||||
classes="buttons",
|
||||
)
|
||||
yield Footer(id="footer")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.query_one("#profiles", SelectionList)._bindings.bind(
|
||||
"right", "screen.view_profile", description="View Profile", priority=True
|
||||
)
|
||||
self.query_one("#footer", Footer).refresh_bindings()
|
||||
|
||||
def on_screen_resume(self, event) -> None:
|
||||
if self.app.profiles is None:
|
||||
self.run_worker(self.run_search(with_feedback=False), exclusive=True)
|
||||
|
||||
def action_refine(self) -> None:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
||||
self.run_worker(self.run_search(with_feedback=True), exclusive=True)
|
||||
|
||||
def action_submit(self) -> None:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
self.app.all_selected_profiles.extend(self.get_profiles(selected_profiles))
|
||||
visited = set()
|
||||
new_selected_profiles = []
|
||||
for profile in self.app.all_selected_profiles:
|
||||
if profile["profile_id"] not in visited:
|
||||
new_selected_profiles.append(profile)
|
||||
visited.add(profile["profile_id"])
|
||||
self.app.all_selected_profiles = new_selected_profiles
|
||||
entry = {
|
||||
"company_description": self.app.company_description,
|
||||
"target_customer": self.app.target_customer,
|
||||
"selected_profiles": self.app.all_selected_profiles,
|
||||
}
|
||||
with open("dataset.jsonl", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry))
|
||||
f.write("\n")
|
||||
|
||||
self.app.all_selected_profiles = []
|
||||
self.clear_content()
|
||||
self.app.push_screen("start")
|
||||
|
||||
def action_view_profile(self) -> None:
|
||||
selected_profile = self.query_one("#profiles", SelectionList).highlighted
|
||||
self.app.push_screen(ProfileScreen(selected_profile))
|
||||
|
||||
def action_back(self) -> None:
|
||||
if self.focused.id != "feedback":
|
||||
self.app.push_screen("start")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if event.button.id == "refine":
|
||||
self.action_refine()
|
||||
elif event.button.id == "submit":
|
||||
self.action_submit()
|
||||
|
||||
async def run_search(
|
||||
self,
|
||||
with_feedback: bool = False,
|
||||
):
|
||||
"""
|
||||
Runs a search using the search agent and updates the pofiles list in UI. If with_feedback is True,
|
||||
it will enter feedback and update the profiles list.
|
||||
"""
|
||||
viewer = self.query_one(MarkdownViewer)
|
||||
ui_log = self.query_one(Markdown)
|
||||
sl = self.query_one("#profiles", SelectionList)
|
||||
sl.clear_options()
|
||||
scroll = self.query_one("#profiles_scroll", VerticalScroll)
|
||||
scroll.loading = True
|
||||
feedback_dict = None
|
||||
if with_feedback:
|
||||
selected_profiles = self.query_one("#profiles", SelectionList).selected
|
||||
num_profiles = len(self.query_one("#profiles", SelectionList).options)
|
||||
unselected_profiles = list(
|
||||
set(range(num_profiles)) - set(selected_profiles)
|
||||
)
|
||||
feedback_dict = {
|
||||
"selected_profiles": self.get_profiles(selected_profiles),
|
||||
"unselected_profiles": self.get_profiles(unselected_profiles),
|
||||
"user_feedback": self.query_one("#feedback", TextArea).text,
|
||||
}
|
||||
stream = streaming_searcher(
|
||||
company_description=self.app.company_description,
|
||||
target_customer=self.app.target_customer,
|
||||
feedback_dict=feedback_dict,
|
||||
)
|
||||
async for message in stream:
|
||||
if isinstance(message, StatusMessage):
|
||||
ui_log.append(message.message)
|
||||
elif isinstance(message, dspy.Prediction):
|
||||
self.app.profiles = message.profiles
|
||||
elif isinstance(message, ModelResponseStream):
|
||||
if message.choices[0].delta.content is not None:
|
||||
ui_log.append(message.choices[0].delta.content)
|
||||
elif isinstance(message, StreamResponse):
|
||||
ui_log.append(message.chunk)
|
||||
viewer.scroll_end(force=True)
|
||||
|
||||
for i, profile in enumerate(self.app.profiles):
|
||||
summary = f"{profile['name']} - {profile['experience_data']['title']} | {profile['experience_data']['company_name']}"
|
||||
option = Selection(summary, i)
|
||||
sl.add_option(option)
|
||||
scroll.loading = False
|
||||
|
||||
def get_profiles(self, indices: list[int]) -> list[dict]:
|
||||
return [self.app.profiles[i] for i in indices]
|
||||
|
||||
def clear_content(self) -> None:
|
||||
self.query_one("#profiles", SelectionList).clear_options()
|
||||
self.query_one("#profiles_scroll", VerticalScroll).loading = False
|
||||
self.query_one("#feedback", TextArea).text = ""
|
||||
self.query_one(Markdown).update("")
|
||||
|
||||
|
||||
class ProfileScreen(Screen):
|
||||
BINDINGS = [
|
||||
Binding(
|
||||
key="left",
|
||||
action="app.push_screen('select')",
|
||||
description="Back to Select",
|
||||
show=True,
|
||||
priority=True,
|
||||
),
|
||||
Binding(
|
||||
key="escape",
|
||||
action="app.pop_screen",
|
||||
description="Back",
|
||||
priority=True,
|
||||
show=True,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, profile: dict) -> None:
|
||||
super().__init__()
|
||||
self.profile = profile
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Markdown(
|
||||
markdown=self.get_profile_markdown(self.app.profiles[self.profile]),
|
||||
id="profile",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def get_profile_markdown(self, profile: dict):
|
||||
headline = f"{profile['headline']} \n" if profile["headline"] else ""
|
||||
md_text = (
|
||||
f"# {profile['name']} - {profile['experience_data']['title']} · {profile['locality']} [🔗]({profile['url']})",
|
||||
f"{headline}",
|
||||
f"## {profile['experience_data']['company_name']} - {profile['experience_data']['company_locality']} · {profile['experience_data']['company_type']} [🔗]({profile['experience_data']['company_linkedin_url']})",
|
||||
f"{profile['experience_data']['company_description']}",
|
||||
)
|
||||
md_text = "\n".join(md_text)
|
||||
return md_text
|
||||
|
||||
|
||||
class DatasetEntry(TypedDict):
|
||||
company_description: str
|
||||
target_customer: str
|
||||
selected_profiles: list[dict]
|
||||
|
||||
|
||||
class PersanaApp(App):
|
||||
SCREENS = {"start": StartScreen, "select": SelectScreen, "profile": ProfileScreen}
|
||||
BINDINGS = [
|
||||
Binding(key="ctrl+x", action="quit", description="Quit", priority=True),
|
||||
]
|
||||
CSS = """
|
||||
Horizontal {
|
||||
height: 3;
|
||||
}
|
||||
|
||||
Button {
|
||||
width: 1fr; /* 👈 expands equally like flex:1 */
|
||||
margin: 1;
|
||||
background: #bae0f5;
|
||||
}
|
||||
#start{
|
||||
height: 3;
|
||||
}
|
||||
#left_container {
|
||||
height: 1fr;
|
||||
}
|
||||
#main_container {
|
||||
height: 1fr;
|
||||
}
|
||||
#log {
|
||||
height: 5;
|
||||
width: 1fr;
|
||||
}
|
||||
"""
|
||||
company_description: Optional[str] = None
|
||||
target_customer: Optional[str] = None
|
||||
profiles: Optional[list[dict]] = None
|
||||
all_selected_profiles: list[dict] = []
|
||||
dataset: list[DatasetEntry] = []
|
||||
log_buffer: str = ""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield StartScreen()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.push_screen("start")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = PersanaApp()
|
||||
app.run()
|
||||
Reference in New Issue
Block a user