From d178be7c0e94ce43716a69463d809643a3867981 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Thu, 27 Nov 2025 11:29:18 -0500 Subject: [PATCH] Fix init to accept kwargs --- agent.json | 27 ++++++++++++++++ auto_classes.json | 4 +-- ce_ranker.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++ config.json | 1 + hello.py | 49 ----------------------------- 5 files changed, 108 insertions(+), 51 deletions(-) create mode 100644 ce_ranker.py delete mode 100644 hello.py diff --git a/agent.json b/agent.json index 23da2b4..51bdb7a 100644 --- a/agent.json +++ b/agent.json @@ -1,4 +1,31 @@ { + "reranker.predict": { + "traces": [], + "train": [], + "demos": [], + "signature": { + "instructions": "Assess the relevance of a document to a query.", + "fields": [ + { + "prefix": "Query:", + "description": "${query}" + }, + { + "prefix": "Document:", + "description": "${document}" + }, + { + "prefix": "Reasoning: Let's think step by step in order to", + "description": "${reasoning}" + }, + { + "prefix": "Relevance Score:", + "description": "${relevance_score}" + } + ] + }, + "lm": null + }, "metadata": { "dependency_versions": { "python": "3.11", diff --git a/auto_classes.json b/auto_classes.json index 6685e21..7031b8b 100644 --- a/auto_classes.json +++ b/auto_classes.json @@ -1,4 +1,4 @@ { - "AutoConfig": "hello.EchoConfig", - "AutoAgent": "hello.EchoAgent" + "AutoConfig": "ce_ranker.CERankerConfig", + "AutoAgent": "ce_ranker.CERankerAgent" } \ No newline at end of file diff --git a/ce_ranker.py b/ce_ranker.py new file mode 100644 index 0000000..81e937a --- /dev/null +++ b/ce_ranker.py @@ -0,0 +1,78 @@ +from dotenv import load_dotenv +import os +import asyncio + +import dspy +from modaic import PrecompiledAgent, PrecompiledConfig +import weaviate + +load_dotenv() + + +class RelevanceAssessment(dspy.Signature): + """Assess the relevance of a document to a query.""" + query: str = dspy.InputField() + document: str = dspy.InputField() + relevance_score: bool = dspy.OutputField() + + +class CERankerConfig(PrecompiledConfig): + lm: str = "openai/gpt-4.1-mini" + collection_name: str + return_properties: list[str] + + +class CERankerAgent(PrecompiledAgent): + config: CERankerConfig + + def __init__(self, config: CERankerConfig, **kwargs): + super().__init__(config, **kwargs) + + lm = dspy.LM(self.config.lm) + dspy.configure(lm=lm) + + self._connect_to_weaviate() + self.reranker = dspy.ChainOfThought(RelevanceAssessment) + + def _connect_to_weaviate(self): + self.weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=os.getenv("WEAVIATE_URL"), + auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")), + ) + self.collection = self.weaviate_client.collections.get( + self.config.collection_name + ) + + async def _score_document(self, query: str, document: str) -> tuple[str, bool]: + result = await self.reranker.acall(query=query, document=document) + return (document, result.relevance_score) + + async def __acall__(self, query: str, k: int = 1) -> str: + response = self.collection.query.hybrid(query=query, limit=k) + + documents = [o.properties["content"] for o in response.objects] + + scored_results = await asyncio.gather( + *[self._score_document(query, doc) for doc in documents] + ) + + scored_results.sort(key=lambda x: x[1], reverse=True) + + return "\n".join([doc for doc, score in scored_results[:k]]) + + def __call__(self, query: str, k: int = 1) -> str: + return asyncio.run(self.__acall__(query, k)) + + +if __name__ == "__main__": + config = CERankerConfig( + collection_name="IRPapersText_Default", + return_properties=["content"] + ) + agent = CERankerAgent(config) + print(agent(query="What is HyDE?")) + agent.push_to_hub( + "connor/CrossEncoderRanker", + with_code=True, + commit_message="Fix init to accept kwargs" + ) \ No newline at end of file diff --git a/config.json b/config.json index f93b409..5c2646d 100644 --- a/config.json +++ b/config.json @@ -1,4 +1,5 @@ { + "lm": "openai/gpt-4.1-mini", "collection_name": "IRPapersText_Default", "return_properties": [ "content" diff --git a/hello.py b/hello.py deleted file mode 100644 index 0206b9e..0000000 --- a/hello.py +++ /dev/null @@ -1,49 +0,0 @@ -from dotenv import load_dotenv -import os - -from modaic import PrecompiledAgent, PrecompiledConfig -import weaviate - -load_dotenv() - -class EchoConfig(PrecompiledConfig): - collection_name: str - return_properties: list[str] - -class EchoAgent(PrecompiledAgent): - config: EchoConfig - - def __init__(self, config: EchoConfig, **kwargs): - super().__init__(config, **kwargs) - self.weaviate_client = weaviate.connect_to_weaviate_cloud( - cluster_url=os.getenv("WEAVIATE_URL"), - auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")), - ) - self.collection = self.weaviate_client.collections.get( - self.config.collection_name - ) - - def forward(self, query: str) -> str: - response = self.collection.query.hybrid( - query=query, - limit=1 - ) - results = [] - for o in response.objects: - results.append(o.properties["content"]) - - return "\n".join(results) - - -if __name__ == "__main__": - config = EchoConfig( - collection_name="IRPapersText_Default", - return_properties=["content"] - ) - agent = EchoAgent(config) - print(agent(query="What is HyDE?")) - agent.push_to_hub( - "connor/CrossEncoderRanker", - with_code=True, - commit_message="Fix init to accept kwargs" - ) \ No newline at end of file