Fix init to accept kwargs

This commit is contained in:
Connor Shorten
2025-11-27 11:31:18 -05:00
parent d178be7c0e
commit b0578059e5
2 changed files with 15 additions and 9 deletions

View File

@@ -8,7 +8,6 @@ import weaviate
load_dotenv() load_dotenv()
class RelevanceAssessment(dspy.Signature): class RelevanceAssessment(dspy.Signature):
"""Assess the relevance of a document to a query.""" """Assess the relevance of a document to a query."""
query: str = dspy.InputField() query: str = dspy.InputField()
@@ -17,9 +16,10 @@ class RelevanceAssessment(dspy.Signature):
class CERankerConfig(PrecompiledConfig): class CERankerConfig(PrecompiledConfig):
lm: str = "openai/gpt-4.1-mini"
collection_name: str collection_name: str
return_properties: list[str] return_properties: list[str]
k: int
lm: str = "openai/gpt-4.1-mini"
class CERankerAgent(PrecompiledAgent): class CERankerAgent(PrecompiledAgent):
@@ -32,6 +32,7 @@ class CERankerAgent(PrecompiledAgent):
dspy.configure(lm=lm) dspy.configure(lm=lm)
self._connect_to_weaviate() self._connect_to_weaviate()
self.k = config.k
self.reranker = dspy.ChainOfThought(RelevanceAssessment) self.reranker = dspy.ChainOfThought(RelevanceAssessment)
def _connect_to_weaviate(self): def _connect_to_weaviate(self):
@@ -47,8 +48,11 @@ class CERankerAgent(PrecompiledAgent):
result = await self.reranker.acall(query=query, document=document) result = await self.reranker.acall(query=query, document=document)
return (document, result.relevance_score) return (document, result.relevance_score)
async def __acall__(self, query: str, k: int = 1) -> str: async def __acall__(self, query: str, k: int | None = None) -> str:
response = self.collection.query.hybrid(query=query, limit=k) if k is None:
k = self.k
response = self.collection.query.hybrid(query=query, limit=self.k)
documents = [o.properties["content"] for o in response.objects] documents = [o.properties["content"] for o in response.objects]
@@ -58,16 +62,17 @@ class CERankerAgent(PrecompiledAgent):
scored_results.sort(key=lambda x: x[1], reverse=True) scored_results.sort(key=lambda x: x[1], reverse=True)
return "\n".join([doc for doc, score in scored_results[:k]]) return "\n".join([doc for doc, score in scored_results[:self.k]])
def __call__(self, query: str, k: int = 1) -> str: def __call__(self, query: str, k: int | None = None) -> str:
return asyncio.run(self.__acall__(query, k)) return asyncio.run(self.__acall__(query, k))
if __name__ == "__main__": if __name__ == "__main__":
config = CERankerConfig( config = CERankerConfig(
collection_name="IRPapersText_Default", collection_name="IRPapersText_Default",
return_properties=["content"] return_properties=["content"],
k=5
) )
agent = CERankerAgent(config) agent = CERankerAgent(config)
print(agent(query="What is HyDE?")) print(agent(query="What is HyDE?"))

View File

@@ -1,7 +1,8 @@
{ {
"lm": "openai/gpt-4.1-mini",
"collection_name": "IRPapersText_Default", "collection_name": "IRPapersText_Default",
"return_properties": [ "return_properties": [
"content" "content"
] ],
"k": 5,
"lm": "openai/gpt-4.1-mini"
} }