Files
prompt-to-signature/agent/modules.py
2025-11-04 11:01:43 -05:00

475 lines
19 KiB
Python

import dspy
from enum import Enum
from typing import Any, Dict, List, Optional, Literal
from pydantic import BaseModel, Field, ValidationError
class FieldType(str, Enum):
STRING = "str"
INTEGER = "int"
FLOAT = "float"
BOOLEAN = "bool"
LIST_STRING = "list[str]"
LIST_INT = "list[int]"
LIST_FLOAT = "list[float]"
DICT_STR_STR = "dict[str, str]"
DICT_STR_INT = "dict[str, int]"
DICT_STR_FLOAT = "dict[str, float]"
DICT_STR_BOOL = "dict[str, bool]"
DICT_STR_ANY = "dict[str, Any]"
LIST_DICT = "list[dict]"
IMAGE = "dspy.Image"
AUDIO = "dspy.Audio"
LITERAL = "Literal"
OPTIONAL_STR = "Optional[str]"
OPTIONAL_INT = "Optional[int]"
OPTIONAL_FLOAT = "Optional[float]"
OPTIONAL_BOOL = "Optional[bool]"
PYDANTIC_MODEL = "pydantic" # For nested structures
class FieldRole(str, Enum):
INPUT = "input"
OUTPUT = "output"
class PydanticFieldDef(BaseModel):
"""Represents a single field in a Pydantic model"""
name: str = Field(description="Field name in the Pydantic model")
type: FieldType = Field(description="Field type")
description: Optional[str] = Field(default=None, description="Field description")
required: bool = Field(default=True, description="Whether the field is required")
literal_values: Optional[List[str]] = Field(
default=None, description="For Literal types, the allowed values"
)
nested_model: Optional["PydanticModelSchema"] = Field(
default=None, description="For nested Pydantic models"
)
class PydanticModelSchema(BaseModel):
"""Represents a complete Pydantic model schema"""
model_name: str = Field(description="Name of the Pydantic model class (PascalCase)")
description: Optional[str] = Field(default=None, description="Model docstring")
fields: List[PydanticFieldDef] = Field(description="Fields in the model")
def to_code(self, indent: int = 0) -> str:
"""Generate Python code for this Pydantic model"""
indent_str = " " * indent
lines = []
lines.append(f"{indent_str}class {self.model_name}(BaseModel):")
if self.description:
lines.append(f'{indent_str} """{self.description}"""')
lines.append("")
for field_def in self.fields:
field_code = self._generate_field_code(field_def, indent + 1)
lines.append(field_code)
return "\n".join(lines)
def _generate_field_code(self, field_def: PydanticFieldDef, indent: int) -> str:
"""Generate code for a single Pydantic field"""
indent_str = " " * indent
# Determine type annotation
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
type_annotation = field_def.nested_model.model_name
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
type_annotation = f"Literal[{', '.join(repr(v) for v in field_def.literal_values)}]"
else:
type_annotation = field_def.type.value
# Add Optional wrapper if not required
if not field_def.required:
type_annotation = f"Optional[{type_annotation}]"
# Build Field() arguments
field_args = []
if field_def.description:
field_args.append(f'description="{field_def.description}"')
if not field_def.required:
field_args.append("default=None")
if field_args:
field_init = f"Field({', '.join(field_args)})"
else:
field_init = "Field()"
return f"{indent_str}{field_def.name}: {type_annotation} = {field_init}"
def get_all_nested_models(self) -> List["PydanticModelSchema"]:
"""Recursively collect all nested models"""
nested = []
for field_def in self.fields:
if field_def.nested_model:
# Add nested models first (depth-first)
nested.extend(field_def.nested_model.get_all_nested_models())
nested.append(field_def.nested_model)
return nested
class GeneratedField(BaseModel):
name: str = Field(description="The field name (snake_case, descriptive)")
type: FieldType = Field(description="The Python type for this field")
role: FieldRole = Field(description="Whether this is an input or output field")
description: str = Field(description="Description of what this field represents")
pydantic_model_schema: Optional[PydanticModelSchema] = Field(
default=None, description="For PYDANTIC_MODEL type, the model schema"
)
literal_values: Optional[List[str]] = Field(
default=None, description="For Literal types, the allowed values"
)
default_value: Optional[str] = Field(
default=None,
description="Default value for the field (as string representation)",
)
def to_dspy_field_code(self) -> str:
"""Generate DSPy field code for use in a Signature class"""
if self.type == FieldType.PYDANTIC_MODEL and self.pydantic_model_schema:
type_annotation = self.pydantic_model_schema.model_name
elif self.type == FieldType.LITERAL and self.literal_values:
type_annotation = (
f"Literal[{', '.join(repr(v) for v in self.literal_values)}]"
)
else:
type_annotation = self.type.value
field_type = "InputField" if self.role == FieldRole.INPUT else "OutputField"
if self.description:
return f'{self.name}: {type_annotation} = dspy.{field_type}(desc="{self.description}")'
else:
return f"{self.name}: {type_annotation} = dspy.{field_type}()"
class SignatureGeneration(dspy.Signature):
"""Generate a DSPy signature with proper field types.
IMPORTANT: For ANY structured/nested JSON output with multiple fields or nested objects:
- DO NOT use a simple str field with JSON instructions
- ALWAYS use type='pydantic' with a complete PydanticModelSchema
- Define the full nested structure with all fields properly typed
Examples:
- Simple output: "answer: str" is fine
- Complex output like medical records, forms, multi-field objects: MUST use pydantic models
"""
prompt: str = dspy.InputField(
desc="Natural language description of the desired functionality."
)
task_description: str = dspy.OutputField(
desc="Clear description of what the signature aims to accomplish."
)
signature_fields: list[GeneratedField] = dspy.OutputField(
desc="""List of input and output fields for the signature.
CRITICAL RULES:
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
4. Simple outputs (single values) can use basic types like str, int, bool
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
Examples:
- BAD: structured_output: str = "A JSON containing patient data..."
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"""
)
signature_name: str = dspy.OutputField(
desc="Suggested class name for the signature (PascalCase)"
)
class SignatureGenerator(dspy.Module):
def __init__(self):
super().__init__()
self.generator = dspy.ChainOfThought(SignatureGeneration)
def forward(self, prompt: str):
"""Generate DSPy signature and return raw prediction attributes"""
result = self.generator(prompt=prompt)
return dspy.Prediction(
signature_name=result.signature_name,
task_description=result.task_description,
signature_fields=result.signature_fields,
reasoning=result.reasoning if hasattr(result, "reasoning") else None,
)
def generate_signature(self, prompt: str) -> Dict[str, Any]:
"""Legacy method for backward compatibility - returns formatted dict"""
try:
result = self.forward(prompt=prompt)
return {
"signature_name": result.signature_name,
"task_description": result.task_description,
"fields": [field.model_dump() for field in result.signature_fields],
"code": self.generate_code(result),
"reasoning": result.reasoning,
}
except ValidationError as e:
error_msg = f"Data validation error from Pydantic: {e}"
return self._format_error(error_msg)
except Exception as e:
# Catch other potential errors (e.g., from dspy)
error_msg = f"An unexpected error occurred: {e}"
return self._format_error(error_msg)
def _format_error(self, error_message: str) -> Dict[str, Any]:
"""Helper to create a standardized error dictionary."""
return {
"error": error_message,
"signature_name": None,
"task_description": None,
"fields": [],
"code": None,
}
@classmethod
def create_signature_class(cls, prediction: dspy.Prediction) -> type:
"""
Dynamically creates a dspy.Signature class from a prediction object.
Args:
prediction: An object with attributes `signature_name`, `task_description`,
and `signature_fields`.
Returns:
A new class that inherits from dspy.Signature.
"""
class_name = prediction.signature_name
docstring = prediction.task_description
class_attrs = {"__doc__": docstring, "__annotations__": {}}
for field in prediction.signature_fields:
field_name = field.name
py_type = cls._get_python_type_from_field(field)
dspy_field_class = (
dspy.InputField if field.role == FieldRole.INPUT else dspy.OutputField
)
dspy_field_instance = dspy_field_class(desc=field.description)
class_attrs[field_name] = dspy_field_instance
class_attrs["__annotations__"][field_name] = py_type
DynamicSignature = type(class_name, (dspy.Signature,), class_attrs)
return DynamicSignature
@staticmethod
def _get_python_type_from_field(field: "GeneratedField") -> type:
"""Converts a GeneratedField into a Python type for annotations."""
type_str = field.type.value
type_map = {
"str": str,
"int": int,
"float": float,
"bool": bool,
"list[str]": List[str],
"list[int]": List[int],
"list[float]": List[float],
"dict[str, str]": Dict[str, str],
"dict[str, int]": Dict[str, int],
"dict[str, float]": Dict[str, float],
"dict[str, bool]": Dict[str, bool],
"dict[str, Any]": Dict[str, Any],
"list[dict]": List[dict],
"Optional[str]": Optional[str],
"Optional[int]": Optional[int],
"Optional[float]": Optional[float],
"Optional[bool]": Optional[bool],
"dspy.Image": dspy.Image,
"dspy.Audio": dspy.Audio,
}
if field.type == FieldType.LITERAL and field.literal_values:
return Literal[tuple(field.literal_values)]
if field.type == FieldType.PYDANTIC_MODEL and field.pydantic_model_schema:
# Dynamically create the Pydantic model class
return SignatureGenerator._create_dynamic_pydantic_model(field.pydantic_model_schema)
if type_str in type_map:
return type_map[type_str]
raise TypeError(
f"Unsupported field type for dynamic class creation: {type_str}"
)
@staticmethod
def _create_dynamic_pydantic_model(schema: PydanticModelSchema) -> type:
"""Dynamically create a Pydantic model class from a schema"""
class_attrs = {"__annotations__": {}}
if schema.description:
class_attrs["__doc__"] = schema.description
for field_def in schema.fields:
field_name = field_def.name
# Determine Python type
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
py_type = Literal[tuple(field_def.literal_values)]
else:
# Use the type map from _get_python_type_from_field
type_str = field_def.type.value
type_map = {
"str": str, "int": int, "float": float, "bool": bool,
"list[str]": List[str], "list[int]": List[int], "list[float]": List[float],
"dict[str, str]": Dict[str, str], "dict[str, int]": Dict[str, int],
"dict[str, float]": Dict[str, float], "dict[str, bool]": Dict[str, bool],
"dict[str, Any]": Dict[str, Any], "list[dict]": List[dict],
"Optional[str]": Optional[str], "Optional[int]": Optional[int],
"Optional[float]": Optional[float], "Optional[bool]": Optional[bool],
}
py_type = type_map.get(type_str, str)
# Wrap in Optional if not required
if not field_def.required and not isinstance(py_type, type(Optional[str])):
py_type = Optional[py_type]
# Create Pydantic field
field_kwargs = {}
if field_def.description:
field_kwargs["description"] = field_def.description
if not field_def.required:
field_kwargs["default"] = None
class_attrs["__annotations__"][field_name] = py_type
if field_kwargs:
class_attrs[field_name] = Field(**field_kwargs)
# Create the dynamic class
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
return DynamicModel
@classmethod
def generate_code(cls, prediction) -> str:
"""Generate Python code from a signature prediction"""
imports = cls.get_required_imports(prediction.signature_fields)
code_lines = []
code_lines.extend(imports)
code_lines.append("")
# Generate Pydantic model classes first (if any)
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
if pydantic_models:
for model_schema in pydantic_models:
code_lines.append(model_schema.to_code())
code_lines.append("")
code_lines.append("")
# Generate the main signature class
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
code_lines.append(f' """{prediction.task_description}"""')
code_lines.append("")
for field in prediction.signature_fields:
code_lines.append(f" {field.to_dspy_field_code()}")
return "\n".join(code_lines)
@classmethod
def _collect_pydantic_models(cls, fields: List[GeneratedField]) -> List[PydanticModelSchema]:
"""Collect all Pydantic models from fields, including nested ones"""
models = []
seen_names = set()
for field in fields:
if field.pydantic_model_schema:
# Get all nested models first (depth-first)
nested_models = field.pydantic_model_schema.get_all_nested_models()
for nested in nested_models:
if nested.model_name not in seen_names:
models.append(nested)
seen_names.add(nested.model_name)
# Then add the top-level model
if field.pydantic_model_schema.model_name not in seen_names:
models.append(field.pydantic_model_schema)
seen_names.add(field.pydantic_model_schema.model_name)
return models
@classmethod
def get_required_imports(cls, fields: List[GeneratedField]) -> List[str]:
"""Determine required imports based on field types"""
imports = ["import dspy"]
typing_imports = set()
needs_pydantic = False
for field in fields:
if field.type == FieldType.PYDANTIC_MODEL:
needs_pydantic = True
# Check nested models for their typing requirements
if field.pydantic_model_schema:
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
elif field.type == FieldType.LITERAL:
typing_imports.add("Literal")
elif field.type in [
FieldType.OPTIONAL_STR,
FieldType.OPTIONAL_INT,
FieldType.OPTIONAL_FLOAT,
FieldType.OPTIONAL_BOOL,
]:
typing_imports.add("Optional")
elif field.type in [
FieldType.LIST_STRING,
FieldType.LIST_INT,
FieldType.LIST_FLOAT,
FieldType.LIST_DICT,
]:
typing_imports.add("List")
elif field.type in [
FieldType.DICT_STR_STR,
FieldType.DICT_STR_INT,
FieldType.DICT_STR_FLOAT,
FieldType.DICT_STR_BOOL,
FieldType.DICT_STR_ANY,
]:
typing_imports.add("Dict")
if field.type == FieldType.DICT_STR_ANY:
typing_imports.add("Any")
if typing_imports:
imports.append(
f"from typing import {', '.join(sorted(list(typing_imports)))}"
)
if needs_pydantic:
imports.append("from pydantic import BaseModel, Field")
return imports
@classmethod
def _collect_typing_imports_from_schema(cls, schema: PydanticModelSchema, typing_imports: set):
"""Recursively collect typing imports needed for a Pydantic model schema"""
for field_def in schema.fields:
if field_def.type == FieldType.LITERAL:
typing_imports.add("Literal")
elif field_def.type in [FieldType.LIST_STRING, FieldType.LIST_INT, FieldType.LIST_FLOAT, FieldType.LIST_DICT]:
typing_imports.add("List")
elif field_def.type in [FieldType.DICT_STR_STR, FieldType.DICT_STR_INT,
FieldType.DICT_STR_FLOAT, FieldType.DICT_STR_BOOL, FieldType.DICT_STR_ANY]:
typing_imports.add("Dict")
if field_def.type == FieldType.DICT_STR_ANY:
typing_imports.add("Any")
if not field_def.required:
typing_imports.add("Optional")
# Recurse for nested models
if field_def.nested_model:
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
signature_generator = SignatureGenerator()