(no commit message)
This commit is contained in:
73
service/retrievers.py
Normal file
73
service/retrievers.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Retriever implementations used by the JTBD DSPy agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, List
|
||||
|
||||
from modaic import PrecompiledConfig, Retriever
|
||||
|
||||
|
||||
class NullRetrieverConfig(PrecompiledConfig):
|
||||
"""Configuration placeholder for the null retriever."""
|
||||
|
||||
|
||||
class NotesRetrieverConfig(PrecompiledConfig):
|
||||
"""Serializable configuration for the in-memory notes retriever."""
|
||||
|
||||
notes: List[str] = []
|
||||
top_k: int = 3
|
||||
|
||||
|
||||
class NullRetriever(Retriever):
|
||||
"""No-op retriever for environments without contextual data."""
|
||||
|
||||
config: NullRetrieverConfig
|
||||
|
||||
def __init__(self, config: NullRetrieverConfig | None = None, **kwargs):
|
||||
super().__init__(config or NullRetrieverConfig(), **kwargs)
|
||||
|
||||
def retrieve(self, query: str) -> str: # type: ignore[override]
|
||||
return ""
|
||||
|
||||
|
||||
class NotesRetriever(Retriever):
|
||||
"""Very small keyword-based retriever backed by an in-memory list of notes."""
|
||||
|
||||
config: NotesRetrieverConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notes: Iterable[str] | None = None,
|
||||
top_k: int | None = None,
|
||||
config: NotesRetrieverConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
cfg = NotesRetrieverConfig()
|
||||
cfg.notes = list(notes or [])
|
||||
if top_k is not None:
|
||||
cfg.top_k = int(top_k)
|
||||
else:
|
||||
cfg = config
|
||||
if notes is not None:
|
||||
cfg.notes = list(notes)
|
||||
if top_k is not None:
|
||||
cfg.top_k = int(top_k)
|
||||
|
||||
super().__init__(cfg, **kwargs)
|
||||
|
||||
def retrieve(self, query: str) -> str: # type: ignore[override]
|
||||
terms = {token for token in query.lower().split() if token}
|
||||
if not terms:
|
||||
return ""
|
||||
|
||||
scored: List[tuple[int, str]] = []
|
||||
for note in self.config.notes:
|
||||
tokens = {token for token in note.lower().split() if token}
|
||||
score = len(terms & tokens)
|
||||
if score > 0:
|
||||
scored.append((score, note))
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
top_matches = [note for _, note in scored[: self.config.top_k]]
|
||||
return "\n".join(top_matches)
|
||||
Reference in New Issue
Block a user