Fix init to accept kwargs
This commit is contained in:
27
agent.json
27
agent.json
@@ -1,4 +1,31 @@
|
|||||||
{
|
{
|
||||||
|
"reranker.predict": {
|
||||||
|
"traces": [],
|
||||||
|
"train": [],
|
||||||
|
"demos": [],
|
||||||
|
"signature": {
|
||||||
|
"instructions": "Assess the relevance of a document to a query.",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"prefix": "Query:",
|
||||||
|
"description": "${query}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prefix": "Document:",
|
||||||
|
"description": "${document}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prefix": "Reasoning: Let's think step by step in order to",
|
||||||
|
"description": "${reasoning}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prefix": "Relevance Score:",
|
||||||
|
"description": "${relevance_score}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"lm": null
|
||||||
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"dependency_versions": {
|
"dependency_versions": {
|
||||||
"python": "3.11",
|
"python": "3.11",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"AutoConfig": "hello.EchoConfig",
|
"AutoConfig": "ce_ranker.CERankerConfig",
|
||||||
"AutoAgent": "hello.EchoAgent"
|
"AutoAgent": "ce_ranker.CERankerAgent"
|
||||||
}
|
}
|
||||||
78
ce_ranker.py
Normal file
78
ce_ranker.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||||
|
import weaviate
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceAssessment(dspy.Signature):
|
||||||
|
"""Assess the relevance of a document to a query."""
|
||||||
|
query: str = dspy.InputField()
|
||||||
|
document: str = dspy.InputField()
|
||||||
|
relevance_score: bool = dspy.OutputField()
|
||||||
|
|
||||||
|
|
||||||
|
class CERankerConfig(PrecompiledConfig):
|
||||||
|
lm: str = "openai/gpt-4.1-mini"
|
||||||
|
collection_name: str
|
||||||
|
return_properties: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CERankerAgent(PrecompiledAgent):
|
||||||
|
config: CERankerConfig
|
||||||
|
|
||||||
|
def __init__(self, config: CERankerConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
|
lm = dspy.LM(self.config.lm)
|
||||||
|
dspy.configure(lm=lm)
|
||||||
|
|
||||||
|
self._connect_to_weaviate()
|
||||||
|
self.reranker = dspy.ChainOfThought(RelevanceAssessment)
|
||||||
|
|
||||||
|
def _connect_to_weaviate(self):
|
||||||
|
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
||||||
|
cluster_url=os.getenv("WEAVIATE_URL"),
|
||||||
|
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
||||||
|
)
|
||||||
|
self.collection = self.weaviate_client.collections.get(
|
||||||
|
self.config.collection_name
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _score_document(self, query: str, document: str) -> tuple[str, bool]:
|
||||||
|
result = await self.reranker.acall(query=query, document=document)
|
||||||
|
return (document, result.relevance_score)
|
||||||
|
|
||||||
|
async def __acall__(self, query: str, k: int = 1) -> str:
|
||||||
|
response = self.collection.query.hybrid(query=query, limit=k)
|
||||||
|
|
||||||
|
documents = [o.properties["content"] for o in response.objects]
|
||||||
|
|
||||||
|
scored_results = await asyncio.gather(
|
||||||
|
*[self._score_document(query, doc) for doc in documents]
|
||||||
|
)
|
||||||
|
|
||||||
|
scored_results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return "\n".join([doc for doc, score in scored_results[:k]])
|
||||||
|
|
||||||
|
def __call__(self, query: str, k: int = 1) -> str:
|
||||||
|
return asyncio.run(self.__acall__(query, k))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = CERankerConfig(
|
||||||
|
collection_name="IRPapersText_Default",
|
||||||
|
return_properties=["content"]
|
||||||
|
)
|
||||||
|
agent = CERankerAgent(config)
|
||||||
|
print(agent(query="What is HyDE?"))
|
||||||
|
agent.push_to_hub(
|
||||||
|
"connor/CrossEncoderRanker",
|
||||||
|
with_code=True,
|
||||||
|
commit_message="Fix init to accept kwargs"
|
||||||
|
)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
|
"lm": "openai/gpt-4.1-mini",
|
||||||
"collection_name": "IRPapersText_Default",
|
"collection_name": "IRPapersText_Default",
|
||||||
"return_properties": [
|
"return_properties": [
|
||||||
"content"
|
"content"
|
||||||
|
|||||||
49
hello.py
49
hello.py
@@ -1,49 +0,0 @@
|
|||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
|
||||||
|
|
||||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
||||||
import weaviate
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
class EchoConfig(PrecompiledConfig):
|
|
||||||
collection_name: str
|
|
||||||
return_properties: list[str]
|
|
||||||
|
|
||||||
class EchoAgent(PrecompiledAgent):
|
|
||||||
config: EchoConfig
|
|
||||||
|
|
||||||
def __init__(self, config: EchoConfig, **kwargs):
|
|
||||||
super().__init__(config, **kwargs)
|
|
||||||
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
|
||||||
cluster_url=os.getenv("WEAVIATE_URL"),
|
|
||||||
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
|
||||||
)
|
|
||||||
self.collection = self.weaviate_client.collections.get(
|
|
||||||
self.config.collection_name
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
|
||||||
response = self.collection.query.hybrid(
|
|
||||||
query=query,
|
|
||||||
limit=1
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
for o in response.objects:
|
|
||||||
results.append(o.properties["content"])
|
|
||||||
|
|
||||||
return "\n".join(results)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
config = EchoConfig(
|
|
||||||
collection_name="IRPapersText_Default",
|
|
||||||
return_properties=["content"]
|
|
||||||
)
|
|
||||||
agent = EchoAgent(config)
|
|
||||||
print(agent(query="What is HyDE?"))
|
|
||||||
agent.push_to_hub(
|
|
||||||
"connor/CrossEncoderRanker",
|
|
||||||
with_code=True,
|
|
||||||
commit_message="Fix init to accept kwargs"
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user