476 lines
19 KiB
Python
476 lines
19 KiB
Python
import dspy
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Literal, Union, get_origin
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
class FieldType(str, Enum):
|
|
STRING = "str"
|
|
INTEGER = "int"
|
|
FLOAT = "float"
|
|
BOOLEAN = "bool"
|
|
LIST_STRING = "list[str]"
|
|
LIST_INT = "list[int]"
|
|
LIST_FLOAT = "list[float]"
|
|
DICT_STR_STR = "dict[str, str]"
|
|
DICT_STR_INT = "dict[str, int]"
|
|
DICT_STR_FLOAT = "dict[str, float]"
|
|
DICT_STR_BOOL = "dict[str, bool]"
|
|
DICT_STR_ANY = "dict[str, Any]"
|
|
LIST_DICT = "list[dict]"
|
|
IMAGE = "dspy.Image"
|
|
AUDIO = "dspy.Audio"
|
|
LITERAL = "Literal"
|
|
OPTIONAL_STR = "Optional[str]"
|
|
OPTIONAL_INT = "Optional[int]"
|
|
OPTIONAL_FLOAT = "Optional[float]"
|
|
OPTIONAL_BOOL = "Optional[bool]"
|
|
PYDANTIC_MODEL = "pydantic" # For nested structures
|
|
|
|
|
|
class FieldRole(str, Enum):
|
|
INPUT = "input"
|
|
OUTPUT = "output"
|
|
|
|
|
|
class PydanticFieldDef(BaseModel):
|
|
"""Represents a single field in a Pydantic model"""
|
|
name: str = Field(description="Field name in the Pydantic model")
|
|
type: FieldType = Field(description="Field type")
|
|
description: Optional[str] = Field(default=None, description="Field description")
|
|
required: bool = Field(default=True, description="Whether the field is required")
|
|
literal_values: Optional[List[str]] = Field(
|
|
default=None, description="For Literal types, the allowed values"
|
|
)
|
|
nested_model: Optional["PydanticModelSchema"] = Field(
|
|
default=None, description="For nested Pydantic models"
|
|
)
|
|
|
|
|
|
class PydanticModelSchema(BaseModel):
|
|
"""Represents a complete Pydantic model schema"""
|
|
model_name: str = Field(description="Name of the Pydantic model class (PascalCase)")
|
|
description: Optional[str] = Field(default=None, description="Model docstring")
|
|
fields: List[PydanticFieldDef] = Field(description="Fields in the model")
|
|
|
|
def to_code(self, indent: int = 0) -> str:
|
|
"""Generate Python code for this Pydantic model"""
|
|
indent_str = " " * indent
|
|
lines = []
|
|
|
|
lines.append(f"{indent_str}class {self.model_name}(BaseModel):")
|
|
if self.description:
|
|
lines.append(f'{indent_str} """{self.description}"""')
|
|
lines.append("")
|
|
|
|
for field_def in self.fields:
|
|
field_code = self._generate_field_code(field_def, indent + 1)
|
|
lines.append(field_code)
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_field_code(self, field_def: PydanticFieldDef, indent: int) -> str:
|
|
"""Generate code for a single Pydantic field"""
|
|
indent_str = " " * indent
|
|
|
|
# Determine type annotation
|
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
|
type_annotation = field_def.nested_model.model_name
|
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
|
type_annotation = f"Literal[{', '.join(repr(v) for v in field_def.literal_values)}]"
|
|
else:
|
|
type_annotation = field_def.type.value
|
|
|
|
# Add Optional wrapper if not required (but avoid double-wrapping)
|
|
if not field_def.required and not type_annotation.startswith("Optional["):
|
|
type_annotation = f"Optional[{type_annotation}]"
|
|
|
|
# Build Field() arguments
|
|
field_args = []
|
|
if field_def.description:
|
|
field_args.append(f'description="{field_def.description}"')
|
|
if not field_def.required:
|
|
field_args.append("default=None")
|
|
|
|
if field_args:
|
|
field_init = f"Field({', '.join(field_args)})"
|
|
else:
|
|
field_init = "Field()"
|
|
|
|
return f"{indent_str}{field_def.name}: {type_annotation} = {field_init}"
|
|
|
|
def get_all_nested_models(self) -> List["PydanticModelSchema"]:
|
|
"""Recursively collect all nested models"""
|
|
nested = []
|
|
for field_def in self.fields:
|
|
if field_def.nested_model:
|
|
# Add nested models first (depth-first)
|
|
nested.extend(field_def.nested_model.get_all_nested_models())
|
|
nested.append(field_def.nested_model)
|
|
return nested
|
|
|
|
|
|
class GeneratedField(BaseModel):
|
|
name: str = Field(description="The field name (snake_case, descriptive)")
|
|
type: FieldType = Field(description="The Python type for this field")
|
|
role: FieldRole = Field(description="Whether this is an input or output field")
|
|
description: str = Field(description="Description of what this field represents")
|
|
pydantic_model_schema: Optional[PydanticModelSchema] = Field(
|
|
default=None, description="For PYDANTIC_MODEL type, the model schema"
|
|
)
|
|
literal_values: Optional[List[str]] = Field(
|
|
default=None, description="For Literal types, the allowed values"
|
|
)
|
|
default_value: Optional[str] = Field(
|
|
default=None,
|
|
description="Default value for the field (as string representation)",
|
|
)
|
|
|
|
def to_dspy_field_code(self) -> str:
|
|
"""Generate DSPy field code for use in a Signature class"""
|
|
if self.type == FieldType.PYDANTIC_MODEL and self.pydantic_model_schema:
|
|
type_annotation = self.pydantic_model_schema.model_name
|
|
elif self.type == FieldType.LITERAL and self.literal_values:
|
|
type_annotation = (
|
|
f"Literal[{', '.join(repr(v) for v in self.literal_values)}]"
|
|
)
|
|
else:
|
|
type_annotation = self.type.value
|
|
|
|
field_type = "InputField" if self.role == FieldRole.INPUT else "OutputField"
|
|
|
|
if self.description:
|
|
return f'{self.name}: {type_annotation} = dspy.{field_type}(desc="{self.description}")'
|
|
else:
|
|
return f"{self.name}: {type_annotation} = dspy.{field_type}()"
|
|
|
|
|
|
class SignatureGeneration(dspy.Signature):
|
|
"""Generate a DSPy signature with proper field types.
|
|
|
|
IMPORTANT: For ANY structured/nested JSON output with multiple fields or nested objects:
|
|
- DO NOT use a simple str field with JSON instructions
|
|
- ALWAYS use type='pydantic' with a complete PydanticModelSchema
|
|
- Define the full nested structure with all fields properly typed
|
|
|
|
Examples:
|
|
- Simple output: "answer: str" is fine
|
|
- Complex output like medical records, forms, multi-field objects: MUST use pydantic models
|
|
"""
|
|
|
|
prompt: str = dspy.InputField(
|
|
desc="Natural language description of the desired functionality."
|
|
)
|
|
task_description: str = dspy.OutputField(
|
|
desc="Clear description of what the signature aims to accomplish."
|
|
)
|
|
signature_fields: list[GeneratedField] = dspy.OutputField(
|
|
desc="""List of input and output fields for the signature.
|
|
|
|
CRITICAL RULES:
|
|
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
|
|
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
|
|
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
|
|
4. Simple outputs (single values) can use basic types like str, int, bool
|
|
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
|
|
|
|
Examples:
|
|
- BAD: structured_output: str = "A JSON containing patient data..."
|
|
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"""
|
|
)
|
|
signature_name: str = dspy.OutputField(
|
|
desc="Suggested class name for the signature (PascalCase)"
|
|
)
|
|
|
|
|
|
class SignatureGenerator(dspy.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.generator = dspy.ChainOfThought(SignatureGeneration)
|
|
|
|
def forward(self, prompt: str):
|
|
"""Generate DSPy signature and return raw prediction attributes"""
|
|
result = self.generator(prompt=prompt)
|
|
|
|
return dspy.Prediction(
|
|
signature_name=result.signature_name,
|
|
task_description=result.task_description,
|
|
signature_fields=result.signature_fields,
|
|
reasoning=result.reasoning if hasattr(result, "reasoning") else None,
|
|
)
|
|
|
|
def generate_signature(self, prompt: str) -> Dict[str, Any]:
|
|
"""Legacy method for backward compatibility - returns formatted dict"""
|
|
try:
|
|
result = self.forward(prompt=prompt)
|
|
|
|
return {
|
|
"signature_name": result.signature_name,
|
|
"task_description": result.task_description,
|
|
"fields": [field.model_dump() for field in result.signature_fields],
|
|
"code": self.generate_code(result),
|
|
"reasoning": result.reasoning,
|
|
}
|
|
except ValidationError as e:
|
|
error_msg = f"Data validation error from Pydantic: {e}"
|
|
return self._format_error(error_msg)
|
|
except Exception as e:
|
|
# Catch other potential errors (e.g., from dspy)
|
|
error_msg = f"An unexpected error occurred: {e}"
|
|
return self._format_error(error_msg)
|
|
|
|
def _format_error(self, error_message: str) -> Dict[str, Any]:
|
|
"""Helper to create a standardized error dictionary."""
|
|
return {
|
|
"error": error_message,
|
|
"signature_name": None,
|
|
"task_description": None,
|
|
"fields": [],
|
|
"code": None,
|
|
}
|
|
|
|
@classmethod
|
|
def create_signature_class(cls, prediction: dspy.Prediction) -> type:
|
|
"""
|
|
Dynamically creates a dspy.Signature class from a prediction object.
|
|
|
|
Args:
|
|
prediction: An object with attributes `signature_name`, `task_description`,
|
|
and `signature_fields`.
|
|
|
|
Returns:
|
|
A new class that inherits from dspy.Signature.
|
|
"""
|
|
class_name = prediction.signature_name
|
|
docstring = prediction.task_description
|
|
|
|
class_attrs = {"__doc__": docstring, "__annotations__": {}}
|
|
|
|
for field in prediction.signature_fields:
|
|
field_name = field.name
|
|
py_type = cls._get_python_type_from_field(field)
|
|
|
|
dspy_field_class = (
|
|
dspy.InputField if field.role == FieldRole.INPUT else dspy.OutputField
|
|
)
|
|
dspy_field_instance = dspy_field_class(desc=field.description)
|
|
|
|
class_attrs[field_name] = dspy_field_instance
|
|
class_attrs["__annotations__"][field_name] = py_type
|
|
|
|
DynamicSignature = type(class_name, (dspy.Signature,), class_attrs)
|
|
return DynamicSignature
|
|
|
|
@staticmethod
|
|
def _get_python_type_from_field(field: "GeneratedField") -> type:
|
|
"""Converts a GeneratedField into a Python type for annotations."""
|
|
type_str = field.type.value
|
|
|
|
type_map = {
|
|
"str": str,
|
|
"int": int,
|
|
"float": float,
|
|
"bool": bool,
|
|
"list[str]": List[str],
|
|
"list[int]": List[int],
|
|
"list[float]": List[float],
|
|
"dict[str, str]": Dict[str, str],
|
|
"dict[str, int]": Dict[str, int],
|
|
"dict[str, float]": Dict[str, float],
|
|
"dict[str, bool]": Dict[str, bool],
|
|
"dict[str, Any]": Dict[str, Any],
|
|
"list[dict]": List[dict],
|
|
"Optional[str]": Optional[str],
|
|
"Optional[int]": Optional[int],
|
|
"Optional[float]": Optional[float],
|
|
"Optional[bool]": Optional[bool],
|
|
"dspy.Image": dspy.Image,
|
|
"dspy.Audio": dspy.Audio,
|
|
}
|
|
|
|
if field.type == FieldType.LITERAL and field.literal_values:
|
|
return Literal[tuple(field.literal_values)]
|
|
|
|
if field.type == FieldType.PYDANTIC_MODEL and field.pydantic_model_schema:
|
|
# Dynamically create the Pydantic model class
|
|
return SignatureGenerator._create_dynamic_pydantic_model(field.pydantic_model_schema)
|
|
|
|
if type_str in type_map:
|
|
return type_map[type_str]
|
|
|
|
raise TypeError(
|
|
f"Unsupported field type for dynamic class creation: {type_str}"
|
|
)
|
|
|
|
@staticmethod
|
|
def _create_dynamic_pydantic_model(schema: PydanticModelSchema) -> type:
|
|
"""Dynamically create a Pydantic model class from a schema"""
|
|
class_attrs = {"__annotations__": {}}
|
|
|
|
if schema.description:
|
|
class_attrs["__doc__"] = schema.description
|
|
|
|
for field_def in schema.fields:
|
|
field_name = field_def.name
|
|
|
|
# Determine Python type
|
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
|
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
|
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
|
py_type = Literal[tuple(field_def.literal_values)]
|
|
else:
|
|
# Use the type map from _get_python_type_from_field
|
|
type_str = field_def.type.value
|
|
type_map = {
|
|
"str": str, "int": int, "float": float, "bool": bool,
|
|
"list[str]": List[str], "list[int]": List[int], "list[float]": List[float],
|
|
"dict[str, str]": Dict[str, str], "dict[str, int]": Dict[str, int],
|
|
"dict[str, float]": Dict[str, float], "dict[str, bool]": Dict[str, bool],
|
|
"dict[str, Any]": Dict[str, Any], "list[dict]": List[dict],
|
|
"Optional[str]": Optional[str], "Optional[int]": Optional[int],
|
|
"Optional[float]": Optional[float], "Optional[bool]": Optional[bool],
|
|
}
|
|
py_type = type_map.get(type_str, str)
|
|
|
|
# Wrap in Optional if not required (but avoid double-wrapping)
|
|
# Optional[X] is Union[X, None], so check if already a Union type
|
|
if not field_def.required and get_origin(py_type) is not Union:
|
|
py_type = Optional[py_type]
|
|
|
|
# Create Pydantic field
|
|
field_kwargs = {}
|
|
if field_def.description:
|
|
field_kwargs["description"] = field_def.description
|
|
if not field_def.required:
|
|
field_kwargs["default"] = None
|
|
|
|
class_attrs["__annotations__"][field_name] = py_type
|
|
if field_kwargs:
|
|
class_attrs[field_name] = Field(**field_kwargs)
|
|
|
|
# Create the dynamic class
|
|
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
|
|
return DynamicModel
|
|
|
|
@classmethod
|
|
def generate_code(cls, prediction) -> str:
|
|
"""Generate Python code from a signature prediction"""
|
|
imports = cls.get_required_imports(prediction.signature_fields)
|
|
|
|
code_lines = []
|
|
code_lines.extend(imports)
|
|
code_lines.append("")
|
|
|
|
# Generate Pydantic model classes first (if any)
|
|
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
|
|
if pydantic_models:
|
|
for model_schema in pydantic_models:
|
|
code_lines.append(model_schema.to_code())
|
|
code_lines.append("")
|
|
code_lines.append("")
|
|
|
|
# Generate the main signature class
|
|
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
|
code_lines.append(f' """{prediction.task_description}"""')
|
|
code_lines.append("")
|
|
|
|
for field in prediction.signature_fields:
|
|
code_lines.append(f" {field.to_dspy_field_code()}")
|
|
|
|
return "\n".join(code_lines)
|
|
|
|
@classmethod
|
|
def _collect_pydantic_models(cls, fields: List[GeneratedField]) -> List[PydanticModelSchema]:
|
|
"""Collect all Pydantic models from fields, including nested ones"""
|
|
models = []
|
|
seen_names = set()
|
|
|
|
for field in fields:
|
|
if field.pydantic_model_schema:
|
|
# Get all nested models first (depth-first)
|
|
nested_models = field.pydantic_model_schema.get_all_nested_models()
|
|
for nested in nested_models:
|
|
if nested.model_name not in seen_names:
|
|
models.append(nested)
|
|
seen_names.add(nested.model_name)
|
|
|
|
# Then add the top-level model
|
|
if field.pydantic_model_schema.model_name not in seen_names:
|
|
models.append(field.pydantic_model_schema)
|
|
seen_names.add(field.pydantic_model_schema.model_name)
|
|
|
|
return models
|
|
|
|
@classmethod
|
|
def get_required_imports(cls, fields: List[GeneratedField]) -> List[str]:
|
|
"""Determine required imports based on field types"""
|
|
imports = ["import dspy"]
|
|
typing_imports = set()
|
|
needs_pydantic = False
|
|
|
|
for field in fields:
|
|
if field.type == FieldType.PYDANTIC_MODEL:
|
|
needs_pydantic = True
|
|
# Check nested models for their typing requirements
|
|
if field.pydantic_model_schema:
|
|
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
|
|
elif field.type == FieldType.LITERAL:
|
|
typing_imports.add("Literal")
|
|
elif field.type in [
|
|
FieldType.OPTIONAL_STR,
|
|
FieldType.OPTIONAL_INT,
|
|
FieldType.OPTIONAL_FLOAT,
|
|
FieldType.OPTIONAL_BOOL,
|
|
]:
|
|
typing_imports.add("Optional")
|
|
elif field.type in [
|
|
FieldType.LIST_STRING,
|
|
FieldType.LIST_INT,
|
|
FieldType.LIST_FLOAT,
|
|
FieldType.LIST_DICT,
|
|
]:
|
|
typing_imports.add("List")
|
|
elif field.type in [
|
|
FieldType.DICT_STR_STR,
|
|
FieldType.DICT_STR_INT,
|
|
FieldType.DICT_STR_FLOAT,
|
|
FieldType.DICT_STR_BOOL,
|
|
FieldType.DICT_STR_ANY,
|
|
]:
|
|
typing_imports.add("Dict")
|
|
|
|
if field.type == FieldType.DICT_STR_ANY:
|
|
typing_imports.add("Any")
|
|
|
|
if typing_imports:
|
|
imports.append(
|
|
f"from typing import {', '.join(sorted(list(typing_imports)))}"
|
|
)
|
|
|
|
if needs_pydantic:
|
|
imports.append("from pydantic import BaseModel, Field")
|
|
|
|
return imports
|
|
|
|
@classmethod
|
|
def _collect_typing_imports_from_schema(cls, schema: PydanticModelSchema, typing_imports: set):
|
|
"""Recursively collect typing imports needed for a Pydantic model schema"""
|
|
for field_def in schema.fields:
|
|
if field_def.type == FieldType.LITERAL:
|
|
typing_imports.add("Literal")
|
|
elif field_def.type in [FieldType.LIST_STRING, FieldType.LIST_INT, FieldType.LIST_FLOAT, FieldType.LIST_DICT]:
|
|
typing_imports.add("List")
|
|
elif field_def.type in [FieldType.DICT_STR_STR, FieldType.DICT_STR_INT,
|
|
FieldType.DICT_STR_FLOAT, FieldType.DICT_STR_BOOL, FieldType.DICT_STR_ANY]:
|
|
typing_imports.add("Dict")
|
|
|
|
if field_def.type == FieldType.DICT_STR_ANY:
|
|
typing_imports.add("Any")
|
|
|
|
if not field_def.required:
|
|
typing_imports.add("Optional")
|
|
|
|
# Recurse for nested models
|
|
if field_def.nested_model:
|
|
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
|
|
|
|
signature_generator = SignatureGenerator() |