Fix init to accept kwargs

This commit is contained in:
Connor Shorten
2025-11-27 11:29:18 -05:00
parent 01159000b8
commit d178be7c0e
5 changed files with 108 additions and 51 deletions

View File

@@ -1,4 +1,31 @@
{ {
"reranker.predict": {
"traces": [],
"train": [],
"demos": [],
"signature": {
"instructions": "Assess the relevance of a document to a query.",
"fields": [
{
"prefix": "Query:",
"description": "${query}"
},
{
"prefix": "Document:",
"description": "${document}"
},
{
"prefix": "Reasoning: Let's think step by step in order to",
"description": "${reasoning}"
},
{
"prefix": "Relevance Score:",
"description": "${relevance_score}"
}
]
},
"lm": null
},
"metadata": { "metadata": {
"dependency_versions": { "dependency_versions": {
"python": "3.11", "python": "3.11",

View File

@@ -1,4 +1,4 @@
{ {
"AutoConfig": "hello.EchoConfig", "AutoConfig": "ce_ranker.CERankerConfig",
"AutoAgent": "hello.EchoAgent" "AutoAgent": "ce_ranker.CERankerAgent"
} }

78
ce_ranker.py Normal file
View File

@@ -0,0 +1,78 @@
from dotenv import load_dotenv
import os
import asyncio
import dspy
from modaic import PrecompiledAgent, PrecompiledConfig
import weaviate
load_dotenv()
class RelevanceAssessment(dspy.Signature):
"""Assess the relevance of a document to a query."""
query: str = dspy.InputField()
document: str = dspy.InputField()
relevance_score: bool = dspy.OutputField()
class CERankerConfig(PrecompiledConfig):
lm: str = "openai/gpt-4.1-mini"
collection_name: str
return_properties: list[str]
class CERankerAgent(PrecompiledAgent):
config: CERankerConfig
def __init__(self, config: CERankerConfig, **kwargs):
super().__init__(config, **kwargs)
lm = dspy.LM(self.config.lm)
dspy.configure(lm=lm)
self._connect_to_weaviate()
self.reranker = dspy.ChainOfThought(RelevanceAssessment)
def _connect_to_weaviate(self):
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=os.getenv("WEAVIATE_URL"),
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
)
self.collection = self.weaviate_client.collections.get(
self.config.collection_name
)
async def _score_document(self, query: str, document: str) -> tuple[str, bool]:
result = await self.reranker.acall(query=query, document=document)
return (document, result.relevance_score)
async def __acall__(self, query: str, k: int = 1) -> str:
response = self.collection.query.hybrid(query=query, limit=k)
documents = [o.properties["content"] for o in response.objects]
scored_results = await asyncio.gather(
*[self._score_document(query, doc) for doc in documents]
)
scored_results.sort(key=lambda x: x[1], reverse=True)
return "\n".join([doc for doc, score in scored_results[:k]])
def __call__(self, query: str, k: int = 1) -> str:
return asyncio.run(self.__acall__(query, k))
if __name__ == "__main__":
config = CERankerConfig(
collection_name="IRPapersText_Default",
return_properties=["content"]
)
agent = CERankerAgent(config)
print(agent(query="What is HyDE?"))
agent.push_to_hub(
"connor/CrossEncoderRanker",
with_code=True,
commit_message="Fix init to accept kwargs"
)

View File

@@ -1,4 +1,5 @@
{ {
"lm": "openai/gpt-4.1-mini",
"collection_name": "IRPapersText_Default", "collection_name": "IRPapersText_Default",
"return_properties": [ "return_properties": [
"content" "content"

View File

@@ -1,49 +0,0 @@
from dotenv import load_dotenv
import os
from modaic import PrecompiledAgent, PrecompiledConfig
import weaviate
load_dotenv()
class EchoConfig(PrecompiledConfig):
collection_name: str
return_properties: list[str]
class EchoAgent(PrecompiledAgent):
config: EchoConfig
def __init__(self, config: EchoConfig, **kwargs):
super().__init__(config, **kwargs)
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=os.getenv("WEAVIATE_URL"),
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
)
self.collection = self.weaviate_client.collections.get(
self.config.collection_name
)
def forward(self, query: str) -> str:
response = self.collection.query.hybrid(
query=query,
limit=1
)
results = []
for o in response.objects:
results.append(o.properties["content"])
return "\n".join(results)
if __name__ == "__main__":
config = EchoConfig(
collection_name="IRPapersText_Default",
return_properties=["content"]
)
agent = EchoAgent(config)
print(agent(query="What is HyDE?"))
agent.push_to_hub(
"connor/CrossEncoderRanker",
with_code=True,
commit_message="Fix init to accept kwargs"
)