first ever agent uploaded to modaic hub

This commit is contained in:
2025-08-17 00:09:02 -04:00
parent aae5b37374
commit 6666c69576
14 changed files with 627 additions and 1 deletions

205
agent/indexer.py Normal file
View File

@@ -0,0 +1,205 @@
from modaic import Indexer
from modaic.databases import VectorDatabase, MilvusVDBConfig, SearchResult
from typing import List, Literal, Tuple
import dspy
from langchain_text_splitters import RecursiveCharacterTextSplitter
import os
import json
from modaic.databases import SQLDatabase, SQLiteConfig
from modaic.context import (
ContextSchema,
Table,
LongText,
Text,
Source,
SourceType,
TextSchema,
)
from modaic.indexing import PineconeReranker, Embedder
from dotenv import load_dotenv
from tqdm.auto import tqdm # auto picks the right frontend
from modaic.context.query_language import Filter
from agent.config import TableRAGConfig
load_dotenv()
class TableRagIndexer(Indexer):
config_class = TableRAGConfig
def __init__(
self,
*args,
vdb_config: MilvusVDBConfig,
sql_config: SQLiteConfig,
**kwargs,
):
super().__init__(*args, **kwargs)
self.embedder = Embedder(model="openai/text-embedding-3-small")
self.vector_database = VectorDatabase(
config=vdb_config,
embedder=self.embedder,
payload_schema=Text.schema,
)
self.sql_db = SQLDatabase(config=sql_config)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
self.reranker = PineconeReranker(
model="bge-reranker-v2-m3", api_key=os.getenv("PINECONE_API_KEY")
)
self.last_query = None
# self.vector_database.load_collection("table_rag")
self.vector_database.create_collection(
"table_rag", Text.schema, exists_behavior="append"
)
def ingest(self, files: List[str] | str, *args, **kwargs):
if isinstance(files, str):
files = [os.path.join(files, file) for file in os.listdir(files)]
records = []
with self.sql_db.connect_and_begin():
for file in tqdm(files, desc="Ingesting files", position=0):
if file.endswith((".csv", ".xlsx", ".xls")):
if file.endswith(".csv"):
table = Table.from_csv(file)
elif file.endswith((".xlsx", ".xls")):
table = Table.from_excel(file)
# Add table to file system context store
table.metadata["schema"] = table.schema_info()
# print("TABLE NAME", table.name)
# print("TABLE SCHEMA\n", table.schema_info())
# print("TABLE METADATA\n", table.metadata)
# print()
# print()
# print(table.metadata["schema"])
self.sql_db.add_table(table)
table.chunk_with(self.chunk_table)
records.extend(table.chunks)
elif file.endswith((".json")):
with open(file, "r", encoding="utf-8") as f:
data_split = json.load(f)
key_value_doc = ""
for key, item in data_split.items():
key_value_doc += f"{key} {item}\n"
text_document = LongText(text=key_value_doc)
text_document.chunk_text(self.text_splitter.split_text)
text_document.apply_to_chunks(
lambda chunk: chunk.add_metadata({"type": "text"})
)
records.extend(text_document.chunks)
print("Adding records to vector database")
print("number of records", len(records))
self.vector_database.add_records("table_rag", records, batch_size=10000)
def add():
pass
def delete():
pass
def retrieve(
self,
user_query: str,
k_recall: int = 10,
k_rerank: int = 10,
type: Literal["table", "text", "all"] = "all",
) -> List[ContextSchema]:
results = self.recall(user_query, k_recall, type)
records = [
(result["context_schema"].text, result["context_schema"])
if result["context_schema"].context_class == "Text"
else (
result["context_schema"].metadata["md_chunk"],
result["context_schema"],
)
for result in results
]
results = self.reranker(user_query, records, k_rerank)
results = [result[1] for result in results]
return results
def recall(
self,
user_query: str,
k: int = 10,
type: Literal["table", "text", "all"] = "all",
) -> List[SearchResult]:
embedding = self.embedder([user_query])[0]
if type == "table":
filter = TextSchema.metadata["type"] == "table"
elif type == "text":
filter = TextSchema.metadata["type"] == "text"
else:
filter = None
return self.vector_database.search("table_rag", embedding, k, Filter(filter))
def chunk_table(self, table: Table) -> List[Text]:
# if (
# table.name == "t_5th_new_zealand_parliament_0"
# or table.name == "france_at_the_2013_world_aquatics_championships_0"
# ):
# print("CHUNKING TABLE", table.name)
# print("TABLE SCHEMA\n", table.schema_info())
# print("TABLE METADATA\n", table.metadata)
# print()
# print()
# raise Exception("Stop here")
table_md = LongText(text=table.markdown())
table_md.chunk_text(self.text_splitter.split_text)
table_md.apply_to_chunks(
lambda chunk: chunk.add_metadata(
{"type": "table", "schema": table.metadata["schema"]}
)
)
# raise Exception("Stop here")
return table_md.chunks
def sql_query(self, query: str) -> str:
"""
Query the sql database and get the result as a string.
Args:
query: The sql query to execute.
Returns:
The result of the sql query as a string.
"""
self.last_query = query
try:
return str(self.sql_db.fetchall(query))
except Exception as e:
return f"Error executing sql query: {e}"
def get_table(self, table_id: str) -> Table:
return self.sql_db.get_table(table_id)
if __name__ == "__main__":
indexer = TableRagIndexer(
vdb_config=MilvusVDBConfig.from_local("examples/TableRAG/index2.db"),
sql_config=SQLiteConfig(db_path="examples/TableRAG/tables.db"),
)
excel_dir = "examples/TableRAG/dev_excel"
docs_dir = "examples/TableRAG/dev_doc"
# excels = [os.path.join(excel_dir, file) for file in os.listdir(excel_dir)]
# docs = [os.path.join(docs_dir, file) for file in os.listdir(docs_dir)]
# all_files = excels + docs
# NOTE: will get bad results because not using the entire dataset
# excels = excels[:20]
# indexer.ingest(docs)
# indexer.ingest(excels)
x = indexer.retrieve("Who is the New Zealand Parliament Member for Canterbury")
print(x[0][1])
# print(x[0][1].source)
# print(x[0][1].metadata["schema"])
# print(type(x[0][1]))
# print(x[0][1].metadata["schema"])
# x = indexer.sql_query("SELECT * FROM t_5th_new_zealand_parliament_0")
# print(x)
# print(type(x[0]["context_schema"]))
# print(x[0]["context_schema"])