Fix init to accept kwargs
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
{
|
||||
"lm": "gpt-4o"
|
||||
"collection_name": "IRPapersText_Default",
|
||||
"return_properties": [
|
||||
"content"
|
||||
]
|
||||
}
|
||||
31
hello.py
31
hello.py
@@ -1,4 +1,5 @@
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
import weaviate
|
||||
@@ -6,21 +7,41 @@ import weaviate
|
||||
load_dotenv()
|
||||
|
||||
class EchoConfig(PrecompiledConfig):
|
||||
lm: str = "gpt-4o"
|
||||
collection_name: str
|
||||
return_properties: list[str]
|
||||
|
||||
class EchoAgent(PrecompiledAgent):
|
||||
config: EchoConfig
|
||||
|
||||
def __init__(self, config: EchoConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=os.getenv("WEAVIATE_URL"),
|
||||
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
||||
)
|
||||
self.collection = self.weaviate_client.collections.get(
|
||||
self.config.collection_name
|
||||
)
|
||||
|
||||
def forward(self, text: str) -> str:
|
||||
return f"Echo: {text}"
|
||||
def forward(self, query: str) -> str:
|
||||
response = self.collection.query.hybrid(
|
||||
query=query,
|
||||
limit=1
|
||||
)
|
||||
results = []
|
||||
for o in response.objects:
|
||||
results.append(o.properties["content"])
|
||||
|
||||
return "\n".join(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = EchoAgent(EchoConfig())
|
||||
print(agent(text="hello"))
|
||||
config = EchoConfig(
|
||||
collection_name="IRPapersText_Default",
|
||||
return_properties=["content"]
|
||||
)
|
||||
agent = EchoAgent(config)
|
||||
print(agent(query="What is HyDE?"))
|
||||
agent.push_to_hub(
|
||||
"connor/CrossEncoderRanker",
|
||||
with_code=True,
|
||||
|
||||
Reference in New Issue
Block a user