177 lines
6.2 KiB
Python
177 lines
6.2 KiB
Python
from modaic import Indexer
|
|
from modaic.databases import VectorDatabase, MilvusVDBConfig, SearchResult
|
|
from typing import List, Literal, Tuple
|
|
import dspy
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
import os
|
|
import json
|
|
from modaic.databases import SQLDatabase, SQLiteConfig
|
|
from modaic.context import (
|
|
ContextSchema,
|
|
Table,
|
|
LongText,
|
|
Text,
|
|
Source,
|
|
SourceType,
|
|
TextSchema,
|
|
)
|
|
from modaic.indexing import PineconeReranker, Embedder
|
|
from dotenv import load_dotenv
|
|
from tqdm.auto import tqdm # auto picks the right frontend
|
|
from modaic.context.query_language import Filter
|
|
from agent.config import TableRAGConfig
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class TableRAGIndexer(Indexer):
|
|
config_class = TableRAGConfig
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
vdb_config: MilvusVDBConfig,
|
|
sql_config: SQLiteConfig,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.embedder = Embedder(model="openai/text-embedding-3-small")
|
|
self.vector_database = VectorDatabase(
|
|
config=vdb_config,
|
|
embedder=self.embedder,
|
|
payload_schema=Text.schema,
|
|
)
|
|
self.sql_db = SQLDatabase(config=sql_config)
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
)
|
|
self.reranker = PineconeReranker(
|
|
model="bge-reranker-v2-m3", api_key=os.getenv("PINECONE_API_KEY")
|
|
)
|
|
self.last_query = None
|
|
# self.vector_database.load_collection("table_rag")
|
|
|
|
self.vector_database.create_collection(
|
|
"table_rag", Text.schema, exists_behavior="append"
|
|
)
|
|
|
|
def ingest(self, files: List[str] | str, *args, **kwargs):
|
|
if isinstance(files, str):
|
|
files = [os.path.join(files, file) for file in os.listdir(files)]
|
|
records = []
|
|
with self.sql_db.connect_and_begin():
|
|
for file in tqdm(files, desc="Ingesting files", position=0):
|
|
if file.endswith((".csv", ".xlsx", ".xls")):
|
|
if file.endswith(".csv"):
|
|
table = Table.from_csv(file)
|
|
elif file.endswith((".xlsx", ".xls")):
|
|
table = Table.from_excel(file)
|
|
# Add table to file system context store
|
|
table.metadata["schema"] = table.schema_info()
|
|
# print("TABLE NAME", table.name)
|
|
# print("TABLE SCHEMA\n", table.schema_info())
|
|
# print("TABLE METADATA\n", table.metadata)
|
|
# print()
|
|
# print()
|
|
# print(table.metadata["schema"])
|
|
self.sql_db.add_table(table)
|
|
table.chunk_with(self.chunk_table)
|
|
records.extend(table.chunks)
|
|
elif file.endswith((".json")):
|
|
with open(file, "r", encoding="utf-8") as f:
|
|
data_split = json.load(f)
|
|
key_value_doc = ""
|
|
for key, item in data_split.items():
|
|
key_value_doc += f"{key} {item}\n"
|
|
text_document = LongText(text=key_value_doc)
|
|
text_document.chunk_text(self.text_splitter.split_text)
|
|
text_document.apply_to_chunks(
|
|
lambda chunk: chunk.add_metadata({"type": "text"})
|
|
)
|
|
records.extend(text_document.chunks)
|
|
print("Adding records to vector database")
|
|
print("number of records", len(records))
|
|
self.vector_database.add_records("table_rag", records, batch_size=10000)
|
|
|
|
def add():
|
|
pass
|
|
|
|
def delete():
|
|
pass
|
|
|
|
def retrieve(
|
|
self,
|
|
user_query: str,
|
|
k_recall: int = 10,
|
|
k_rerank: int = 10,
|
|
type: Literal["table", "text", "all"] = "all",
|
|
) -> List[ContextSchema]:
|
|
results = self.recall(user_query, k_recall, type)
|
|
records = [
|
|
(result["context_schema"].text, result["context_schema"])
|
|
if result["context_schema"].context_class == "Text"
|
|
else (
|
|
result["context_schema"].metadata["md_chunk"],
|
|
result["context_schema"],
|
|
)
|
|
for result in results
|
|
]
|
|
|
|
results = self.reranker(user_query, records, k_rerank)
|
|
results = [result[1] for result in results]
|
|
return results
|
|
|
|
def recall(
|
|
self,
|
|
user_query: str,
|
|
k: int = 10,
|
|
type: Literal["table", "text", "all"] = "all",
|
|
) -> List[SearchResult]:
|
|
embedding = self.embedder([user_query])[0]
|
|
if type == "table":
|
|
filter = TextSchema.metadata["type"] == "table"
|
|
elif type == "text":
|
|
filter = TextSchema.metadata["type"] == "text"
|
|
else:
|
|
filter = None
|
|
return self.vector_database.search("table_rag", embedding, k, Filter(filter))
|
|
|
|
def chunk_table(self, table: Table) -> List[Text]:
|
|
# if (
|
|
# table.name == "t_5th_new_zealand_parliament_0"
|
|
# or table.name == "france_at_the_2013_world_aquatics_championships_0"
|
|
# ):
|
|
# print("CHUNKING TABLE", table.name)
|
|
# print("TABLE SCHEMA\n", table.schema_info())
|
|
# print("TABLE METADATA\n", table.metadata)
|
|
# print()
|
|
# print()
|
|
# raise Exception("Stop here")
|
|
table_md = LongText(text=table.markdown())
|
|
table_md.chunk_text(self.text_splitter.split_text)
|
|
table_md.apply_to_chunks(
|
|
lambda chunk: chunk.add_metadata(
|
|
{"type": "table", "schema": table.metadata["schema"]}
|
|
)
|
|
)
|
|
# raise Exception("Stop here")
|
|
return table_md.chunks
|
|
|
|
def sql_query(self, query: str) -> str:
|
|
"""
|
|
Query the sql database and get the result as a string.
|
|
Args:
|
|
query: The sql query to execute.
|
|
Returns:
|
|
The result of the sql query as a string.
|
|
"""
|
|
self.last_query = query
|
|
try:
|
|
return str(self.sql_db.fetchall(query))
|
|
except Exception as e:
|
|
return f"Error executing sql query: {e}"
|
|
|
|
def get_table(self, table_id: str) -> Table:
|
|
return self.sql_db.get_table(table_id)
|