From 01159000b82bc9ee73cd86ae138b278ab17f849d Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Thu, 27 Nov 2025 11:06:37 -0500 Subject: [PATCH] Fix init to accept kwargs --- config.json | 5 ++++- hello.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/config.json b/config.json index 9d13255..f93b409 100644 --- a/config.json +++ b/config.json @@ -1,3 +1,6 @@ { - "lm": "gpt-4o" + "collection_name": "IRPapersText_Default", + "return_properties": [ + "content" + ] } \ No newline at end of file diff --git a/hello.py b/hello.py index 79b8d0f..0206b9e 100644 --- a/hello.py +++ b/hello.py @@ -1,4 +1,5 @@ from dotenv import load_dotenv +import os from modaic import PrecompiledAgent, PrecompiledConfig import weaviate @@ -6,21 +7,41 @@ import weaviate load_dotenv() class EchoConfig(PrecompiledConfig): - lm: str = "gpt-4o" + collection_name: str + return_properties: list[str] class EchoAgent(PrecompiledAgent): config: EchoConfig def __init__(self, config: EchoConfig, **kwargs): super().__init__(config, **kwargs) + self.weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=os.getenv("WEAVIATE_URL"), + auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")), + ) + self.collection = self.weaviate_client.collections.get( + self.config.collection_name + ) - def forward(self, text: str) -> str: - return f"Echo: {text}" + def forward(self, query: str) -> str: + response = self.collection.query.hybrid( + query=query, + limit=1 + ) + results = [] + for o in response.objects: + results.append(o.properties["content"]) + + return "\n".join(results) if __name__ == "__main__": - agent = EchoAgent(EchoConfig()) - print(agent(text="hello")) + config = EchoConfig( + collection_name="IRPapersText_Default", + return_properties=["content"] + ) + agent = EchoAgent(config) + print(agent(query="What is HyDE?")) agent.push_to_hub( "connor/CrossEncoderRanker", with_code=True,