Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 57e7b1fd36 | |||
| 501c224540 |
@@ -5,9 +5,6 @@ LLM-driven automated knowledge graph construction from text using DSPy and Neo4j
|
|||||||
```sh
|
```sh
|
||||||
text-to-cypher/
|
text-to-cypher/
|
||||||
├── README.md
|
├── README.md
|
||||||
├── examples/
|
|
||||||
│ └── wikipedia-abstracts-v0_0_1.ndjson
|
|
||||||
├── img/
|
|
||||||
├── main.py
|
├── main.py
|
||||||
├── pyproject.toml
|
├── pyproject.toml
|
||||||
├── uv.lock
|
├── uv.lock
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"AutoConfig": "main.GenerateCypherConfig",
|
"AutoConfig": "modules.GenerateCypherConfig",
|
||||||
"AutoProgram": "main.GenerateCypher"
|
"AutoProgram": "modules.GenerateCypher"
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"model": "openai/gpt-4o",
|
"model": "openrouter/openai/gpt-4o",
|
||||||
"max_tokens": 1024
|
"max_tokens": 1024
|
||||||
}
|
}
|
||||||
@@ -2,19 +2,11 @@ import os
|
|||||||
import dspy
|
import dspy
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||||
from src.neo4j import Neo4j
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# set up Neo4j using NEO4J_URI
|
|
||||||
neo4j = Neo4j(
|
|
||||||
uri=os.getenv("NEO4J_URI"),
|
|
||||||
user=os.getenv("NEO4J_USER"),
|
|
||||||
password=os.getenv("NEO4J_PASSWORD"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class CypherFromQuestion(dspy.Signature):
|
||||||
class CypherFromText(dspy.Signature):
|
|
||||||
"""Task: Generate Cypher statement to query a graph database.
|
"""Task: Generate Cypher statement to query a graph database.
|
||||||
Instructions: Use only the provided relationship types and properties in the schema.
|
Instructions: Use only the provided relationship types and properties in the schema.
|
||||||
Do not use any other relationship types or properties that are not provided in the schema.
|
Do not use any other relationship types or properties that are not provided in the schema.
|
||||||
@@ -22,19 +14,18 @@ class CypherFromText(dspy.Signature):
|
|||||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||||
Do not include any text except the generated Cypher statement.
|
Do not include any text except the generated Cypher statement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
question = dspy.InputField(
|
question = dspy.InputField(
|
||||||
desc="Question to model using a cypher statement."
|
desc="Question to model using a cypher statement. Use only the provided relationship types and properties in the schema."
|
||||||
)
|
)
|
||||||
neo4j_schema = dspy.InputField(
|
neo4j_schema = dspy.InputField(
|
||||||
desc="Current graph schema in Neo4j as a list of NODES and RELATIONSHIPS."
|
desc="Current graph schema in Neo4j as a list of NODES and RELATIONSHIPS."
|
||||||
)
|
)
|
||||||
statement = dspy.OutputField(
|
statement = dspy.OutputField(desc="Cypher statement to query the graph database.")
|
||||||
desc="Cypher statement to merge nodes and relationships found in the text."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateCypherConfig(PrecompiledConfig):
|
class GenerateCypherConfig(PrecompiledConfig):
|
||||||
model: str = "openai/gpt-4o"
|
model: str = "openrouter/openai/gpt-4o" # OPENROUTER ONLY
|
||||||
max_tokens: int = 1024
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
@@ -46,12 +37,13 @@ class GenerateCypher(PrecompiledProgram):
|
|||||||
self.lm = dspy.LM(
|
self.lm = dspy.LM(
|
||||||
model=config.model,
|
model=config.model,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
|
api_base="https://openrouter.ai/api/v1",
|
||||||
)
|
)
|
||||||
self.generate_cypher = dspy.ChainOfThought(CypherFromText)
|
self.generate_cypher = dspy.ChainOfThought(CypherFromQuestion)
|
||||||
self.generate_cypher.set_lm(self.lm)
|
self.generate_cypher.set_lm(self.lm)
|
||||||
|
|
||||||
def forward(self, text: str, neo4j_schema: list[str]):
|
def forward(self, question: str, neo4j_schema: list[str]):
|
||||||
return self.generate_cypher(text=text, neo4j_schema=neo4j_schema)
|
return self.generate_cypher(question=question, neo4j_schema=neo4j_schema)
|
||||||
|
|
||||||
|
|
||||||
generate_cypher = GenerateCypher(GenerateCypherConfig())
|
generate_cypher = GenerateCypher(GenerateCypherConfig())
|
||||||
@@ -77,6 +69,6 @@ if __name__ == "__main__":
|
|||||||
generate_cypher.push_to_hub(
|
generate_cypher.push_to_hub(
|
||||||
"farouk1/text-to-cypher",
|
"farouk1/text-to-cypher",
|
||||||
with_code=True,
|
with_code=True,
|
||||||
tag="v0.0.8",
|
tag="v1.0.0",
|
||||||
commit_message="Update README.md",
|
commit_message="Update README.md",
|
||||||
)
|
)
|
||||||
@@ -8,7 +8,7 @@
|
|||||||
"fields": [
|
"fields": [
|
||||||
{
|
{
|
||||||
"prefix": "Question:",
|
"prefix": "Question:",
|
||||||
"description": "Question to model using a cypher statement."
|
"description": "Question to model using a cypher statement. Use only the provided relationship types and properties in the schema."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prefix": "Neo 4 J Schema:",
|
"prefix": "Neo 4 J Schema:",
|
||||||
@@ -20,12 +20,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prefix": "Statement:",
|
"prefix": "Statement:",
|
||||||
"description": "Cypher statement to merge nodes and relationships found in the text."
|
"description": "Cypher statement to query the graph database."
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"lm": {
|
"lm": {
|
||||||
"model": "openai/gpt-4o",
|
"model": "openrouter/openai/gpt-4o",
|
||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"cache": true,
|
"cache": true,
|
||||||
"num_retries": 3,
|
"num_retries": 3,
|
||||||
@@ -33,7 +33,8 @@
|
|||||||
"launch_kwargs": {},
|
"launch_kwargs": {},
|
||||||
"train_kwargs": {},
|
"train_kwargs": {},
|
||||||
"temperature": null,
|
"temperature": null,
|
||||||
"max_tokens": 1024
|
"max_tokens": 1024,
|
||||||
|
"api_base": "https://openrouter.ai/api/v1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = ["datasets>=4.4.2", "dspy>=3.0.4", "modaic>=0.8.2", "neo4j~=5.18.0", "python-dotenv~=1.0.1"]
|
dependencies = ["datasets>=4.4.2", "dspy>=3.0.4", "modaic>=0.8.3", "neo4j~=5.18.0", "python-dotenv~=1.0.1", "sacrebleu>=2.5.1"]
|
||||||
|
|||||||
180
src/neo4j.py
180
src/neo4j.py
@@ -1,180 +0,0 @@
|
|||||||
import json
|
|
||||||
import neo4j
|
|
||||||
|
|
||||||
|
|
||||||
def parse_relationships(schema: dict) -> str:
|
|
||||||
# Parse the JSON string into a Python object if it's not already
|
|
||||||
if isinstance(schema, str):
|
|
||||||
data = json.loads(schema)
|
|
||||||
else:
|
|
||||||
data = schema
|
|
||||||
|
|
||||||
data = data[0]["relationships"]
|
|
||||||
|
|
||||||
# Initialize a list to hold the formatted relationship strings
|
|
||||||
relationships = []
|
|
||||||
|
|
||||||
# Iterate through each relationship in the data
|
|
||||||
for relationship in data:
|
|
||||||
entity1, relation, entity2 = relationship
|
|
||||||
# Extract the names of the entities and the relationship
|
|
||||||
entity1_name = entity1["name"]
|
|
||||||
entity2_name = entity2["name"]
|
|
||||||
# Format the string as specified and add it to the list
|
|
||||||
formatted_relationship = f"{entity1_name}-{relation}->{entity2_name}"
|
|
||||||
relationships.append(formatted_relationship)
|
|
||||||
|
|
||||||
# Join all formatted strings with a newline character
|
|
||||||
result = "\n".join(relationships)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def parse_nodes(schema):
|
|
||||||
schema = schema
|
|
||||||
nodes = [node["name"] for node in schema[0]["nodes"]]
|
|
||||||
return "\n".join(nodes)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_node_properties(node_properties):
|
|
||||||
# Initialize a dictionary to accumulate node details
|
|
||||||
node_details = {}
|
|
||||||
|
|
||||||
# Iterate through each item in the input JSON
|
|
||||||
for item in node_properties:
|
|
||||||
node_label = item["nodeLabels"][0] # Assuming there's always one label
|
|
||||||
prop_name = item["propertyName"]
|
|
||||||
mandatory = "required" if item["mandatory"] else "optional"
|
|
||||||
|
|
||||||
# Prepare the property string
|
|
||||||
property_str = f"{prop_name} ({mandatory})" if item["mandatory"] else prop_name
|
|
||||||
|
|
||||||
# If the node label exists, append the property; otherwise, create a new entry
|
|
||||||
if node_label in node_details:
|
|
||||||
node_details[node_label].append(property_str)
|
|
||||||
else:
|
|
||||||
node_details[node_label] = [property_str]
|
|
||||||
|
|
||||||
# Format the output
|
|
||||||
output_lines = []
|
|
||||||
for node, properties in node_details.items():
|
|
||||||
output_lines.append(f"{node}")
|
|
||||||
for prop in properties:
|
|
||||||
prop_line = f" - {prop}" if "required" in prop else f" - {prop}"
|
|
||||||
output_lines.append(prop_line)
|
|
||||||
|
|
||||||
return "\n".join(output_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_rel_properties(rel_properties):
|
|
||||||
# Initialize a dictionary to accumulate relationship details
|
|
||||||
rel_details = {}
|
|
||||||
|
|
||||||
# Iterate through each item in the input JSON
|
|
||||||
for item in rel_properties:
|
|
||||||
# Extract relationship type name, removing :` and `
|
|
||||||
rel_type = item["relType"][2:].strip("`")
|
|
||||||
prop_name = item["propertyName"]
|
|
||||||
mandatory = "required" if item["mandatory"] else "optional"
|
|
||||||
|
|
||||||
# If propertyName is not None, prepare the property string
|
|
||||||
if prop_name is not None:
|
|
||||||
property_str = f"{prop_name} ({mandatory})"
|
|
||||||
# If the relationship type exists, append the property; otherwise, create a new entry
|
|
||||||
if rel_type in rel_details:
|
|
||||||
rel_details[rel_type].append(property_str)
|
|
||||||
else:
|
|
||||||
rel_details[rel_type] = [property_str]
|
|
||||||
else:
|
|
||||||
# For relationships without properties, ensure the relationship is listed
|
|
||||||
rel_details.setdefault(rel_type, [])
|
|
||||||
|
|
||||||
# Format the output
|
|
||||||
output_lines = []
|
|
||||||
for rel_type, properties in rel_details.items():
|
|
||||||
output_lines.append(f"{rel_type}")
|
|
||||||
for prop in properties:
|
|
||||||
output_lines.append(f" - {prop}")
|
|
||||||
|
|
||||||
return "\n".join(output_lines)
|
|
||||||
|
|
||||||
|
|
||||||
class Neo4j:
|
|
||||||
def __init__(self, uri, user: str = None, password: str = None):
|
|
||||||
self._uri = uri
|
|
||||||
self._user = user
|
|
||||||
self._password = password
|
|
||||||
self._auth = (
|
|
||||||
None
|
|
||||||
if (self._user is None and self._password is None)
|
|
||||||
else (self._user, self._password)
|
|
||||||
)
|
|
||||||
self._driver = neo4j.GraphDatabase.driver(
|
|
||||||
self._uri, auth=(self._user, self._password)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._verify_connection()
|
|
||||||
print("CONNECTION ESTABLISHED")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self._driver.close()
|
|
||||||
print("CONNECTION CLOSED")
|
|
||||||
|
|
||||||
def _verify_connection(self):
|
|
||||||
with self._driver as driver:
|
|
||||||
driver.verify_connectivity()
|
|
||||||
print("CONNECTION VERIFIED")
|
|
||||||
|
|
||||||
def query(self, query, parameters=None, db=None):
|
|
||||||
assert db is None, (
|
|
||||||
"The Neo4j implementation does not support multiple databases."
|
|
||||||
)
|
|
||||||
with self._driver.session(database=db) as session:
|
|
||||||
result = session.run(query, parameters)
|
|
||||||
return result.data()
|
|
||||||
|
|
||||||
def schema(self, parsed=False):
|
|
||||||
query = """
|
|
||||||
CALL db.schema.visualization()
|
|
||||||
"""
|
|
||||||
schema = self.query(query)
|
|
||||||
|
|
||||||
if parsed:
|
|
||||||
return parse_nodes(schema), parse_relationships(schema)
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
def schema_properties(self, parsed=False):
|
|
||||||
props = self._schema_node_properties(), self._schema_relationship_properties()
|
|
||||||
if parsed:
|
|
||||||
return parse_node_properties(props[0]), parse_rel_properties(props[1])
|
|
||||||
|
|
||||||
return props
|
|
||||||
|
|
||||||
def _schema_node_properties(self):
|
|
||||||
query = """
|
|
||||||
CALL db.schema.nodeTypeProperties()
|
|
||||||
"""
|
|
||||||
return self.query(query)
|
|
||||||
|
|
||||||
def _schema_relationship_properties(self):
|
|
||||||
query = """
|
|
||||||
CALL db.schema.relTypeProperties()
|
|
||||||
"""
|
|
||||||
return self.query(query)
|
|
||||||
|
|
||||||
def fmt_schema(self):
|
|
||||||
parsed_schema = self.schema(parsed=True)
|
|
||||||
parsed_props = self.schema_properties(parsed=True)
|
|
||||||
parsed = (*parsed_props, parsed_schema[1])
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
f"{element}:\n{parsed[idx]}\n"
|
|
||||||
for idx, element in enumerate(
|
|
||||||
[
|
|
||||||
"NODE LABELS & PROPERTIES",
|
|
||||||
"RELATIONSHIP LABELS & PROPERTIES",
|
|
||||||
"RELATIONSHIPS",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user