From 313ab4c8c6b1c1c06b77a816c5812d85728d2720 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Thu, 27 Nov 2025 11:36:01 -0500 Subject: [PATCH] Fix init to accept kwargs --- ce_ranker.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ce_ranker.py b/ce_ranker.py index fbe68d9..49a9a04 100644 --- a/ce_ranker.py +++ b/ce_ranker.py @@ -48,23 +48,24 @@ class CERankerAgent(PrecompiledAgent): result = await self.reranker.acall(query=query, document=document) return (document, result.relevance_score) - async def __acall__(self, query: str, k: int | None = None) -> str: + async def __acall__(self, query: str, k: int | None = None) -> list[str]: if k is None: k = self.k - - response = self.collection.query.hybrid(query=query, limit=self.k) - + + response = self.collection.query.hybrid(query=query, limit=k) documents = [o.properties["content"] for o in response.objects] - + scored_results = await asyncio.gather( *[self._score_document(query, doc) for doc in documents] ) - - scored_results.sort(key=lambda x: x[1], reverse=True) - - return "\n".join([doc for doc, score in scored_results[:self.k]]) - def __call__(self, query: str, k: int | None = None) -> str: + # Sort by the score (descending). If score is True, it comes first. + scored_results.sort(key=lambda x: x[1], reverse=True) + + # Return a list of documents (not joined as a string), up to k results + return [doc for doc, score in scored_results[:k]] + + def __call__(self, query: str, k: int | None = None) -> list[str]: return asyncio.run(self.__acall__(query, k))