diff --git a/agent/modules.py b/agent/modules.py index c6c1590..5660821 100644 --- a/agent/modules.py +++ b/agent/modules.py @@ -1,6 +1,6 @@ import dspy from enum import Enum -from typing import Any, Dict, List, Optional, Literal +from typing import Any, Dict, List, Optional, Literal, Union, get_origin from pydantic import BaseModel, Field, ValidationError class FieldType(str, Enum): @@ -80,8 +80,8 @@ class PydanticModelSchema(BaseModel): else: type_annotation = field_def.type.value - # Add Optional wrapper if not required - if not field_def.required: + # Add Optional wrapper if not required (but avoid double-wrapping) + if not field_def.required and not type_annotation.startswith("Optional["): type_annotation = f"Optional[{type_annotation}]" # Build Field() arguments @@ -331,8 +331,9 @@ class SignatureGenerator(dspy.Module): } py_type = type_map.get(type_str, str) - # Wrap in Optional if not required - if not field_def.required and not isinstance(py_type, type(Optional[str])): + # Wrap in Optional if not required (but avoid double-wrapping) + # Optional[X] is Union[X, None], so check if already a Union type + if not field_def.required and get_origin(py_type) is not Union: py_type = Optional[py_type] # Create Pydantic field