215 lines
7.1 KiB
Python
215 lines
7.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, get_origin, get_args
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
@dataclass
|
|
class Usage:
|
|
"""Token usage statistics."""
|
|
|
|
input_tokens: int = 0
|
|
cached_input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Total tokens used (input + output)."""
|
|
return self.input_tokens + self.output_tokens
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"Usage(input={self.input_tokens}, "
|
|
f"cached={self.cached_input_tokens}, "
|
|
f"output={self.output_tokens}, "
|
|
f"total={self.total_tokens})"
|
|
)
|
|
|
|
|
|
def is_pydantic_model(type_hint: Any) -> bool:
|
|
"""Check if a type hint is a Pydantic model or contains one (e.g., list[Model]).
|
|
|
|
Returns True for:
|
|
- Pydantic models: MyModel
|
|
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
|
"""
|
|
try:
|
|
# direct Pydantic model
|
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
|
return True
|
|
|
|
# generic type (list, dict, etc.)
|
|
origin = get_origin(type_hint)
|
|
if origin is not None:
|
|
args = get_args(type_hint)
|
|
# check if any type argument is a Pydantic model
|
|
for arg in args:
|
|
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
|
return True
|
|
|
|
return False
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
|
"""Generate JSON schema from type hint.
|
|
|
|
Handles:
|
|
- Pydantic models: MyModel
|
|
- Generic types: list[MyModel], dict[str, MyModel]
|
|
|
|
Note: Claude API requires root type to be "object" for structured outputs (tools).
|
|
For list/dict types, we wrap them in an object with a single property.
|
|
|
|
Sets additionalProperties to false for all objects.
|
|
"""
|
|
origin = get_origin(type_hint)
|
|
args = get_args(type_hint)
|
|
|
|
# handle generic types (list, dict, etc.)
|
|
if origin is list:
|
|
# list[Model] - wrap in object since API requires root type = "object"
|
|
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
|
model = args[0]
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"items": {"type": "array", "items": model.model_json_schema()}
|
|
},
|
|
"required": ["items"],
|
|
"additionalProperties": False,
|
|
}
|
|
else:
|
|
raise ValueError(f"Unsupported list type: {type_hint}")
|
|
|
|
elif origin is dict:
|
|
# dict[str, Model] - wrap in object since API requires root type = "object"
|
|
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
|
if (
|
|
len(args) >= 2
|
|
and isinstance(args[1], type)
|
|
and issubclass(args[1], BaseModel)
|
|
):
|
|
model = args[1]
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"values": {
|
|
"type": "object",
|
|
"additionalProperties": model.model_json_schema(),
|
|
}
|
|
},
|
|
"required": ["values"],
|
|
"additionalProperties": False,
|
|
}
|
|
else:
|
|
raise ValueError(f"Unsupported dict type: {type_hint}")
|
|
|
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
|
# direct Pydantic model - already an object
|
|
schema = type_hint.model_json_schema()
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
|
|
|
# recursively set additionalProperties: false for all nested objects
|
|
def set_additional_properties(obj: dict[str, Any]) -> None:
|
|
if isinstance(obj, dict):
|
|
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
|
obj["additionalProperties"] = False
|
|
for value in obj.values():
|
|
if isinstance(value, dict):
|
|
set_additional_properties(value)
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
if isinstance(item, dict):
|
|
set_additional_properties(item)
|
|
|
|
set_additional_properties(schema)
|
|
return schema
|
|
|
|
|
|
def parse_json_response(
|
|
response: str | dict | list, type_hint: Any
|
|
) -> BaseModel | list[BaseModel] | dict[str, BaseModel]:
|
|
"""Parse JSON response into typed output.
|
|
|
|
Handles:
|
|
- Pydantic models: MyModel
|
|
- Generic types: list[MyModel], dict[str, MyModel]
|
|
|
|
Note: When schema has list/dict at root, the SDK wraps them in {"items": [...]}
|
|
or {"values": {...}} because API requires root type = "object".
|
|
|
|
Args:
|
|
response: JSON string or already-parsed dict/list from structured_output
|
|
type_hint: The output type annotation
|
|
|
|
Returns:
|
|
Validated and typed output
|
|
|
|
Raises:
|
|
json.JSONDecodeError: If response string is not valid JSON
|
|
pydantic.ValidationError: If JSON doesn't match schema
|
|
"""
|
|
import json
|
|
|
|
origin = get_origin(type_hint)
|
|
args = get_args(type_hint)
|
|
|
|
# parse string to dict/list if needed
|
|
if isinstance(response, str):
|
|
parsed = json.loads(response)
|
|
else:
|
|
parsed = response
|
|
|
|
# handle list[Model]
|
|
if origin is list:
|
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
|
model = args[0]
|
|
|
|
# unwrap from {"items": [...]} if present (from structured_output)
|
|
if isinstance(parsed, dict) and "items" in parsed:
|
|
parsed = parsed["items"]
|
|
|
|
if not isinstance(parsed, list):
|
|
raise ValueError(f"Expected list, got {type(parsed)}")
|
|
return [model.model_validate(item) for item in parsed]
|
|
|
|
# handle dict[str, Model]
|
|
elif origin is dict:
|
|
if (
|
|
len(args) >= 2
|
|
and isinstance(args[1], type)
|
|
and issubclass(args[1], BaseModel)
|
|
):
|
|
model = args[1]
|
|
|
|
# unwrap from {"values": {...}} if present (from structured_output)
|
|
if isinstance(parsed, dict) and "values" in parsed:
|
|
parsed = parsed["values"]
|
|
|
|
if not isinstance(parsed, dict):
|
|
raise ValueError(f"Expected dict, got {type(parsed)}")
|
|
return {key: model.model_validate(value) for key, value in parsed.items()}
|
|
|
|
# handle direct Pydantic model
|
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
|
if isinstance(response, str):
|
|
return type_hint.model_validate_json(response)
|
|
else:
|
|
return type_hint.model_validate(parsed)
|
|
|
|
raise ValueError(f"Unsupported type for parsing: {type_hint}")
|
|
|
|
|
|
def extract_text_from_response(response: str) -> str:
|
|
"""Extract plain text from response.
|
|
|
|
For string outputs, we just return the text as-is.
|
|
Claude Code may wrap responses in markdown or other formatting.
|
|
"""
|
|
return response.strip()
|