49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
from dotenv import load_dotenv
|
|
import os
|
|
|
|
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
import weaviate
|
|
|
|
load_dotenv()
|
|
|
|
class EchoConfig(PrecompiledConfig):
|
|
collection_name: str
|
|
return_properties: list[str]
|
|
|
|
class EchoAgent(PrecompiledAgent):
|
|
config: EchoConfig
|
|
|
|
def __init__(self, config: EchoConfig, **kwargs):
|
|
super().__init__(config, **kwargs)
|
|
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
|
cluster_url=os.getenv("WEAVIATE_URL"),
|
|
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
|
)
|
|
self.collection = self.weaviate_client.collections.get(
|
|
self.config.collection_name
|
|
)
|
|
|
|
def forward(self, query: str) -> str:
|
|
response = self.collection.query.hybrid(
|
|
query=query,
|
|
limit=1
|
|
)
|
|
results = []
|
|
for o in response.objects:
|
|
results.append(o.properties["content"])
|
|
|
|
return "\n".join(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = EchoConfig(
|
|
collection_name="IRPapersText_Default",
|
|
return_properties=["content"]
|
|
)
|
|
agent = EchoAgent(config)
|
|
print(agent(query="What is HyDE?"))
|
|
agent.push_to_hub(
|
|
"connor/CrossEncoderRanker",
|
|
with_code=True,
|
|
commit_message="Fix init to accept kwargs"
|
|
) |