Use structured outputs with Pydantic models instead of text parsing
This commit is contained in:
@@ -27,24 +27,96 @@ class Usage:
|
||||
|
||||
|
||||
def is_pydantic_model(type_hint: Any) -> bool:
|
||||
"""Check if a type hint is a Pydantic model."""
|
||||
"""Check if a type hint is a Pydantic model or contains one (e.g., list[Model]).
|
||||
|
||||
Returns True for:
|
||||
- Pydantic models: MyModel
|
||||
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
||||
"""
|
||||
try:
|
||||
return isinstance(type_hint, type) and issubclass(type_hint, BaseModel)
|
||||
# Direct Pydantic model
|
||||
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
return True
|
||||
|
||||
# Generic type (list, dict, etc.)
|
||||
origin = get_origin(type_hint)
|
||||
if origin is not None:
|
||||
args = get_args(type_hint)
|
||||
# Check if any type argument is a Pydantic model
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
|
||||
return False
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_json_schema(pydantic_model: type[BaseModel]) -> dict[str, Any]:
|
||||
"""Generate JSON schema from Pydantic model.
|
||||
def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
"""Generate JSON schema from type hint.
|
||||
|
||||
Sets additionalProperties to false to match Codex behavior.
|
||||
Handles:
|
||||
- Pydantic models: MyModel
|
||||
- Generic types: list[MyModel], dict[str, MyModel]
|
||||
|
||||
Note: Claude API requires root type to be "object" for structured outputs (tools).
|
||||
For list/dict types, we wrap them in an object with a single property.
|
||||
|
||||
Sets additionalProperties to false for all objects.
|
||||
"""
|
||||
schema = pydantic_model.model_json_schema()
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Recursively set additionalProperties: false for all objects
|
||||
# Handle generic types (list, dict, etc.)
|
||||
if origin is list:
|
||||
# list[Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
||||
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||
model = args[0]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": model.model_json_schema()
|
||||
}
|
||||
},
|
||||
"required": ["items"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported list type: {type_hint}")
|
||||
|
||||
elif origin is dict:
|
||||
# dict[str, Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
model = args[1]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values": {
|
||||
"type": "object",
|
||||
"additionalProperties": model.model_json_schema()
|
||||
}
|
||||
},
|
||||
"required": ["values"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported dict type: {type_hint}")
|
||||
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
# Direct Pydantic model - already an object
|
||||
schema = type_hint.model_json_schema()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
||||
|
||||
# Recursively set additionalProperties: false for all nested objects
|
||||
def set_additional_properties(obj: dict[str, Any]) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if obj.get("type") == "object":
|
||||
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
||||
obj["additionalProperties"] = False
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
@@ -58,14 +130,74 @@ def get_json_schema(pydantic_model: type[BaseModel]) -> dict[str, Any]:
|
||||
return schema
|
||||
|
||||
|
||||
def parse_json_response(response: str, pydantic_model: type[BaseModel]) -> BaseModel:
|
||||
"""Parse JSON response into Pydantic model.
|
||||
def parse_json_response(
|
||||
response: str | dict | list, type_hint: Any
|
||||
) -> BaseModel | list[BaseModel] | dict[str, BaseModel]:
|
||||
"""Parse JSON response into typed output.
|
||||
|
||||
Handles:
|
||||
- Pydantic models: MyModel
|
||||
- Generic types: list[MyModel], dict[str, MyModel]
|
||||
|
||||
Note: When schema has list/dict at root, the SDK wraps them in {"items": [...]}
|
||||
or {"values": {...}} because API requires root type = "object".
|
||||
|
||||
Args:
|
||||
response: JSON string or already-parsed dict/list from structured_output
|
||||
type_hint: The output type annotation
|
||||
|
||||
Returns:
|
||||
Validated and typed output
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If response is not valid JSON
|
||||
pydantic.ValidationError: If JSON doesn't match model schema
|
||||
json.JSONDecodeError: If response string is not valid JSON
|
||||
pydantic.ValidationError: If JSON doesn't match schema
|
||||
"""
|
||||
return pydantic_model.model_validate_json(response)
|
||||
import json
|
||||
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Parse string to dict/list if needed
|
||||
if isinstance(response, str):
|
||||
parsed = json.loads(response)
|
||||
else:
|
||||
parsed = response
|
||||
|
||||
# Handle list[Model]
|
||||
if origin is list:
|
||||
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||
model = args[0]
|
||||
|
||||
# Unwrap from {"items": [...]} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "items" in parsed:
|
||||
parsed = parsed["items"]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
raise ValueError(f"Expected list, got {type(parsed)}")
|
||||
return [model.model_validate(item) for item in parsed]
|
||||
|
||||
# Handle dict[str, Model]
|
||||
elif origin is dict:
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
model = args[1]
|
||||
|
||||
# Unwrap from {"values": {...}} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "values" in parsed:
|
||||
parsed = parsed["values"]
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"Expected dict, got {type(parsed)}")
|
||||
return {key: model.model_validate(value) for key, value in parsed.items()}
|
||||
|
||||
# Handle direct Pydantic model
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
if isinstance(response, str):
|
||||
return type_hint.model_validate_json(response)
|
||||
else:
|
||||
return type_hint.model_validate(parsed)
|
||||
|
||||
raise ValueError(f"Unsupported type for parsing: {type_hint}")
|
||||
|
||||
|
||||
def extract_text_from_response(response: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user