Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 71b7d59419 | |||
| 86b1c6834d |
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"AutoConfig": "modules.GenerateCypherConfig",
|
||||
"AutoProgram": "modules.GenerateCypher"
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"model": "openrouter/openai/gpt-4o",
|
||||
"max_tokens": 1024
|
||||
"max_tokens": 1024,
|
||||
"cache": false
|
||||
}
|
||||
76
modules.py
Normal file
76
modules.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
import dspy
|
||||
from dotenv import load_dotenv
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class CypherFromQuestion(dspy.Signature):
|
||||
"""Task: Generate Cypher statement to query a graph database.
|
||||
Instructions: Use only the provided relationship types and properties in the schema.
|
||||
Do not use any other relationship types or properties that are not provided in the schema.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||
Do not include any text except the generated Cypher statement.
|
||||
"""
|
||||
|
||||
question = dspy.InputField(
|
||||
desc="Question to model using a cypher statement. Use only the provided relationship types and properties in the schema."
|
||||
)
|
||||
neo4j_schema = dspy.InputField(
|
||||
desc="Current graph schema in Neo4j as a list of NODES and RELATIONSHIPS."
|
||||
)
|
||||
statement = dspy.OutputField(desc="Cypher statement to query the graph database.")
|
||||
|
||||
|
||||
class GenerateCypherConfig(PrecompiledConfig):
|
||||
model: str = "openrouter/openai/gpt-4o" # OPENROUTER ONLY
|
||||
max_tokens: int = 1024
|
||||
cache: bool = False
|
||||
|
||||
|
||||
class GenerateCypher(PrecompiledProgram):
|
||||
config: GenerateCypherConfig
|
||||
|
||||
def __init__(self, config: GenerateCypherConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.lm = dspy.LM(
|
||||
model=config.model,
|
||||
max_tokens=config.max_tokens,
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
cache=config.cache,
|
||||
)
|
||||
self.generate_cypher = dspy.ChainOfThought(CypherFromQuestion)
|
||||
self.generate_cypher.set_lm(self.lm)
|
||||
|
||||
def forward(self, question: str, neo4j_schema: list[str]):
|
||||
return self.generate_cypher(question=question, neo4j_schema=neo4j_schema)
|
||||
|
||||
|
||||
generate_cypher = GenerateCypher(GenerateCypherConfig())
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
examples_path = Path(__file__).parent / "examples" / "wikipedia-abstracts-v0_0_1.ndjson"
|
||||
with open(examples_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
text = data["text"]
|
||||
print("TEXT TO PROCESS:\n", text[:50])
|
||||
cypher = generate_cypher(text=text, neo4j_schema=neo4j.fmt_schema())
|
||||
neo4j.query(cypher.statement.replace('```', ''))
|
||||
print("CYPHER STATEMENT:\n", cypher.statement)
|
||||
|
||||
schema = neo4j.fmt_schema()
|
||||
print("SCHEMA:\n", schema)
|
||||
"""
|
||||
generate_cypher.push_to_hub(
|
||||
"farouk1/text-to-cypher",
|
||||
with_code=True,
|
||||
tag="v1.0.1",
|
||||
commit_message="Don't cache results",
|
||||
)
|
||||
78
optimize.py
Normal file
78
optimize.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import dspy
|
||||
from dspy import GEPA
|
||||
from modules import generate_cypher
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def process_dataset():
|
||||
train_split = load_dataset("neo4j/text2cypher-2025v1")["train"]
|
||||
train_split = [
|
||||
dspy.Example(
|
||||
{
|
||||
"question": x["question"],
|
||||
"neo4j_schema": x["schema"],
|
||||
"expected_cypher": x["cypher"],
|
||||
}
|
||||
).with_inputs("question", "neo4j_schema")
|
||||
for x in train_split
|
||||
]
|
||||
import random
|
||||
|
||||
random.Random(0).shuffle(train_split)
|
||||
train_split = train_split[:200]
|
||||
tot_num = len(train_split)
|
||||
|
||||
test_split = load_dataset("neo4j/text2cypher-2025v1")["test"]
|
||||
test_split = [
|
||||
dspy.Example(
|
||||
{
|
||||
"question": x["question"],
|
||||
"neo4j_schema": x["schema"],
|
||||
"expected_cypher": x["cypher"],
|
||||
}
|
||||
).with_inputs("question", "neo4j_schema")
|
||||
for x in test_split
|
||||
]
|
||||
|
||||
train_set = train_split[: int(0.5 * tot_num)]
|
||||
val_set = train_split[int(0.5 * tot_num) :]
|
||||
test_set = test_split[:200]
|
||||
|
||||
return train_set, val_set, test_set
|
||||
|
||||
|
||||
def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
|
||||
from sacrebleu import sentence_bleu
|
||||
|
||||
expected_cypher = example["expected_cypher"]
|
||||
generated_cypher = prediction.statement
|
||||
|
||||
# Calculate sentence-level BLEU (Google BLEU)
|
||||
bleu_score = sentence_bleu(
|
||||
generated_cypher,
|
||||
[expected_cypher], # Reference as a list
|
||||
).score / 100.0 # Normalize to 0-1
|
||||
|
||||
feedback = f"BLEU score: {bleu_score:.3f}"
|
||||
return dspy.Prediction(score=bleu_score, feedback=feedback)
|
||||
|
||||
|
||||
train_set, val_set, test_set = process_dataset()
|
||||
|
||||
optimizer = GEPA(
|
||||
metric=metric,
|
||||
auto="light",
|
||||
num_threads=32,
|
||||
track_stats=True,
|
||||
reflection_minibatch_size=3,
|
||||
reflection_lm=dspy.LM(model="gpt-5.2", temperature=1.0, max_tokens=32000),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
optimized_program = optimizer.compile(
|
||||
generate_cypher,
|
||||
trainset=train_set,
|
||||
valset=val_set,
|
||||
)
|
||||
optimized_program.push_to_hub("farouk1/text-to-cypher-gepa", tag="v1.0.2", commit_message="Optimized program with code")
|
||||
@@ -4,7 +4,7 @@
|
||||
"train": [],
|
||||
"demos": [],
|
||||
"signature": {
|
||||
"instructions": "text\nTask: Given (1) a natural-language question and (2) a Neo4j schema description, output exactly ONE Cypher query that answers the question.\n\nINPUTS\n- question: the user request in natural language.\n- neo4j_schema: schema info given either as:\n (a) JSON-like dict describing node labels, relationship types, directions, and properties, OR\n (b) a textual summary listing node labels with properties and a list of allowed relationships as {start, type, end}, plus any relationship properties.\n\nABSOLUTE REQUIREMENTS (must follow)\n1) Output ONLY the Cypher query text.\n - No reasoning, no explanations, no markdown/code fences, no headings, no extra characters.\n2) Use ONLY labels, relationship types, directions, and properties that appear in neo4j_schema.\n - Do NOT invent labels/properties/relationships.\n - If the question asks for something not representable, produce the closest possible query using only the schema.\n3) Respect relationship direction exactly as specified.\n - If schema says Article -[:PUBLISHED_IN]-> Journal, do not reverse it.\n - In JSON-like schemas, relationship direction may be expressed as \"in\" or \"out\" under a node\u2019s \"relationships\"; interpret it relative to that node.\n4) Return ONLY what the question asks for.\n - If it asks for \u201ctitle values\u201d, return a.title (not whole nodes).\n - If it asks for counts, return counts with clear aliases.\n - Use DISTINCT when the question implies uniqueness.\n5) Produce exactly one valid Cypher statement.\n\nQUERY CONSTRUCTION RULES / COMMON PITFALLS TO AVOID\nA) Filtering on relationship properties:\n - Put relationship property predicates on the relationship pattern or in WHERE, using correct Cypher syntax.\n - Example: MATCH (a)-[r:PUBLISHED_IN]->(j) WHERE r.meta = '220'\n - IMPORTANT: use Cypher string literals with single quotes (e.g., '220'), not JSON-style quotes.\nB) \u201cFirst N\u201d / \u201cN items\u201d semantics:\n - If the question requests \u201cfirst 3\u201d or \u201c20 Article\u201d, include LIMIT N.\n - If \u201cfirst\u201d implies ordering but no explicit sort key is given in schema/question, you may use LIMIT without ORDER BY.\n - Do NOT return more columns than asked just to justify \u201cfirst\u201d.\nC) Aggregations and grouping:\n - When returning both a field and a count, group by the non-aggregated field via WITH/RETURN.\n - Apply HAVING-like filters using WITH ... WHERE (e.g., cities with >1 student).\n - Example pattern:\n MATCH (s:Student)\n WITH s.city_code AS city, count(*) AS student_count\n WHERE student_count > 1\n RETURN city, student_count\nD) Date/time duration questions:\n - Use only functions that work with the property datatypes shown.\n - If begin/end are DATE_TIME, you may use duration/between logic; prefer robust checks:\n - If asked \u201cexactly one month\u201d, check the full duration equals duration({months:1}) when possible, or use duration.between(f.begin, f.end) and compare appropriately.\n - Do not introduce alternative date properties that aren\u2019t requested unless necessary and present in schema.\nE) String matching:\n - For prefix constraints, use STARTS WITH.\n - For exact text match, use equality.\nF) Combining strings/properties:\n - Use `+` for concatenation and alias with AS as requested.\n\nOUTPUT\n- Exactly one Cypher query, and nothing else.",
|
||||
"instructions": "Task: Generate Cypher statement to query a graph database.\nInstructions: Use only the provided relationship types and properties in the schema.\nDo not use any other relationship types or properties that are not provided in the schema.\nDo not include any explanations or apologies in your responses.\nDo not respond to any questions that might ask anything else than for you to construct a Cypher statement.\nDo not include any text except the generated Cypher statement.",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Question:",
|
||||
@@ -27,7 +27,7 @@
|
||||
"lm": {
|
||||
"model": "openrouter/openai/gpt-4o",
|
||||
"model_type": "chat",
|
||||
"cache": true,
|
||||
"cache": false,
|
||||
"num_retries": 3,
|
||||
"finetuning_model": null,
|
||||
"launch_kwargs": {},
|
||||
|
||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[project]
|
||||
name = "text-to-cypher-gepa"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = ["datasets>=4.4.2", "dspy>=3.0.4", "modaic>=0.8.3", "neo4j~=5.18.0", "python-dotenv~=1.0.1", "sacrebleu>=2.5.1"]
|
||||
Reference in New Issue
Block a user