Fix init to accept kwargs
This commit is contained in:
19
ce_ranker.py
19
ce_ranker.py
@@ -8,7 +8,6 @@ import weaviate
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class RelevanceAssessment(dspy.Signature):
|
class RelevanceAssessment(dspy.Signature):
|
||||||
"""Assess the relevance of a document to a query."""
|
"""Assess the relevance of a document to a query."""
|
||||||
query: str = dspy.InputField()
|
query: str = dspy.InputField()
|
||||||
@@ -17,9 +16,10 @@ class RelevanceAssessment(dspy.Signature):
|
|||||||
|
|
||||||
|
|
||||||
class CERankerConfig(PrecompiledConfig):
|
class CERankerConfig(PrecompiledConfig):
|
||||||
lm: str = "openai/gpt-4.1-mini"
|
|
||||||
collection_name: str
|
collection_name: str
|
||||||
return_properties: list[str]
|
return_properties: list[str]
|
||||||
|
k: int
|
||||||
|
lm: str = "openai/gpt-4.1-mini"
|
||||||
|
|
||||||
|
|
||||||
class CERankerAgent(PrecompiledAgent):
|
class CERankerAgent(PrecompiledAgent):
|
||||||
@@ -32,6 +32,7 @@ class CERankerAgent(PrecompiledAgent):
|
|||||||
dspy.configure(lm=lm)
|
dspy.configure(lm=lm)
|
||||||
|
|
||||||
self._connect_to_weaviate()
|
self._connect_to_weaviate()
|
||||||
|
self.k = config.k
|
||||||
self.reranker = dspy.ChainOfThought(RelevanceAssessment)
|
self.reranker = dspy.ChainOfThought(RelevanceAssessment)
|
||||||
|
|
||||||
def _connect_to_weaviate(self):
|
def _connect_to_weaviate(self):
|
||||||
@@ -47,8 +48,11 @@ class CERankerAgent(PrecompiledAgent):
|
|||||||
result = await self.reranker.acall(query=query, document=document)
|
result = await self.reranker.acall(query=query, document=document)
|
||||||
return (document, result.relevance_score)
|
return (document, result.relevance_score)
|
||||||
|
|
||||||
async def __acall__(self, query: str, k: int = 1) -> str:
|
async def __acall__(self, query: str, k: int | None = None) -> str:
|
||||||
response = self.collection.query.hybrid(query=query, limit=k)
|
if k is None:
|
||||||
|
k = self.k
|
||||||
|
|
||||||
|
response = self.collection.query.hybrid(query=query, limit=self.k)
|
||||||
|
|
||||||
documents = [o.properties["content"] for o in response.objects]
|
documents = [o.properties["content"] for o in response.objects]
|
||||||
|
|
||||||
@@ -58,16 +62,17 @@ class CERankerAgent(PrecompiledAgent):
|
|||||||
|
|
||||||
scored_results.sort(key=lambda x: x[1], reverse=True)
|
scored_results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
return "\n".join([doc for doc, score in scored_results[:k]])
|
return "\n".join([doc for doc, score in scored_results[:self.k]])
|
||||||
|
|
||||||
def __call__(self, query: str, k: int = 1) -> str:
|
def __call__(self, query: str, k: int | None = None) -> str:
|
||||||
return asyncio.run(self.__acall__(query, k))
|
return asyncio.run(self.__acall__(query, k))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = CERankerConfig(
|
config = CERankerConfig(
|
||||||
collection_name="IRPapersText_Default",
|
collection_name="IRPapersText_Default",
|
||||||
return_properties=["content"]
|
return_properties=["content"],
|
||||||
|
k=5
|
||||||
)
|
)
|
||||||
agent = CERankerAgent(config)
|
agent = CERankerAgent(config)
|
||||||
print(agent(query="What is HyDE?"))
|
print(agent(query="What is HyDE?"))
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"lm": "openai/gpt-4.1-mini",
|
|
||||||
"collection_name": "IRPapersText_Default",
|
"collection_name": "IRPapersText_Default",
|
||||||
"return_properties": [
|
"return_properties": [
|
||||||
"content"
|
"content"
|
||||||
]
|
],
|
||||||
|
"k": 5,
|
||||||
|
"lm": "openai/gpt-4.1-mini"
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user