Optimized program with code
This commit is contained in:
78
optimize.py
Normal file
78
optimize.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import dspy
|
||||
from dspy import GEPA
|
||||
from modules import generate_cypher
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def process_dataset():
|
||||
train_split = load_dataset("neo4j/text2cypher-2025v1")["train"]
|
||||
train_split = [
|
||||
dspy.Example(
|
||||
{
|
||||
"question": x["question"],
|
||||
"neo4j_schema": x["schema"],
|
||||
"expected_cypher": x["cypher"],
|
||||
}
|
||||
).with_inputs("question", "neo4j_schema")
|
||||
for x in train_split
|
||||
]
|
||||
import random
|
||||
|
||||
random.Random(0).shuffle(train_split)
|
||||
train_split = train_split[:200]
|
||||
tot_num = len(train_split)
|
||||
|
||||
test_split = load_dataset("neo4j/text2cypher-2025v1")["test"]
|
||||
test_split = [
|
||||
dspy.Example(
|
||||
{
|
||||
"question": x["question"],
|
||||
"neo4j_schema": x["schema"],
|
||||
"expected_cypher": x["cypher"],
|
||||
}
|
||||
).with_inputs("question", "neo4j_schema")
|
||||
for x in test_split
|
||||
]
|
||||
|
||||
train_set = train_split[: int(0.5 * tot_num)]
|
||||
val_set = train_split[int(0.5 * tot_num) :]
|
||||
test_set = test_split[:200]
|
||||
|
||||
return train_set, val_set, test_set
|
||||
|
||||
|
||||
def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
|
||||
from sacrebleu import sentence_bleu
|
||||
|
||||
expected_cypher = example["expected_cypher"]
|
||||
generated_cypher = prediction.statement
|
||||
|
||||
# Calculate sentence-level BLEU (Google BLEU)
|
||||
bleu_score = sentence_bleu(
|
||||
generated_cypher,
|
||||
[expected_cypher], # Reference as a list
|
||||
).score / 100.0 # Normalize to 0-1
|
||||
|
||||
feedback = f"BLEU score: {bleu_score:.3f}"
|
||||
return dspy.Prediction(score=bleu_score, feedback=feedback)
|
||||
|
||||
|
||||
train_set, val_set, test_set = process_dataset()
|
||||
|
||||
optimizer = GEPA(
|
||||
metric=metric,
|
||||
auto="light",
|
||||
num_threads=32,
|
||||
track_stats=True,
|
||||
reflection_minibatch_size=3,
|
||||
reflection_lm=dspy.LM(model="gpt-5.2", temperature=1.0, max_tokens=32000),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
optimized_program = optimizer.compile(
|
||||
generate_cypher,
|
||||
trainset=train_set,
|
||||
valset=val_set,
|
||||
)
|
||||
optimized_program.push_to_hub("farouk1/text-to-cypher-gepa", tag="v1.0.1", commit_message="Optimized program with code")
|
||||
Reference in New Issue
Block a user