Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1087651663 | |||
| 71b7d59419 | |||
| 86b1c6834d |
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"AutoConfig": "modules.GenerateCypherConfig",
|
||||||
|
"AutoProgram": "modules.GenerateCypher"
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
"model": "openrouter/openai/gpt-4o",
|
"model": "openrouter/openai/gpt-4o",
|
||||||
"max_tokens": 1024
|
"max_tokens": 1024,
|
||||||
|
"cache": true
|
||||||
}
|
}
|
||||||
76
modules.py
Normal file
76
modules.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
import dspy
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class CypherFromQuestion(dspy.Signature):
|
||||||
|
"""Task: Generate Cypher statement to query a graph database.
|
||||||
|
Instructions: Use only the provided relationship types and properties in the schema.
|
||||||
|
Do not use any other relationship types or properties that are not provided in the schema.
|
||||||
|
Do not include any explanations or apologies in your responses.
|
||||||
|
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||||
|
Do not include any text except the generated Cypher statement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
question = dspy.InputField(
|
||||||
|
desc="Question to model using a cypher statement. Use only the provided relationship types and properties in the schema."
|
||||||
|
)
|
||||||
|
neo4j_schema = dspy.InputField(
|
||||||
|
desc="Current graph schema in Neo4j as a list of NODES and RELATIONSHIPS."
|
||||||
|
)
|
||||||
|
statement = dspy.OutputField(desc="Cypher statement to query the graph database.")
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateCypherConfig(PrecompiledConfig):
|
||||||
|
model: str = "openrouter/openai/gpt-4o" # OPENROUTER ONLY
|
||||||
|
max_tokens: int = 1024
|
||||||
|
cache: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateCypher(PrecompiledProgram):
|
||||||
|
config: GenerateCypherConfig
|
||||||
|
|
||||||
|
def __init__(self, config: GenerateCypherConfig, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.lm = dspy.LM(
|
||||||
|
model=config.model,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
api_base="https://openrouter.ai/api/v1",
|
||||||
|
cache=config.cache,
|
||||||
|
)
|
||||||
|
self.generate_cypher = dspy.ChainOfThought(CypherFromQuestion)
|
||||||
|
self.generate_cypher.set_lm(self.lm)
|
||||||
|
|
||||||
|
def forward(self, question: str, neo4j_schema: list[str]):
|
||||||
|
return self.generate_cypher(question=question, neo4j_schema=neo4j_schema)
|
||||||
|
|
||||||
|
|
||||||
|
generate_cypher = GenerateCypher(GenerateCypherConfig())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
|
examples_path = Path(__file__).parent / "examples" / "wikipedia-abstracts-v0_0_1.ndjson"
|
||||||
|
with open(examples_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
text = data["text"]
|
||||||
|
print("TEXT TO PROCESS:\n", text[:50])
|
||||||
|
cypher = generate_cypher(text=text, neo4j_schema=neo4j.fmt_schema())
|
||||||
|
neo4j.query(cypher.statement.replace('```', ''))
|
||||||
|
print("CYPHER STATEMENT:\n", cypher.statement)
|
||||||
|
|
||||||
|
schema = neo4j.fmt_schema()
|
||||||
|
print("SCHEMA:\n", schema)
|
||||||
|
"""
|
||||||
|
generate_cypher.push_to_hub(
|
||||||
|
"farouk1/text-to-cypher",
|
||||||
|
with_code=True,
|
||||||
|
tag="v1.0.1",
|
||||||
|
commit_message="Don't cache results",
|
||||||
|
)
|
||||||
78
optimize.py
Normal file
78
optimize.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import dspy
|
||||||
|
from dspy import GEPA
|
||||||
|
from modules import generate_cypher
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def process_dataset():
|
||||||
|
train_split = load_dataset("neo4j/text2cypher-2025v1")["train"]
|
||||||
|
train_split = [
|
||||||
|
dspy.Example(
|
||||||
|
{
|
||||||
|
"question": x["question"],
|
||||||
|
"neo4j_schema": x["schema"],
|
||||||
|
"expected_cypher": x["cypher"],
|
||||||
|
}
|
||||||
|
).with_inputs("question", "neo4j_schema")
|
||||||
|
for x in train_split
|
||||||
|
]
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.Random(0).shuffle(train_split)
|
||||||
|
train_split = train_split[:200]
|
||||||
|
tot_num = len(train_split)
|
||||||
|
|
||||||
|
test_split = load_dataset("neo4j/text2cypher-2025v1")["test"]
|
||||||
|
test_split = [
|
||||||
|
dspy.Example(
|
||||||
|
{
|
||||||
|
"question": x["question"],
|
||||||
|
"neo4j_schema": x["schema"],
|
||||||
|
"expected_cypher": x["cypher"],
|
||||||
|
}
|
||||||
|
).with_inputs("question", "neo4j_schema")
|
||||||
|
for x in test_split
|
||||||
|
]
|
||||||
|
|
||||||
|
train_set = train_split[: int(0.5 * tot_num)]
|
||||||
|
val_set = train_split[int(0.5 * tot_num) :]
|
||||||
|
test_set = test_split[:200]
|
||||||
|
|
||||||
|
return train_set, val_set, test_set
|
||||||
|
|
||||||
|
|
||||||
|
def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
|
||||||
|
from sacrebleu import sentence_bleu
|
||||||
|
|
||||||
|
expected_cypher = example["expected_cypher"]
|
||||||
|
generated_cypher = prediction.statement
|
||||||
|
|
||||||
|
# Calculate sentence-level BLEU (Google BLEU)
|
||||||
|
bleu_score = sentence_bleu(
|
||||||
|
generated_cypher,
|
||||||
|
[expected_cypher], # Reference as a list
|
||||||
|
).score / 100.0 # Normalize to 0-1
|
||||||
|
|
||||||
|
feedback = f"BLEU score: {bleu_score:.3f}"
|
||||||
|
return dspy.Prediction(score=bleu_score, feedback=feedback)
|
||||||
|
|
||||||
|
|
||||||
|
train_set, val_set, test_set = process_dataset()
|
||||||
|
|
||||||
|
optimizer = GEPA(
|
||||||
|
metric=metric,
|
||||||
|
auto="medium",
|
||||||
|
num_threads=32,
|
||||||
|
track_stats=True,
|
||||||
|
reflection_minibatch_size=3,
|
||||||
|
reflection_lm=dspy.LM(model="gpt-5.2", temperature=1.0, max_tokens=32000),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
optimized_program = optimizer.compile(
|
||||||
|
generate_cypher,
|
||||||
|
trainset=train_set,
|
||||||
|
valset=val_set,
|
||||||
|
)
|
||||||
|
optimized_program.push_to_hub("farouk1/text-to-cypher-gepa", tag="v1.0.4", commit_message="Optimized program with code")
|
||||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[project]
|
||||||
|
name = "text-to-cypher-gepa"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = ["datasets>=4.4.2", "dspy>=3.0.4", "modaic>=0.8.3", "neo4j~=5.18.0", "python-dotenv~=1.0.1", "sacrebleu>=2.5.1"]
|
||||||
Reference in New Issue
Block a user