(no commit message)
This commit is contained in:
@@ -24,7 +24,7 @@ class FieldType(str, Enum):
|
||||
OPTIONAL_INT = "Optional[int]"
|
||||
OPTIONAL_FLOAT = "Optional[float]"
|
||||
OPTIONAL_BOOL = "Optional[bool]"
|
||||
PYDANTIC_MODEL = "pydantic" # For nested structures
|
||||
PYDANTIC_MODEL = "pydantic"
|
||||
|
||||
|
||||
class FieldRole(str, Enum):
|
||||
@@ -72,7 +72,7 @@ class PydanticModelSchema(BaseModel):
|
||||
"""Generate code for a single Pydantic field"""
|
||||
indent_str = " " * indent
|
||||
|
||||
# Determine type annotation
|
||||
# determine type annotation
|
||||
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||
type_annotation = field_def.nested_model.model_name
|
||||
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||
@@ -80,11 +80,12 @@ class PydanticModelSchema(BaseModel):
|
||||
else:
|
||||
type_annotation = field_def.type.value
|
||||
|
||||
# Add Optional wrapper if not required (but avoid double-wrapping)
|
||||
if not field_def.required and not type_annotation.startswith("Optional["):
|
||||
type_annotation = f"Optional[{type_annotation}]"
|
||||
|
||||
# Build Field() arguments
|
||||
# add Optional wrapper if not required (but avoid double-wrapping)
|
||||
if not field_def.required and not type_annotation.startswith("Optional["):
|
||||
type_annotation = f"{type_annotation}"
|
||||
|
||||
# build Field() arguments
|
||||
field_args = []
|
||||
if field_def.description:
|
||||
field_args.append(f'description="{field_def.description}"')
|
||||
@@ -103,7 +104,7 @@ class PydanticModelSchema(BaseModel):
|
||||
nested = []
|
||||
for field_def in self.fields:
|
||||
if field_def.nested_model:
|
||||
# Add nested models first (depth-first)
|
||||
# add nested models first (depth-first)
|
||||
nested.extend(field_def.nested_model.get_all_nested_models())
|
||||
nested.append(field_def.nested_model)
|
||||
return nested
|
||||
@@ -312,7 +313,7 @@ class SignatureGenerator(dspy.Module):
|
||||
for field_def in schema.fields:
|
||||
field_name = field_def.name
|
||||
|
||||
# Determine Python type
|
||||
# determine python type
|
||||
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
|
||||
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||
@@ -331,12 +332,12 @@ class SignatureGenerator(dspy.Module):
|
||||
}
|
||||
py_type = type_map.get(type_str, str)
|
||||
|
||||
# Wrap in Optional if not required (but avoid double-wrapping)
|
||||
# Optional[X] is Union[X, None], so check if already a Union type
|
||||
# wrap in Optional if not required (but avoid double-wrapping)
|
||||
# optional[X] is Union[X, None], so check if already a Union type
|
||||
if not field_def.required and get_origin(py_type) is not Union:
|
||||
py_type = Optional[py_type]
|
||||
|
||||
# Create Pydantic field
|
||||
# create Pydantic field
|
||||
field_kwargs = {}
|
||||
if field_def.description:
|
||||
field_kwargs["description"] = field_def.description
|
||||
@@ -347,7 +348,7 @@ class SignatureGenerator(dspy.Module):
|
||||
if field_kwargs:
|
||||
class_attrs[field_name] = Field(**field_kwargs)
|
||||
|
||||
# Create the dynamic class
|
||||
# create the dynamic class
|
||||
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
|
||||
return DynamicModel
|
||||
|
||||
@@ -360,7 +361,7 @@ class SignatureGenerator(dspy.Module):
|
||||
code_lines.extend(imports)
|
||||
code_lines.append("")
|
||||
|
||||
# Generate Pydantic model classes first (if any)
|
||||
# generate Pydantic model classes first (if any)
|
||||
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
|
||||
if pydantic_models:
|
||||
for model_schema in pydantic_models:
|
||||
@@ -368,7 +369,7 @@ class SignatureGenerator(dspy.Module):
|
||||
code_lines.append("")
|
||||
code_lines.append("")
|
||||
|
||||
# Generate the main signature class
|
||||
# generate the main signature class
|
||||
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
||||
code_lines.append(f' """{prediction.task_description}"""')
|
||||
code_lines.append("")
|
||||
@@ -386,14 +387,14 @@ class SignatureGenerator(dspy.Module):
|
||||
|
||||
for field in fields:
|
||||
if field.pydantic_model_schema:
|
||||
# Get all nested models first (depth-first)
|
||||
# get all nested models first (depth-first)
|
||||
nested_models = field.pydantic_model_schema.get_all_nested_models()
|
||||
for nested in nested_models:
|
||||
if nested.model_name not in seen_names:
|
||||
models.append(nested)
|
||||
seen_names.add(nested.model_name)
|
||||
|
||||
# Then add the top-level model
|
||||
# then add the top-level model
|
||||
if field.pydantic_model_schema.model_name not in seen_names:
|
||||
models.append(field.pydantic_model_schema)
|
||||
seen_names.add(field.pydantic_model_schema.model_name)
|
||||
@@ -410,7 +411,7 @@ class SignatureGenerator(dspy.Module):
|
||||
for field in fields:
|
||||
if field.type == FieldType.PYDANTIC_MODEL:
|
||||
needs_pydantic = True
|
||||
# Check nested models for their typing requirements
|
||||
# check nested models for their typing requirements
|
||||
if field.pydantic_model_schema:
|
||||
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
|
||||
elif field.type == FieldType.LITERAL:
|
||||
@@ -469,7 +470,7 @@ class SignatureGenerator(dspy.Module):
|
||||
if not field_def.required:
|
||||
typing_imports.add("Optional")
|
||||
|
||||
# Recurse for nested models
|
||||
# recurse for nested models
|
||||
if field_def.nested_model:
|
||||
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user