diff --git a/ce_ranker.py b/ce_ranker.py index fbe68d9..49a9a04 100644 --- a/ce_ranker.py +++ b/ce_ranker.py @@ -48,23 +48,24 @@ class CERankerAgent(PrecompiledAgent): result = await self.reranker.acall(query=query, document=document) return (document, result.relevance_score) - async def __acall__(self, query: str, k: int | None = None) -> str: + async def __acall__(self, query: str, k: int | None = None) -> list[str]: if k is None: k = self.k - - response = self.collection.query.hybrid(query=query, limit=self.k) - + + response = self.collection.query.hybrid(query=query, limit=k) documents = [o.properties["content"] for o in response.objects] - + scored_results = await asyncio.gather( *[self._score_document(query, doc) for doc in documents] ) - - scored_results.sort(key=lambda x: x[1], reverse=True) - - return "\n".join([doc for doc, score in scored_results[:self.k]]) - def __call__(self, query: str, k: int | None = None) -> str: + # Sort by the score (descending). If score is True, it comes first. + scored_results.sort(key=lambda x: x[1], reverse=True) + + # Return a list of documents (not joined as a string), up to k results + return [doc for doc, score in scored_results[:k]] + + def __call__(self, query: str, k: int | None = None) -> list[str]: return asyncio.run(self.__acall__(query, k))