74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
"""Retriever implementations used by the JTBD DSPy agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Iterable, List
|
|
|
|
from modaic import PrecompiledConfig, Retriever
|
|
|
|
|
|
class NullRetrieverConfig(PrecompiledConfig):
|
|
"""Configuration placeholder for the null retriever."""
|
|
|
|
|
|
class NotesRetrieverConfig(PrecompiledConfig):
|
|
"""Serializable configuration for the in-memory notes retriever."""
|
|
|
|
notes: List[str] = []
|
|
top_k: int = 3
|
|
|
|
|
|
class NullRetriever(Retriever):
|
|
"""No-op retriever for environments without contextual data."""
|
|
|
|
config: NullRetrieverConfig
|
|
|
|
def __init__(self, config: NullRetrieverConfig | None = None, **kwargs):
|
|
super().__init__(config or NullRetrieverConfig(), **kwargs)
|
|
|
|
def retrieve(self, query: str) -> str: # type: ignore[override]
|
|
return ""
|
|
|
|
|
|
class NotesRetriever(Retriever):
|
|
"""Very small keyword-based retriever backed by an in-memory list of notes."""
|
|
|
|
config: NotesRetrieverConfig
|
|
|
|
def __init__(
|
|
self,
|
|
notes: Iterable[str] | None = None,
|
|
top_k: int | None = None,
|
|
config: NotesRetrieverConfig | None = None,
|
|
**kwargs,
|
|
):
|
|
if config is None:
|
|
cfg = NotesRetrieverConfig()
|
|
cfg.notes = list(notes or [])
|
|
if top_k is not None:
|
|
cfg.top_k = int(top_k)
|
|
else:
|
|
cfg = config
|
|
if notes is not None:
|
|
cfg.notes = list(notes)
|
|
if top_k is not None:
|
|
cfg.top_k = int(top_k)
|
|
|
|
super().__init__(cfg, **kwargs)
|
|
|
|
def retrieve(self, query: str) -> str: # type: ignore[override]
|
|
terms = {token for token in query.lower().split() if token}
|
|
if not terms:
|
|
return ""
|
|
|
|
scored: List[tuple[int, str]] = []
|
|
for note in self.config.notes:
|
|
tokens = {token for token in note.lower().split() if token}
|
|
score = len(terms & tokens)
|
|
if score > 0:
|
|
scored.append((score, note))
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
top_matches = [note for _, note in scored[: self.config.top_k]]
|
|
return "\n".join(top_matches)
|