Use structured outputs with Pydantic models instead of text parsing
This commit is contained in:
@@ -34,15 +34,15 @@ def is_pydantic_model(type_hint: Any) -> bool:
|
||||
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
||||
"""
|
||||
try:
|
||||
# Direct Pydantic model
|
||||
# direct Pydantic model
|
||||
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
return True
|
||||
|
||||
# Generic type (list, dict, etc.)
|
||||
# generic type (list, dict, etc.)
|
||||
origin = get_origin(type_hint)
|
||||
if origin is not None:
|
||||
args = get_args(type_hint)
|
||||
# Check if any type argument is a Pydantic model
|
||||
# check if any type argument is a Pydantic model
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
@@ -67,7 +67,7 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Handle generic types (list, dict, etc.)
|
||||
# handle generic types (list, dict, etc.)
|
||||
if origin is list:
|
||||
# list[Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
||||
@@ -76,13 +76,10 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": model.model_json_schema()
|
||||
}
|
||||
"items": {"type": "array", "items": model.model_json_schema()}
|
||||
},
|
||||
"required": ["items"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported list type: {type_hint}")
|
||||
@@ -90,30 +87,34 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
elif origin is dict:
|
||||
# dict[str, Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
if (
|
||||
len(args) >= 2
|
||||
and isinstance(args[1], type)
|
||||
and issubclass(args[1], BaseModel)
|
||||
):
|
||||
model = args[1]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values": {
|
||||
"type": "object",
|
||||
"additionalProperties": model.model_json_schema()
|
||||
"additionalProperties": model.model_json_schema(),
|
||||
}
|
||||
},
|
||||
"required": ["values"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported dict type: {type_hint}")
|
||||
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
# Direct Pydantic model - already an object
|
||||
# direct Pydantic model - already an object
|
||||
schema = type_hint.model_json_schema()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
||||
|
||||
# Recursively set additionalProperties: false for all nested objects
|
||||
# recursively set additionalProperties: false for all nested objects
|
||||
def set_additional_properties(obj: dict[str, Any]) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
||||
@@ -158,18 +159,18 @@ def parse_json_response(
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Parse string to dict/list if needed
|
||||
# parse string to dict/list if needed
|
||||
if isinstance(response, str):
|
||||
parsed = json.loads(response)
|
||||
else:
|
||||
parsed = response
|
||||
|
||||
# Handle list[Model]
|
||||
# handle list[Model]
|
||||
if origin is list:
|
||||
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||
model = args[0]
|
||||
|
||||
# Unwrap from {"items": [...]} if present (from structured_output)
|
||||
# unwrap from {"items": [...]} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "items" in parsed:
|
||||
parsed = parsed["items"]
|
||||
|
||||
@@ -177,12 +178,16 @@ def parse_json_response(
|
||||
raise ValueError(f"Expected list, got {type(parsed)}")
|
||||
return [model.model_validate(item) for item in parsed]
|
||||
|
||||
# Handle dict[str, Model]
|
||||
# handle dict[str, Model]
|
||||
elif origin is dict:
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
if (
|
||||
len(args) >= 2
|
||||
and isinstance(args[1], type)
|
||||
and issubclass(args[1], BaseModel)
|
||||
):
|
||||
model = args[1]
|
||||
|
||||
# Unwrap from {"values": {...}} if present (from structured_output)
|
||||
# unwrap from {"values": {...}} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "values" in parsed:
|
||||
parsed = parsed["values"]
|
||||
|
||||
@@ -190,7 +195,7 @@ def parse_json_response(
|
||||
raise ValueError(f"Expected dict, got {type(parsed)}")
|
||||
return {key: model.model_validate(value) for key, value in parsed.items()}
|
||||
|
||||
# Handle direct Pydantic model
|
||||
# handle direct Pydantic model
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
if isinstance(response, str):
|
||||
return type_hint.model_validate_json(response)
|
||||
|
||||
Reference in New Issue
Block a user