Optimized program with code

This commit is contained in:
2025-12-27 07:14:57 -08:00
parent 1dc1669bcb
commit 86b1c6834d
4 changed files with 163 additions and 0 deletions

78
optimize.py Normal file
View File

@@ -0,0 +1,78 @@
import dspy
from dspy import GEPA
from modules import generate_cypher
from datasets import load_dataset
def process_dataset():
train_split = load_dataset("neo4j/text2cypher-2025v1")["train"]
train_split = [
dspy.Example(
{
"question": x["question"],
"neo4j_schema": x["schema"],
"expected_cypher": x["cypher"],
}
).with_inputs("question", "neo4j_schema")
for x in train_split
]
import random
random.Random(0).shuffle(train_split)
train_split = train_split[:200]
tot_num = len(train_split)
test_split = load_dataset("neo4j/text2cypher-2025v1")["test"]
test_split = [
dspy.Example(
{
"question": x["question"],
"neo4j_schema": x["schema"],
"expected_cypher": x["cypher"],
}
).with_inputs("question", "neo4j_schema")
for x in test_split
]
train_set = train_split[: int(0.5 * tot_num)]
val_set = train_split[int(0.5 * tot_num) :]
test_set = test_split[:200]
return train_set, val_set, test_set
def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
from sacrebleu import sentence_bleu
expected_cypher = example["expected_cypher"]
generated_cypher = prediction.statement
# Calculate sentence-level BLEU (Google BLEU)
bleu_score = sentence_bleu(
generated_cypher,
[expected_cypher], # Reference as a list
).score / 100.0 # Normalize to 0-1
feedback = f"BLEU score: {bleu_score:.3f}"
return dspy.Prediction(score=bleu_score, feedback=feedback)
train_set, val_set, test_set = process_dataset()
optimizer = GEPA(
metric=metric,
auto="light",
num_threads=32,
track_stats=True,
reflection_minibatch_size=3,
reflection_lm=dspy.LM(model="gpt-5.2", temperature=1.0, max_tokens=32000),
)
if __name__ == "__main__":
optimized_program = optimizer.compile(
generate_cypher,
trainset=train_set,
valset=val_set,
)
optimized_program.push_to_hub("farouk1/text-to-cypher-gepa", tag="v1.0.1", commit_message="Optimized program with code")