first ever agent uploaded to modaic hub
This commit is contained in:
205
agent/indexer.py
Normal file
205
agent/indexer.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from modaic import Indexer
|
||||
from modaic.databases import VectorDatabase, MilvusVDBConfig, SearchResult
|
||||
from typing import List, Literal, Tuple
|
||||
import dspy
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
import os
|
||||
import json
|
||||
from modaic.databases import SQLDatabase, SQLiteConfig
|
||||
from modaic.context import (
|
||||
ContextSchema,
|
||||
Table,
|
||||
LongText,
|
||||
Text,
|
||||
Source,
|
||||
SourceType,
|
||||
TextSchema,
|
||||
)
|
||||
from modaic.indexing import PineconeReranker, Embedder
|
||||
from dotenv import load_dotenv
|
||||
from tqdm.auto import tqdm # auto picks the right frontend
|
||||
from modaic.context.query_language import Filter
|
||||
from agent.config import TableRAGConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class TableRagIndexer(Indexer):
|
||||
config_class = TableRAGConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
vdb_config: MilvusVDBConfig,
|
||||
sql_config: SQLiteConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embedder = Embedder(model="openai/text-embedding-3-small")
|
||||
self.vector_database = VectorDatabase(
|
||||
config=vdb_config,
|
||||
embedder=self.embedder,
|
||||
payload_schema=Text.schema,
|
||||
)
|
||||
self.sql_db = SQLDatabase(config=sql_config)
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
self.reranker = PineconeReranker(
|
||||
model="bge-reranker-v2-m3", api_key=os.getenv("PINECONE_API_KEY")
|
||||
)
|
||||
self.last_query = None
|
||||
# self.vector_database.load_collection("table_rag")
|
||||
|
||||
self.vector_database.create_collection(
|
||||
"table_rag", Text.schema, exists_behavior="append"
|
||||
)
|
||||
|
||||
def ingest(self, files: List[str] | str, *args, **kwargs):
|
||||
if isinstance(files, str):
|
||||
files = [os.path.join(files, file) for file in os.listdir(files)]
|
||||
records = []
|
||||
with self.sql_db.connect_and_begin():
|
||||
for file in tqdm(files, desc="Ingesting files", position=0):
|
||||
if file.endswith((".csv", ".xlsx", ".xls")):
|
||||
if file.endswith(".csv"):
|
||||
table = Table.from_csv(file)
|
||||
elif file.endswith((".xlsx", ".xls")):
|
||||
table = Table.from_excel(file)
|
||||
# Add table to file system context store
|
||||
table.metadata["schema"] = table.schema_info()
|
||||
# print("TABLE NAME", table.name)
|
||||
# print("TABLE SCHEMA\n", table.schema_info())
|
||||
# print("TABLE METADATA\n", table.metadata)
|
||||
# print()
|
||||
# print()
|
||||
# print(table.metadata["schema"])
|
||||
self.sql_db.add_table(table)
|
||||
table.chunk_with(self.chunk_table)
|
||||
records.extend(table.chunks)
|
||||
elif file.endswith((".json")):
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
data_split = json.load(f)
|
||||
key_value_doc = ""
|
||||
for key, item in data_split.items():
|
||||
key_value_doc += f"{key} {item}\n"
|
||||
text_document = LongText(text=key_value_doc)
|
||||
text_document.chunk_text(self.text_splitter.split_text)
|
||||
text_document.apply_to_chunks(
|
||||
lambda chunk: chunk.add_metadata({"type": "text"})
|
||||
)
|
||||
records.extend(text_document.chunks)
|
||||
print("Adding records to vector database")
|
||||
print("number of records", len(records))
|
||||
self.vector_database.add_records("table_rag", records, batch_size=10000)
|
||||
|
||||
def add():
|
||||
pass
|
||||
|
||||
def delete():
|
||||
pass
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
user_query: str,
|
||||
k_recall: int = 10,
|
||||
k_rerank: int = 10,
|
||||
type: Literal["table", "text", "all"] = "all",
|
||||
) -> List[ContextSchema]:
|
||||
results = self.recall(user_query, k_recall, type)
|
||||
records = [
|
||||
(result["context_schema"].text, result["context_schema"])
|
||||
if result["context_schema"].context_class == "Text"
|
||||
else (
|
||||
result["context_schema"].metadata["md_chunk"],
|
||||
result["context_schema"],
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
results = self.reranker(user_query, records, k_rerank)
|
||||
results = [result[1] for result in results]
|
||||
return results
|
||||
|
||||
def recall(
|
||||
self,
|
||||
user_query: str,
|
||||
k: int = 10,
|
||||
type: Literal["table", "text", "all"] = "all",
|
||||
) -> List[SearchResult]:
|
||||
embedding = self.embedder([user_query])[0]
|
||||
if type == "table":
|
||||
filter = TextSchema.metadata["type"] == "table"
|
||||
elif type == "text":
|
||||
filter = TextSchema.metadata["type"] == "text"
|
||||
else:
|
||||
filter = None
|
||||
return self.vector_database.search("table_rag", embedding, k, Filter(filter))
|
||||
|
||||
def chunk_table(self, table: Table) -> List[Text]:
|
||||
# if (
|
||||
# table.name == "t_5th_new_zealand_parliament_0"
|
||||
# or table.name == "france_at_the_2013_world_aquatics_championships_0"
|
||||
# ):
|
||||
# print("CHUNKING TABLE", table.name)
|
||||
# print("TABLE SCHEMA\n", table.schema_info())
|
||||
# print("TABLE METADATA\n", table.metadata)
|
||||
# print()
|
||||
# print()
|
||||
# raise Exception("Stop here")
|
||||
table_md = LongText(text=table.markdown())
|
||||
table_md.chunk_text(self.text_splitter.split_text)
|
||||
table_md.apply_to_chunks(
|
||||
lambda chunk: chunk.add_metadata(
|
||||
{"type": "table", "schema": table.metadata["schema"]}
|
||||
)
|
||||
)
|
||||
# raise Exception("Stop here")
|
||||
return table_md.chunks
|
||||
|
||||
def sql_query(self, query: str) -> str:
|
||||
"""
|
||||
Query the sql database and get the result as a string.
|
||||
Args:
|
||||
query: The sql query to execute.
|
||||
Returns:
|
||||
The result of the sql query as a string.
|
||||
"""
|
||||
self.last_query = query
|
||||
try:
|
||||
return str(self.sql_db.fetchall(query))
|
||||
except Exception as e:
|
||||
return f"Error executing sql query: {e}"
|
||||
|
||||
def get_table(self, table_id: str) -> Table:
|
||||
return self.sql_db.get_table(table_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
indexer = TableRagIndexer(
|
||||
vdb_config=MilvusVDBConfig.from_local("examples/TableRAG/index2.db"),
|
||||
sql_config=SQLiteConfig(db_path="examples/TableRAG/tables.db"),
|
||||
)
|
||||
excel_dir = "examples/TableRAG/dev_excel"
|
||||
docs_dir = "examples/TableRAG/dev_doc"
|
||||
# excels = [os.path.join(excel_dir, file) for file in os.listdir(excel_dir)]
|
||||
# docs = [os.path.join(docs_dir, file) for file in os.listdir(docs_dir)]
|
||||
# all_files = excels + docs
|
||||
# NOTE: will get bad results because not using the entire dataset
|
||||
# excels = excels[:20]
|
||||
|
||||
# indexer.ingest(docs)
|
||||
# indexer.ingest(excels)
|
||||
x = indexer.retrieve("Who is the New Zealand Parliament Member for Canterbury")
|
||||
print(x[0][1])
|
||||
# print(x[0][1].source)
|
||||
# print(x[0][1].metadata["schema"])
|
||||
|
||||
# print(type(x[0][1]))
|
||||
# print(x[0][1].metadata["schema"])
|
||||
# x = indexer.sql_query("SELECT * FROM t_5th_new_zealand_parliament_0")
|
||||
# print(x)
|
||||
|
||||
# print(type(x[0]["context_schema"]))
|
||||
# print(x[0]["context_schema"])
|
||||
Reference in New Issue
Block a user