Fix init to accept kwargs

This commit is contained in:
Connor Shorten
2025-11-27 11:06:37 -05:00
parent 71c34ddd16
commit 01159000b8
2 changed files with 30 additions and 6 deletions

View File

@@ -1,3 +1,6 @@
{ {
"lm": "gpt-4o" "collection_name": "IRPapersText_Default",
"return_properties": [
"content"
]
} }

View File

@@ -1,4 +1,5 @@
from dotenv import load_dotenv from dotenv import load_dotenv
import os
from modaic import PrecompiledAgent, PrecompiledConfig from modaic import PrecompiledAgent, PrecompiledConfig
import weaviate import weaviate
@@ -6,21 +7,41 @@ import weaviate
load_dotenv() load_dotenv()
class EchoConfig(PrecompiledConfig): class EchoConfig(PrecompiledConfig):
lm: str = "gpt-4o" collection_name: str
return_properties: list[str]
class EchoAgent(PrecompiledAgent): class EchoAgent(PrecompiledAgent):
config: EchoConfig config: EchoConfig
def __init__(self, config: EchoConfig, **kwargs): def __init__(self, config: EchoConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=os.getenv("WEAVIATE_URL"),
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
)
self.collection = self.weaviate_client.collections.get(
self.config.collection_name
)
def forward(self, text: str) -> str: def forward(self, query: str) -> str:
return f"Echo: {text}" response = self.collection.query.hybrid(
query=query,
limit=1
)
results = []
for o in response.objects:
results.append(o.properties["content"])
return "\n".join(results)
if __name__ == "__main__": if __name__ == "__main__":
agent = EchoAgent(EchoConfig()) config = EchoConfig(
print(agent(text="hello")) collection_name="IRPapersText_Default",
return_properties=["content"]
)
agent = EchoAgent(config)
print(agent(query="What is HyDE?"))
agent.push_to_hub( agent.push_to_hub(
"connor/CrossEncoderRanker", "connor/CrossEncoderRanker",
with_code=True, with_code=True,