(no commit message)
This commit is contained in:
@@ -24,7 +24,7 @@ class FieldType(str, Enum):
|
|||||||
OPTIONAL_INT = "Optional[int]"
|
OPTIONAL_INT = "Optional[int]"
|
||||||
OPTIONAL_FLOAT = "Optional[float]"
|
OPTIONAL_FLOAT = "Optional[float]"
|
||||||
OPTIONAL_BOOL = "Optional[bool]"
|
OPTIONAL_BOOL = "Optional[bool]"
|
||||||
PYDANTIC_MODEL = "pydantic" # For nested structures
|
PYDANTIC_MODEL = "pydantic"
|
||||||
|
|
||||||
|
|
||||||
class FieldRole(str, Enum):
|
class FieldRole(str, Enum):
|
||||||
@@ -72,7 +72,7 @@ class PydanticModelSchema(BaseModel):
|
|||||||
"""Generate code for a single Pydantic field"""
|
"""Generate code for a single Pydantic field"""
|
||||||
indent_str = " " * indent
|
indent_str = " " * indent
|
||||||
|
|
||||||
# Determine type annotation
|
# determine type annotation
|
||||||
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||||
type_annotation = field_def.nested_model.model_name
|
type_annotation = field_def.nested_model.model_name
|
||||||
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||||
@@ -80,11 +80,12 @@ class PydanticModelSchema(BaseModel):
|
|||||||
else:
|
else:
|
||||||
type_annotation = field_def.type.value
|
type_annotation = field_def.type.value
|
||||||
|
|
||||||
# Add Optional wrapper if not required (but avoid double-wrapping)
|
|
||||||
|
# add Optional wrapper if not required (but avoid double-wrapping)
|
||||||
if not field_def.required and not type_annotation.startswith("Optional["):
|
if not field_def.required and not type_annotation.startswith("Optional["):
|
||||||
type_annotation = f"Optional[{type_annotation}]"
|
type_annotation = f"{type_annotation}"
|
||||||
|
|
||||||
# Build Field() arguments
|
# build Field() arguments
|
||||||
field_args = []
|
field_args = []
|
||||||
if field_def.description:
|
if field_def.description:
|
||||||
field_args.append(f'description="{field_def.description}"')
|
field_args.append(f'description="{field_def.description}"')
|
||||||
@@ -103,7 +104,7 @@ class PydanticModelSchema(BaseModel):
|
|||||||
nested = []
|
nested = []
|
||||||
for field_def in self.fields:
|
for field_def in self.fields:
|
||||||
if field_def.nested_model:
|
if field_def.nested_model:
|
||||||
# Add nested models first (depth-first)
|
# add nested models first (depth-first)
|
||||||
nested.extend(field_def.nested_model.get_all_nested_models())
|
nested.extend(field_def.nested_model.get_all_nested_models())
|
||||||
nested.append(field_def.nested_model)
|
nested.append(field_def.nested_model)
|
||||||
return nested
|
return nested
|
||||||
@@ -312,7 +313,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
for field_def in schema.fields:
|
for field_def in schema.fields:
|
||||||
field_name = field_def.name
|
field_name = field_def.name
|
||||||
|
|
||||||
# Determine Python type
|
# determine python type
|
||||||
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
if field_def.type == FieldType.PYDANTIC_MODEL and field_def.nested_model:
|
||||||
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
|
py_type = SignatureGenerator._create_dynamic_pydantic_model(field_def.nested_model)
|
||||||
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
elif field_def.type == FieldType.LITERAL and field_def.literal_values:
|
||||||
@@ -331,12 +332,12 @@ class SignatureGenerator(dspy.Module):
|
|||||||
}
|
}
|
||||||
py_type = type_map.get(type_str, str)
|
py_type = type_map.get(type_str, str)
|
||||||
|
|
||||||
# Wrap in Optional if not required (but avoid double-wrapping)
|
# wrap in Optional if not required (but avoid double-wrapping)
|
||||||
# Optional[X] is Union[X, None], so check if already a Union type
|
# optional[X] is Union[X, None], so check if already a Union type
|
||||||
if not field_def.required and get_origin(py_type) is not Union:
|
if not field_def.required and get_origin(py_type) is not Union:
|
||||||
py_type = Optional[py_type]
|
py_type = Optional[py_type]
|
||||||
|
|
||||||
# Create Pydantic field
|
# create Pydantic field
|
||||||
field_kwargs = {}
|
field_kwargs = {}
|
||||||
if field_def.description:
|
if field_def.description:
|
||||||
field_kwargs["description"] = field_def.description
|
field_kwargs["description"] = field_def.description
|
||||||
@@ -347,7 +348,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
if field_kwargs:
|
if field_kwargs:
|
||||||
class_attrs[field_name] = Field(**field_kwargs)
|
class_attrs[field_name] = Field(**field_kwargs)
|
||||||
|
|
||||||
# Create the dynamic class
|
# create the dynamic class
|
||||||
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
|
DynamicModel = type(schema.model_name, (BaseModel,), class_attrs)
|
||||||
return DynamicModel
|
return DynamicModel
|
||||||
|
|
||||||
@@ -360,7 +361,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
code_lines.extend(imports)
|
code_lines.extend(imports)
|
||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
|
|
||||||
# Generate Pydantic model classes first (if any)
|
# generate Pydantic model classes first (if any)
|
||||||
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
|
pydantic_models = cls._collect_pydantic_models(prediction.signature_fields)
|
||||||
if pydantic_models:
|
if pydantic_models:
|
||||||
for model_schema in pydantic_models:
|
for model_schema in pydantic_models:
|
||||||
@@ -368,7 +369,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
|
|
||||||
# Generate the main signature class
|
# generate the main signature class
|
||||||
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
code_lines.append(f"class {prediction.signature_name}(dspy.Signature):")
|
||||||
code_lines.append(f' """{prediction.task_description}"""')
|
code_lines.append(f' """{prediction.task_description}"""')
|
||||||
code_lines.append("")
|
code_lines.append("")
|
||||||
@@ -386,14 +387,14 @@ class SignatureGenerator(dspy.Module):
|
|||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field.pydantic_model_schema:
|
if field.pydantic_model_schema:
|
||||||
# Get all nested models first (depth-first)
|
# get all nested models first (depth-first)
|
||||||
nested_models = field.pydantic_model_schema.get_all_nested_models()
|
nested_models = field.pydantic_model_schema.get_all_nested_models()
|
||||||
for nested in nested_models:
|
for nested in nested_models:
|
||||||
if nested.model_name not in seen_names:
|
if nested.model_name not in seen_names:
|
||||||
models.append(nested)
|
models.append(nested)
|
||||||
seen_names.add(nested.model_name)
|
seen_names.add(nested.model_name)
|
||||||
|
|
||||||
# Then add the top-level model
|
# then add the top-level model
|
||||||
if field.pydantic_model_schema.model_name not in seen_names:
|
if field.pydantic_model_schema.model_name not in seen_names:
|
||||||
models.append(field.pydantic_model_schema)
|
models.append(field.pydantic_model_schema)
|
||||||
seen_names.add(field.pydantic_model_schema.model_name)
|
seen_names.add(field.pydantic_model_schema.model_name)
|
||||||
@@ -410,7 +411,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
for field in fields:
|
for field in fields:
|
||||||
if field.type == FieldType.PYDANTIC_MODEL:
|
if field.type == FieldType.PYDANTIC_MODEL:
|
||||||
needs_pydantic = True
|
needs_pydantic = True
|
||||||
# Check nested models for their typing requirements
|
# check nested models for their typing requirements
|
||||||
if field.pydantic_model_schema:
|
if field.pydantic_model_schema:
|
||||||
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
|
cls._collect_typing_imports_from_schema(field.pydantic_model_schema, typing_imports)
|
||||||
elif field.type == FieldType.LITERAL:
|
elif field.type == FieldType.LITERAL:
|
||||||
@@ -469,7 +470,7 @@ class SignatureGenerator(dspy.Module):
|
|||||||
if not field_def.required:
|
if not field_def.required:
|
||||||
typing_imports.add("Optional")
|
typing_imports.add("Optional")
|
||||||
|
|
||||||
# Recurse for nested models
|
# recurse for nested models
|
||||||
if field_def.nested_model:
|
if field_def.nested_model:
|
||||||
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
|
cls._collect_typing_imports_from_schema(field_def.nested_model, typing_imports)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user