(no commit message)
This commit is contained in:
20
agent.json
20
agent.json
@@ -1,22 +1,26 @@
|
|||||||
{
|
{
|
||||||
"signature_generator.generator": {
|
"signature_generator.generator.predict": {
|
||||||
"traces": [],
|
"traces": [],
|
||||||
"train": [],
|
"train": [],
|
||||||
"demos": [],
|
"demos": [],
|
||||||
"signature": {
|
"signature": {
|
||||||
"instructions": "Given the fields `prompt`, produce the fields `task_description`, `signature_fields`, `signature_name`.",
|
"instructions": "Generate a DSPy signature with proper field types.\n\nIMPORTANT: For ANY structured/nested JSON output with multiple fields or nested objects:\n- DO NOT use a simple str field with JSON instructions\n- ALWAYS use type='pydantic' with a complete PydanticModelSchema\n- Define the full nested structure with all fields properly typed\n\nExamples:\n- Simple output: \"answer: str\" is fine\n- Complex output like medical records, forms, multi-field objects: MUST use pydantic models",
|
||||||
"fields": [
|
"fields": [
|
||||||
{
|
{
|
||||||
"prefix": "Prompt:",
|
"prefix": "Prompt:",
|
||||||
"description": "Natural language description of the desired functionality"
|
"description": "Natural language description of the desired functionality."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prefix": "Reasoning: Let's think step by step in order to",
|
||||||
|
"description": "${reasoning}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prefix": "Task Description:",
|
"prefix": "Task Description:",
|
||||||
"description": "Clear description of what the signature accomplishes"
|
"description": "Clear description of what the signature aims to accomplish."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prefix": "Signature Fields:",
|
"prefix": "Signature Fields:",
|
||||||
"description": "List of input and output fields for the signature"
|
"description": "List of input and output fields for the signature.\n\nCRITICAL RULES:\n1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema\n2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs\n3. For Pydantic models: define ALL nested fields properly in the schema with correct types\n4. Simple outputs (single values) can use basic types like str, int, bool\n5. Use Literal types for enumerated values (e.g., severity levels, status codes)\n\nExamples:\n- BAD: structured_output: str = \"A JSON containing patient data...\"\n- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prefix": "Signature Name:",
|
"prefix": "Signature Name:",
|
||||||
@@ -25,15 +29,15 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"lm": {
|
"lm": {
|
||||||
"model": "gemini/gemini-2.5-pro-preview-03-25",
|
"model": "gpt-5-2025-08-07",
|
||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"cache": true,
|
"cache": true,
|
||||||
"num_retries": 3,
|
"num_retries": 3,
|
||||||
"finetuning_model": null,
|
"finetuning_model": null,
|
||||||
"launch_kwargs": {},
|
"launch_kwargs": {},
|
||||||
"train_kwargs": {},
|
"train_kwargs": {},
|
||||||
"temperature": 0.7,
|
"temperature": 1.0,
|
||||||
"max_tokens": 4096
|
"max_completion_tokens": 16000
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from .modules import signature_generator
|
|||||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||||
|
|
||||||
class PromptToSignatureConfig(PrecompiledConfig):
|
class PromptToSignatureConfig(PrecompiledConfig):
|
||||||
lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
lm: str = "gpt-5-2025-08-07"
|
||||||
refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
refine_lm: str = "gpt-5-2025-08-07"
|
||||||
max_tokens: int = 4096
|
max_tokens: int = 16000
|
||||||
temperature: float = 0.7
|
temperature: float = 1.0
|
||||||
max_attempts_to_refine: int = 5
|
max_attempts_to_refine: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
347
agent/modules.py
347
agent/modules.py
@@ -1,8 +1,9 @@
|
|||||||
import dspy
|
import dspy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Literal
|
from typing import Any, Dict, List, Optional, Literal, Union
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from .utils import *
|
|
||||||
|
dspy.configure(lm=dspy.LM("gpt-5-2025-08-07", temperature=1.0, max_tokens=16000))
|
||||||
|
|
||||||
|
|
||||||
class FieldType(str, Enum):
|
class FieldType(str, Enum):
|
||||||
@@ -15,13 +16,18 @@ class FieldType(str, Enum):
|
|||||||
LIST_FLOAT = "list[float]"
|
LIST_FLOAT = "list[float]"
|
||||||
DICT_STR_STR = "dict[str, str]"
|
DICT_STR_STR = "dict[str, str]"
|
||||||
DICT_STR_INT = "dict[str, int]"
|
DICT_STR_INT = "dict[str, int]"
|
||||||
|
DICT_STR_FLOAT = "dict[str, float]"
|
||||||
|
DICT_STR_BOOL = "dict[str, bool]"
|
||||||
DICT_STR_ANY = "dict[str, Any]"
|
DICT_STR_ANY = "dict[str, Any]"
|
||||||
|
LIST_DICT = "list[dict]"
|
||||||
IMAGE = "dspy.Image"
|
IMAGE = "dspy.Image"
|
||||||
AUDIO = "dspy.Audio"
|
AUDIO = "dspy.Audio"
|
||||||
LITERAL = "Literal"
|
LITERAL = "Literal"
|
||||||
OPTIONAL_STR = "Optional[str]"
|
OPTIONAL_STR = "Optional[str]"
|
||||||
OPTIONAL_INT = "Optional[int]"
|
OPTIONAL_INT = "Optional[int]"
|
||||||
OPTIONAL_FLOAT = "Optional[float]"
|
OPTIONAL_FLOAT = "Optional[float]"
|
||||||
|
OPTIONAL_BOOL = "Optional[bool]"
|
||||||
|
PYDANTIC_MODEL = "pydantic" # For nested structures
|
||||||
|
|
||||||
|
|
||||||
class FieldRole(str, Enum):
|
class FieldRole(str, Enum):
|
||||||
@@ -29,11 +35,91 @@ class FieldRole(str, Enum):
|
|||||||
OUTPUT = "output"
|
OUTPUT = "output"
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticFieldDef(BaseModel):
|
||||||
|
"""Represents a single field in a Pydantic model"""
|
||||||
|
name: str = Field(description="Field name in the Pydantic model")
|
||||||
|
type: FieldType = Field(description="Field type")
|
||||||
|
description: Optional[str] = Field(default=None, description="Field description")
|
||||||
|
required: bool = Field(default=True, description="Whether the field is required")
|
||||||
|
literal_values: Optional[List[str]] = Field(
|
||||||
|
default=None, description="For Literal types, the allowed values"
|
||||||
|
)
|
||||||
|
nested_model: Optional["PydanticModelSchema"] = Field(
|
||||||
|
default=None, description="For nested Pydantic models"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelSchema(BaseModel):
|
||||||
|
"""Represents a complete Pydantic model schema"""
|
||||||
|
model_name: str = Field(description="Name of the Pydantic model class (PascalCase)")
|
||||||
|
description: Optional[str] = Field(default=None, description="Model docstring")
|
||||||
|
fields: List[PydanticFieldDef] = Field(description="Fields in the model")
|
||||||
|
|
||||||
|
def to_code(self, indent: int = 0) -> str:
|
||||||
|
"""Generate Python code for this Pydantic model"""
|
||||||
|
indent_str = " " * indent
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
lines.append(f"{indent_str}class {self.model_name}(BaseModel):")
|
||||||
|
if self.description:
|
||||||
|
lines.append(f'{indent_str} """{self.description}"""')
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for field_def in self.fields:
|
||||||
|
field_code = self._generate_field_code(field_def, indent + 1)
|
||||||
|
lines.append(field_code)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _generate_field_code(self, field_def: PydanticFieldDef, indent: int) -> str:
|
||||||
|
"""Generate code for a single Pydantic field"""
|
||||||
|
indent_str = " " * indent
|
||||||
|
|
||||||
|
# Determine type annotation
|
||||||
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||||
|
type_annotation = field_def.nested_model.model_name
|
||||||
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||||
|
type_annotation = f"Literal[{', '.join(repr(v) for v in field_def.literal_values)}]"
|
||||||
|
else:
|
||||||
|
type_annotation = field_def.type.value
|
||||||
|
|
||||||
|
# Add Optional wrapper if not required
|
||||||
|
if not field_def.required:
|
||||||
|
type_annotation = f"Optional[{type_annotation}]"
|
||||||
|
|
||||||
|
# Build Field() arguments
|
||||||
|
field_args = []
|
||||||
|
if field_def.description:
|
||||||
|
field_args.append(f'description="{field_def.description}"')
|
||||||
|
if not field_def.required:
|
||||||
|
field_args.append("default=None")
|
||||||
|
|
||||||
|
if field_args:
|
||||||
|
field_init = f"Field({', '.join(field_args)})"
|
||||||
|
else:
|
||||||
|
field_init = "Field()"
|
||||||
|
|
||||||
|
return f"{indent_str}{field_def.name}: {type_annotation} = {field_init}"
|
||||||
|
|
||||||
|
def get_all_nested_models(self) -> List["PydanticModelSchema"]:
|
||||||
|
"""Recursively collect all nested models"""
|
||||||
|
nested = []
|
||||||
|
for field_def in self.fields:
|
||||||
|
if field_def.nested_model:
|
||||||
|
# Add nested models first (depth-first)
|
||||||
|
nested.extend(field_def.nested_model.get_all_nested_models())
|
||||||
|
nested.append(field_def.nested_model)
|
||||||
|
return nested
|
||||||
|
|
||||||
|
|
||||||
class GeneratedField(BaseModel):
|
class GeneratedField(BaseModel):
|
||||||
name: str = Field(description="The field name (snake_case, descriptive)")
|
name: str = Field(description="The field name (snake_case, descriptive)")
|
||||||
type: FieldType = Field(description="The Python type for this field")
|
type: FieldType = Field(description="The Python type for this field")
|
||||||
role: FieldRole = Field(description="Whether this is an input or output field")
|
role: FieldRole = Field(description="Whether this is an input or output field")
|
||||||
description: str = Field(description="Description of what this field represents")
|
description: str = Field(description="Description of what this field represents")
|
||||||
|
pydantic_model_schema: Optional[PydanticModelSchema] = Field(
|
||||||
|
default=None, description="For PYDANTIC_MODEL type, the model schema"
|
||||||
|
)
|
||||||
literal_values: Optional[List[str]] = Field(
|
literal_values: Optional[List[str]] = Field(
|
||||||
default=None, description="For Literal types, the allowed values"
|
default=None, description="For Literal types, the allowed values"
|
||||||
)
|
)
|
||||||
@@ -43,7 +129,10 @@ class GeneratedField(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_dspy_field_code(self) -> str:
|
def to_dspy_field_code(self) -> str:
|
||||||
if self.type == FieldType.LITERAL and self.literal_values:
|
"""Generate DSPy field code for use in a Signature class"""
|
||||||
|
if self.type == FieldType.PYDANTIC_MODEL and self.pydantic_model_schema:
|
||||||
|
type_annotation = self.pydantic_model_schema.model_name
|
||||||
|
elif self.type == FieldType.LITERAL and self.literal_values:
|
||||||
type_annotation = (
|
type_annotation = (
|
||||||
f"Literal[{', '.join(repr(v) for v in self.literal_values)}]"
|
f"Literal[{', '.join(repr(v) for v in self.literal_values)}]"
|
||||||
)
|
)
|
||||||
@@ -59,14 +148,37 @@ class GeneratedField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SignatureGeneration(dspy.Signature):
|
class SignatureGeneration(dspy.Signature):
|
||||||
|
"""Generate a DSPy signature with proper field types.
|
||||||
|
|
||||||
|
IMPORTANT: For ANY structured/nested JSON output with multiple fields or nested objects:
|
||||||
|
- DO NOT use a simple str field with JSON instructions
|
||||||
|
- ALWAYS use type='pydantic' with a complete PydanticModelSchema
|
||||||
|
- Define the full nested structure with all fields properly typed
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- Simple output: "answer: str" is fine
|
||||||
|
- Complex output like medical records, forms, multi-field objects: MUST use pydantic models
|
||||||
|
"""
|
||||||
|
|
||||||
prompt: str = dspy.InputField(
|
prompt: str = dspy.InputField(
|
||||||
desc="Natural language description of the desired functionality"
|
desc="Natural language description of the desired functionality."
|
||||||
)
|
)
|
||||||
task_description: str = dspy.OutputField(
|
task_description: str = dspy.OutputField(
|
||||||
desc="Clear description of what the signature accomplishes"
|
desc="Clear description of what the signature aims to accomplish."
|
||||||
)
|
)
|
||||||
signature_fields: list[GeneratedField] = dspy.OutputField(
|
signature_fields: list[GeneratedField] = dspy.OutputField(
|
||||||
desc="List of input and output fields for the signature"
|
desc="""List of input and output fields for the signature.
|
||||||
|
|
||||||
|
CRITICAL RULES:
|
||||||
|
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
|
||||||
|
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
|
||||||
|
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
|
||||||
|
4. Simple outputs (single values) can use basic types like str, int, bool
|
||||||
|
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- BAD: structured_output: str = "A JSON containing patient data..."
|
||||||
|
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"""
|
||||||
)
|
)
|
||||||
signature_name: str = dspy.OutputField(
|
signature_name: str = dspy.OutputField(
|
||||||
desc="Suggested class name for the signature (PascalCase)"
|
desc="Suggested class name for the signature (PascalCase)"
|
||||||
@@ -76,7 +188,7 @@ class SignatureGeneration(dspy.Signature):
|
|||||||
class SignatureGenerator(dspy.Module):
|
class SignatureGenerator(dspy.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = dspy.Predict(SignatureGeneration)
|
self.generator = dspy.ChainOfThought(SignatureGeneration)
|
||||||
|
|
||||||
def forward(self, prompt: str):
|
def forward(self, prompt: str):
|
||||||
"""Generate DSPy signature and return raw prediction attributes"""
|
"""Generate DSPy signature and return raw prediction attributes"""
|
||||||
@@ -166,10 +278,14 @@ class SignatureGenerator(dspy.Module):
|
|||||||
"list[float]": List[float],
|
"list[float]": List[float],
|
||||||
"dict[str, str]": Dict[str, str],
|
"dict[str, str]": Dict[str, str],
|
||||||
"dict[str, int]": Dict[str, int],
|
"dict[str, int]": Dict[str, int],
|
||||||
|
"dict[str, float]": Dict[str, float],
|
||||||
|
"dict[str, bool]": Dict[str, bool],
|
||||||
"dict[str, Any]": Dict[str, Any],
|
"dict[str, Any]": Dict[str, Any],
|
||||||
|
"list[dict]": List[dict],
|
||||||
"Optional[str]": Optional[str],
|
"Optional[str]": Optional[str],
|
||||||
"Optional[int]": Optional[int],
|
"Optional[int]": Optional[int],
|
||||||
"Optional[float]": Optional[float],
|
"Optional[float]": Optional[float],
|
||||||
|
"Optional[bool]": Optional[bool],
|
||||||
"dspy.Image": dspy.Image,
|
"dspy.Image": dspy.Image,
|
||||||
"dspy.Audio": dspy.Audio,
|
"dspy.Audio": dspy.Audio,
|
||||||
}
|
}
|
||||||
@@ -177,6 +293,10 @@ class SignatureGenerator(dspy.Module):
|
|||||||
if field.type == FieldType.LITERAL and field.literal_values:
|
if field.type == FieldType.LITERAL and field.literal_values:
|
||||||
return Literal[tuple(field.literal_values)]
|
return Literal[tuple(field.literal_values)]
|
||||||
|
|
||||||
|
if field.type == FieldType.PYDANTIC_MODEL and field.pydantic_model_schema:
|
||||||
|
# Dynamically create the Pydantic model class
|
||||||
|
return SignatureGenerator._create_dynamic_pydantic_model(field.pydantic_model_schema)
|
||||||
|
|
||||||
if type_str in type_map:
|
if type_str in type_map:
|
||||||
return type_map[type_str]
|
return type_map[type_str]
|
||||||
|
|
||||||
@@ -184,6 +304,55 @@ class SignatureGenerator(dspy.Module):
|
|||||||
f"Unsupported field type for dynamic class creation: {type_str}"
|
f"Unsupported field type for dynamic class creation: {type_str}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_dynamic_pydantic_model(schema: PydanticModelSchema) -> type:
|
||||||
|
"""Dynamically create a Pydantic model class from a schema"""
|
||||||
|
class_attrs = {"__annotations__": {}}
|
||||||
|
|
||||||
|
if schema.description:
|
||||||
|
class_attrs["__doc__"] = schema.description
|
||||||
|
|
||||||
|
for field_def in schema.fields:
|
||||||
|
field_name = field_def.name
|
||||||
|
|
||||||
|
# Determine Python type
|
||||||
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||||
|
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
|
||||||
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||||
|
py_type = Literal[tuple(field_def.literal_values)]
|
||||||
|
else:
|
||||||
|
# Use the type map from _get_python_type_from_field
|
||||||
|
type_str = field_def.type.value
|
||||||
|
type_map = {
|
||||||
|
"str": str, "int": int, "float": float, "bool": bool,
|
||||||
|
"list[str]": List[str], "list[int]": List[int], "list[float]": List[float],
|
||||||
|
"dict[str, str]": Dict[str, str], "dict[str, int]": Dict[str, int],
|
||||||
|
"dict[str, float]": Dict[str, float], "dict[str, bool]": Dict[str, bool],
|
||||||
|
"dict[str, Any]": Dict[str, Any], "list[dict]": List[dict],
|
||||||
|
"Optional[str]": Optional[str], "Optional[int]": Optional[int],
|
||||||
|
"Optional[float]": Optional[float], "Optional[bool]": Optional[bool],
|
||||||
|
}
|
||||||
|
py_type = type_map.get(type_str, str)
|
||||||
|
|
||||||
|
# Wrap in Optional if not required
|
||||||
|
if not field_def.required and not isinstance(py_type, type(Optional[str])):
|
||||||
|
py_type = Optional[py_type]
|
||||||
|
|
||||||
|
# Create Pydantic field
|
||||||
|
field_kwargs = {}
|
||||||
|
if field_def.description:
|
||||||
|
field_kwargs["description"] = field_def.description
|
||||||
|
if not field_def.required:
|
||||||
|
field_kwargs["default"] = None
|
||||||
|
|
||||||
|
class_attrs["__annotations__"][field_name] = py_type
|
||||||
|
if field_kwargs:
|
||||||
|
class_attrs[field_name] = Field(**field_kwargs)
|
||||||
|
|
||||||
|
# Create the dynamic class
|
||||||
|
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
|
||||||
|
return DynamicModel
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_code(cls, prediction) -> str:
|
def generate_code(cls, prediction) -> str:
|
||||||
"""Generate Python code from a signature prediction"""
|
"""Generate Python code from a signature prediction"""
|
||||||
@@ -192,6 +361,16 @@ class SignatureGenerator(dspy.Module):
|
|||||||
code_lines = []
|
code_lines = []
|
||||||
code_lines.extend(imports)
|
code_lines.extend(imports)
|
||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
|
|
||||||
|
# Generate Pydantic model classes first (if any)
|
||||||
|
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
|
||||||
|
if pydantic_models:
|
||||||
|
for model_schema in pydantic_models:
|
||||||
|
code_lines.append(model_schema.to_code())
|
||||||
|
code_lines.append("")
|
||||||
|
code_lines.append("")
|
||||||
|
|
||||||
|
# Generate the main signature class
|
||||||
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
||||||
code_lines.append(f' """{prediction.task_description}"""')
|
code_lines.append(f' """{prediction.task_description}"""')
|
||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
@@ -201,30 +380,62 @@ class SignatureGenerator(dspy.Module):
|
|||||||
|
|
||||||
return "\n".join(code_lines)
|
return "\n".join(code_lines)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _collect_pydantic_models(cls, fields: List[GeneratedField]) -> List[PydanticModelSchema]:
|
||||||
|
"""Collect all Pydantic models from fields, including nested ones"""
|
||||||
|
models = []
|
||||||
|
seen_names = set()
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if field.pydantic_model_schema:
|
||||||
|
# Get all nested models first (depth-first)
|
||||||
|
nested_models = field.pydantic_model_schema.get_all_nested_models()
|
||||||
|
for nested in nested_models:
|
||||||
|
if nested.model_name not in seen_names:
|
||||||
|
models.append(nested)
|
||||||
|
seen_names.add(nested.model_name)
|
||||||
|
|
||||||
|
# Then add the top-level model
|
||||||
|
if field.pydantic_model_schema.model_name not in seen_names:
|
||||||
|
models.append(field.pydantic_model_schema)
|
||||||
|
seen_names.add(field.pydantic_model_schema.model_name)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_imports(cls, fields: List[GeneratedField]) -> List[str]:
|
def get_required_imports(cls, fields: List[GeneratedField]) -> List[str]:
|
||||||
"""Determine required imports based on field types"""
|
"""Determine required imports based on field types"""
|
||||||
imports = ["import dspy"]
|
imports = ["import dspy"]
|
||||||
typing_imports = set()
|
typing_imports = set()
|
||||||
|
needs_pydantic = False
|
||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field.type == FieldType.LITERAL:
|
if field.type == FieldType.PYDANTIC_MODEL:
|
||||||
|
needs_pydantic = True
|
||||||
|
# Check nested models for their typing requirements
|
||||||
|
if field.pydantic_model_schema:
|
||||||
|
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
|
||||||
|
elif field.type == FieldType.LITERAL:
|
||||||
typing_imports.add("Literal")
|
typing_imports.add("Literal")
|
||||||
elif field.type in [
|
elif field.type in [
|
||||||
FieldType.OPTIONAL_STR,
|
FieldType.OPTIONAL_STR,
|
||||||
FieldType.OPTIONAL_INT,
|
FieldType.OPTIONAL_INT,
|
||||||
FieldType.OPTIONAL_FLOAT,
|
FieldType.OPTIONAL_FLOAT,
|
||||||
|
FieldType.OPTIONAL_BOOL,
|
||||||
]:
|
]:
|
||||||
typing_imports.add("Optional")
|
typing_imports.add("Optional")
|
||||||
elif field.type in [
|
elif field.type in [
|
||||||
FieldType.LIST_STRING,
|
FieldType.LIST_STRING,
|
||||||
FieldType.LIST_INT,
|
FieldType.LIST_INT,
|
||||||
FieldType.LIST_FLOAT,
|
FieldType.LIST_FLOAT,
|
||||||
|
FieldType.LIST_DICT,
|
||||||
]:
|
]:
|
||||||
typing_imports.add("List")
|
typing_imports.add("List")
|
||||||
elif field.type in [
|
elif field.type in [
|
||||||
FieldType.DICT_STR_STR,
|
FieldType.DICT_STR_STR,
|
||||||
FieldType.DICT_STR_INT,
|
FieldType.DICT_STR_INT,
|
||||||
|
FieldType.DICT_STR_FLOAT,
|
||||||
|
FieldType.DICT_STR_BOOL,
|
||||||
FieldType.DICT_STR_ANY,
|
FieldType.DICT_STR_ANY,
|
||||||
]:
|
]:
|
||||||
typing_imports.add("Dict")
|
typing_imports.add("Dict")
|
||||||
@@ -237,6 +448,126 @@ class SignatureGenerator(dspy.Module):
|
|||||||
f"from typing import {', '.join(sorted(list(typing_imports)))}"
|
f"from typing import {', '.join(sorted(list(typing_imports)))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if needs_pydantic:
|
||||||
|
imports.append("from pydantic import BaseModel, Field")
|
||||||
|
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _collect_typing_imports_from_schema(cls, schema: PydanticModelSchema, typing_imports: set):
|
||||||
|
"""Recursively collect typing imports needed for a Pydantic model schema"""
|
||||||
|
for field_def in schema.fields:
|
||||||
|
if field_def.type == FieldType.LITERAL:
|
||||||
|
typing_imports.add("Literal")
|
||||||
|
elif field_def.type in [FieldType.LIST_STRING, FieldType.LIST_INT, FieldType.LIST_FLOAT, FieldType.LIST_DICT]:
|
||||||
|
typing_imports.add("List")
|
||||||
|
elif field_def.type in [FieldType.DICT_STR_STR, FieldType.DICT_STR_INT,
|
||||||
|
FieldType.DICT_STR_FLOAT, FieldType.DICT_STR_BOOL, FieldType.DICT_STR_ANY]:
|
||||||
|
typing_imports.add("Dict")
|
||||||
|
|
||||||
|
if field_def.type == FieldType.DICT_STR_ANY:
|
||||||
|
typing_imports.add("Any")
|
||||||
|
|
||||||
|
if not field_def.required:
|
||||||
|
typing_imports.add("Optional")
|
||||||
|
|
||||||
|
# Recurse for nested models
|
||||||
|
if field_def.nested_model:
|
||||||
|
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
|
||||||
|
|
||||||
signature_generator = SignatureGenerator()
|
signature_generator = SignatureGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
CR_PROMPT = """ 1. Medical Record Information Extraction and Clinical Priority Classification
|
||||||
|
|
||||||
|
**INPUT FORMAT**: Raw clinical notes in free-text format, typically 200-2000 words, containing unstructured medical documentation from patient encounters including history, examination findings, diagnoses, treatment plans, and follow-up instructions.
|
||||||
|
|
||||||
|
**TASK DESCRIPTION**: You are a medical information extraction system. Analyze the provided clinical notes and perform comprehensive structured data extraction along with risk stratification. Your task involves multiple sub-tasks executed simultaneously.
|
||||||
|
|
||||||
|
**EXTRACTION REQUIREMENTS**:
|
||||||
|
- **Patient Demographics**: Extract full legal name, age (in years), biological sex/gender, date of birth (format: YYYY-MM-DD), medical record number (MRN) if present
|
||||||
|
- **Primary Diagnosis**: Identify the main diagnosis with corresponding ICD-10 code, include diagnostic certainty level (confirmed/suspected/rule-out)
|
||||||
|
- **Secondary Diagnoses**: List all comorbidities and additional conditions mentioned, each with ICD-10 codes where applicable
|
||||||
|
- **Medications**: Extract complete medication list including generic and brand names, dosages with units (mg, mcg, mL), frequency (BID, TID, QID, PRN), route of administration (PO, IV, IM, topical), and duration if specified
|
||||||
|
- **Allergies**: Document all allergies with allergen name, reaction type (rash, anaphylaxis, nausea, etc.), and severity classification (mild/moderate/severe/life-threatening)
|
||||||
|
- **Vital Signs**: Extract most recent measurements - blood pressure (systolic/diastolic in mmHg), heart rate (bpm), temperature (°F or °C with unit), respiratory rate (breaths/min), oxygen saturation (%), and pain score (0-10 scale)
|
||||||
|
- **Laboratory Results**: Identify all lab values mentioned with test name, numerical result, unit of measurement, reference range, and flag if abnormal (high/low/critical)
|
||||||
|
- **Appointments**: Extract scheduled follow-up dates, appointment types (follow-up, specialist referral, procedure), and provider names
|
||||||
|
|
||||||
|
**CLASSIFICATION REQUIREMENTS**:
|
||||||
|
- **Urgency Level**: Classify the case into one of four categories:
|
||||||
|
- ROUTINE: Stable patient, chronic condition management, no acute concerns
|
||||||
|
- URGENT: Requires attention within 24-48 hours, acute but not life-threatening condition
|
||||||
|
- EMERGENCY: Immediate intervention required, potentially life-threatening presentation
|
||||||
|
- CRITICAL: Life-threatening emergency requiring immediate intervention (ICU-level care)
|
||||||
|
|
||||||
|
**OUTPUT FORMAT**: Return the following schema as a pydantic model:
|
||||||
|
{
|
||||||
|
"patient_demographics": {
|
||||||
|
"name": "string",
|
||||||
|
"age": "integer",
|
||||||
|
"gender": "string",
|
||||||
|
"dob": "YYYY-MM-DD",
|
||||||
|
"mrn": "string or null"
|
||||||
|
},
|
||||||
|
"primary_diagnosis": {
|
||||||
|
"condition": "string",
|
||||||
|
"icd10_code": "integer",
|
||||||
|
"certainty": "confirmed|suspected|rule-out"
|
||||||
|
},
|
||||||
|
"secondary_diagnoses": [
|
||||||
|
{"condition": "string", "icd10_code": "string"}
|
||||||
|
],
|
||||||
|
"medications": [
|
||||||
|
{
|
||||||
|
"name": "string",
|
||||||
|
"dosage": "string",
|
||||||
|
"frequency": "string",
|
||||||
|
"route": "string",
|
||||||
|
"duration": "string or null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"allergies": [
|
||||||
|
{
|
||||||
|
"allergen": "string",
|
||||||
|
"reaction": "string",
|
||||||
|
"severity": "mild|moderate|severe|life-threatening"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"vital_signs": {
|
||||||
|
"blood_pressure": "string (systolic/diastolic)",
|
||||||
|
"heart_rate": "integer",
|
||||||
|
"temperature": "float with unit",
|
||||||
|
"respiratory_rate": "integer",
|
||||||
|
"oxygen_saturation": "integer",
|
||||||
|
"pain_score": "integer (0-10)"
|
||||||
|
},
|
||||||
|
"lab_results": [
|
||||||
|
{
|
||||||
|
"test_name": "string",
|
||||||
|
"value": "float",
|
||||||
|
"unit": "string",
|
||||||
|
"reference_range": "string",
|
||||||
|
"flag": "normal|high|low|critical"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"appointments": [
|
||||||
|
{
|
||||||
|
"date": "YYYY-MM-DD",
|
||||||
|
"type": "string",
|
||||||
|
"provider": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"urgency_classification": {
|
||||||
|
"level": "ROUTINE|URGENT|EMERGENCY|CRITICAL",
|
||||||
|
"reasoning": "string (brief explanation for classification)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main():
|
||||||
|
result = signature_generator(CR_PROMPT)
|
||||||
|
print(signature_generator.generate_code(result))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from typing import Dict, Any
|
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
def to_snake_case(name: str) -> str:
|
|
||||||
"""Convert PascalCase or camelCase string to snake_case."""
|
|
||||||
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
||||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def save_signature_to_file(
|
|
||||||
result: Dict[str, Any], output_path: str | None = None
|
|
||||||
) -> str:
|
|
||||||
"""Save generated signature code to a Python file"""
|
|
||||||
if not output_path:
|
|
||||||
signature_name = result.get("signature_name", "generated_signature")
|
|
||||||
output_path = f"{to_snake_case(signature_name)}.py"
|
|
||||||
|
|
||||||
if result.get("code"):
|
|
||||||
with open(output_path, "w") as f:
|
|
||||||
f.write(result["code"])
|
|
||||||
return output_path
|
|
||||||
else:
|
|
||||||
raise ValueError("No code to save")
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"lm": "gemini/gemini-2.5-pro-preview-03-25",
|
"lm": "gpt-5-2025-08-07",
|
||||||
"refine_lm": "gemini/gemini-2.5-pro-preview-03-25",
|
"refine_lm": "gpt-5-2025-08-07",
|
||||||
"max_tokens": 4096,
|
"max_tokens": 16000,
|
||||||
"temperature": 0.7,
|
"temperature": 1.0,
|
||||||
"max_attempts_to_refine": 5
|
"max_attempts_to_refine": 5
|
||||||
}
|
}
|
||||||
90
main.py
90
main.py
@@ -1,8 +1,98 @@
|
|||||||
from agent import PromptToSignatureAgent, PromptToSignatureConfig
|
from agent import PromptToSignatureAgent, PromptToSignatureConfig
|
||||||
|
from modaic import AutoAgent
|
||||||
|
|
||||||
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
||||||
|
|
||||||
|
CR_PROMPT = """ 1. Medical Record Information Extraction and Clinical Priority Classification
|
||||||
|
|
||||||
|
**INPUT FORMAT**: Raw clinical notes in free-text format, typically 200-2000 words, containing unstructured medical documentation from patient encounters including history, examination findings, diagnoses, treatment plans, and follow-up instructions.
|
||||||
|
|
||||||
|
**TASK DESCRIPTION**: You are a medical information extraction system. Analyze the provided clinical notes and perform comprehensive structured data extraction along with risk stratification. Your task involves multiple sub-tasks executed simultaneously.
|
||||||
|
|
||||||
|
**EXTRACTION REQUIREMENTS**:
|
||||||
|
- **Patient Demographics**: Extract full legal name, age (in years), biological sex/gender, date of birth (format: YYYY-MM-DD), medical record number (MRN) if present
|
||||||
|
- **Primary Diagnosis**: Identify the main diagnosis with corresponding ICD-10 code, include diagnostic certainty level (confirmed/suspected/rule-out)
|
||||||
|
- **Secondary Diagnoses**: List all comorbidities and additional conditions mentioned, each with ICD-10 codes where applicable
|
||||||
|
- **Medications**: Extract complete medication list including generic and brand names, dosages with units (mg, mcg, mL), frequency (BID, TID, QID, PRN), route of administration (PO, IV, IM, topical), and duration if specified
|
||||||
|
- **Allergies**: Document all allergies with allergen name, reaction type (rash, anaphylaxis, nausea, etc.), and severity classification (mild/moderate/severe/life-threatening)
|
||||||
|
- **Vital Signs**: Extract most recent measurements - blood pressure (systolic/diastolic in mmHg), heart rate (bpm), temperature (°F or °C with unit), respiratory rate (breaths/min), oxygen saturation (%), and pain score (0-10 scale)
|
||||||
|
- **Laboratory Results**: Identify all lab values mentioned with test name, numerical result, unit of measurement, reference range, and flag if abnormal (high/low/critical)
|
||||||
|
- **Appointments**: Extract scheduled follow-up dates, appointment types (follow-up, specialist referral, procedure), and provider names
|
||||||
|
|
||||||
|
**CLASSIFICATION REQUIREMENTS**:
|
||||||
|
- **Urgency Level**: Classify the case into one of four categories:
|
||||||
|
- ROUTINE: Stable patient, chronic condition management, no acute concerns
|
||||||
|
- URGENT: Requires attention within 24-48 hours, acute but not life-threatening condition
|
||||||
|
- EMERGENCY: Immediate intervention required, potentially life-threatening presentation
|
||||||
|
- CRITICAL: Life-threatening emergency requiring immediate intervention (ICU-level care)
|
||||||
|
|
||||||
|
**OUTPUT FORMAT**: Return the following schema as a pydantic model:
|
||||||
|
{
|
||||||
|
"patient_demographics": {
|
||||||
|
"name": "string",
|
||||||
|
"age": "integer",
|
||||||
|
"gender": "string",
|
||||||
|
"dob": "YYYY-MM-DD",
|
||||||
|
"mrn": "string or null"
|
||||||
|
},
|
||||||
|
"primary_diagnosis": {
|
||||||
|
"condition": "string",
|
||||||
|
"icd10_code": "integer",
|
||||||
|
"certainty": "confirmed|suspected|rule-out"
|
||||||
|
},
|
||||||
|
"secondary_diagnoses": [
|
||||||
|
{"condition": "string", "icd10_code": "string"}
|
||||||
|
],
|
||||||
|
"medications": [
|
||||||
|
{
|
||||||
|
"name": "string",
|
||||||
|
"dosage": "string",
|
||||||
|
"frequency": "string",
|
||||||
|
"route": "string",
|
||||||
|
"duration": "string or null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"allergies": [
|
||||||
|
{
|
||||||
|
"allergen": "string",
|
||||||
|
"reaction": "string",
|
||||||
|
"severity": "mild|moderate|severe|life-threatening"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"vital_signs": {
|
||||||
|
"blood_pressure": "string (systolic/diastolic)",
|
||||||
|
"heart_rate": "integer",
|
||||||
|
"temperature": "float with unit",
|
||||||
|
"respiratory_rate": "integer",
|
||||||
|
"oxygen_saturation": "integer",
|
||||||
|
"pain_score": "integer (0-10)"
|
||||||
|
},
|
||||||
|
"lab_results": [
|
||||||
|
{
|
||||||
|
"test_name": "string",
|
||||||
|
"value": "float",
|
||||||
|
"unit": "string",
|
||||||
|
"reference_range": "string",
|
||||||
|
"flag": "normal|high|low|critical"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"appointments": [
|
||||||
|
{
|
||||||
|
"date": "YYYY-MM-DD",
|
||||||
|
"type": "string",
|
||||||
|
"provider": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"urgency_classification": {
|
||||||
|
"level": "ROUTINE|URGENT|EMERGENCY|CRITICAL",
|
||||||
|
"reasoning": "string (brief explanation for classification)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "prompt-to-signature"
|
name = "prompt-to-signature"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "A DSPy agent to convert prompts into DSPY signatures."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = ["dspy>=3.0.3", "modaic>=0.4.0", "rich>=14.2.0", "typer>=0.20.0"]
|
dependencies = ["dspy>=3.0.3", "modaic>=0.4.0", "rich>=14.2.0", "typer>=0.20.0"]
|
||||||
|
|||||||
Reference in New Issue
Block a user