diff --git a/ce_ranker.py b/ce_ranker.py index 81e937a..fbe68d9 100644 --- a/ce_ranker.py +++ b/ce_ranker.py @@ -8,7 +8,6 @@ import weaviate load_dotenv() - class RelevanceAssessment(dspy.Signature): """Assess the relevance of a document to a query.""" query: str = dspy.InputField() @@ -17,9 +16,10 @@ class RelevanceAssessment(dspy.Signature): class CERankerConfig(PrecompiledConfig): - lm: str = "openai/gpt-4.1-mini" collection_name: str return_properties: list[str] + k: int + lm: str = "openai/gpt-4.1-mini" class CERankerAgent(PrecompiledAgent): @@ -32,6 +32,7 @@ class CERankerAgent(PrecompiledAgent): dspy.configure(lm=lm) self._connect_to_weaviate() + self.k = config.k self.reranker = dspy.ChainOfThought(RelevanceAssessment) def _connect_to_weaviate(self): @@ -47,8 +48,11 @@ class CERankerAgent(PrecompiledAgent): result = await self.reranker.acall(query=query, document=document) return (document, result.relevance_score) - async def __acall__(self, query: str, k: int = 1) -> str: - response = self.collection.query.hybrid(query=query, limit=k) + async def __acall__(self, query: str, k: int | None = None) -> str: + if k is None: + k = self.k + + response = self.collection.query.hybrid(query=query, limit=self.k) documents = [o.properties["content"] for o in response.objects] @@ -58,16 +62,17 @@ class CERankerAgent(PrecompiledAgent): scored_results.sort(key=lambda x: x[1], reverse=True) - return "\n".join([doc for doc, score in scored_results[:k]]) + return "\n".join([doc for doc, score in scored_results[:self.k]]) - def __call__(self, query: str, k: int = 1) -> str: + def __call__(self, query: str, k: int | None = None) -> str: return asyncio.run(self.__acall__(query, k)) if __name__ == "__main__": config = CERankerConfig( collection_name="IRPapersText_Default", - return_properties=["content"] + return_properties=["content"], + k=5 ) agent = CERankerAgent(config) print(agent(query="What is HyDE?")) diff --git a/config.json b/config.json index 5c2646d..1f67b53 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,8 @@ { - "lm": "openai/gpt-4.1-mini", "collection_name": "IRPapersText_Default", "return_properties": [ "content" - ] + ], + "k": 5, + "lm": "openai/gpt-4.1-mini" } \ No newline at end of file