Fix init to accept kwargs
This commit is contained in:
78
ce_ranker.py
Normal file
78
ce_ranker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
import dspy
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
import weaviate
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class RelevanceAssessment(dspy.Signature):
|
||||
"""Assess the relevance of a document to a query."""
|
||||
query: str = dspy.InputField()
|
||||
document: str = dspy.InputField()
|
||||
relevance_score: bool = dspy.OutputField()
|
||||
|
||||
|
||||
class CERankerConfig(PrecompiledConfig):
|
||||
lm: str = "openai/gpt-4.1-mini"
|
||||
collection_name: str
|
||||
return_properties: list[str]
|
||||
|
||||
|
||||
class CERankerAgent(PrecompiledAgent):
|
||||
config: CERankerConfig
|
||||
|
||||
def __init__(self, config: CERankerConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
lm = dspy.LM(self.config.lm)
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
self._connect_to_weaviate()
|
||||
self.reranker = dspy.ChainOfThought(RelevanceAssessment)
|
||||
|
||||
def _connect_to_weaviate(self):
|
||||
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=os.getenv("WEAVIATE_URL"),
|
||||
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
||||
)
|
||||
self.collection = self.weaviate_client.collections.get(
|
||||
self.config.collection_name
|
||||
)
|
||||
|
||||
async def _score_document(self, query: str, document: str) -> tuple[str, bool]:
|
||||
result = await self.reranker.acall(query=query, document=document)
|
||||
return (document, result.relevance_score)
|
||||
|
||||
async def __acall__(self, query: str, k: int = 1) -> str:
|
||||
response = self.collection.query.hybrid(query=query, limit=k)
|
||||
|
||||
documents = [o.properties["content"] for o in response.objects]
|
||||
|
||||
scored_results = await asyncio.gather(
|
||||
*[self._score_document(query, doc) for doc in documents]
|
||||
)
|
||||
|
||||
scored_results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return "\n".join([doc for doc, score in scored_results[:k]])
|
||||
|
||||
def __call__(self, query: str, k: int = 1) -> str:
|
||||
return asyncio.run(self.__acall__(query, k))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = CERankerConfig(
|
||||
collection_name="IRPapersText_Default",
|
||||
return_properties=["content"]
|
||||
)
|
||||
agent = CERankerAgent(config)
|
||||
print(agent(query="What is HyDE?"))
|
||||
agent.push_to_hub(
|
||||
"connor/CrossEncoderRanker",
|
||||
with_code=True,
|
||||
commit_message="Fix init to accept kwargs"
|
||||
)
|
||||
Reference in New Issue
Block a user