Files
claude-code/claude_dspy/utils.py

210 lines
7.1 KiB
Python

from dataclasses import dataclass
from typing import Any, get_origin, get_args
from pydantic import BaseModel
@dataclass
class Usage:
"""Token usage statistics."""
input_tokens: int = 0
cached_input_tokens: int = 0
output_tokens: int = 0
@property
def total_tokens(self) -> int:
"""Total tokens used (input + output)."""
return self.input_tokens + self.output_tokens
def __repr__(self) -> str:
return (
f"Usage(input={self.input_tokens}, "
f"cached={self.cached_input_tokens}, "
f"output={self.output_tokens}, "
f"total={self.total_tokens})"
)
def is_pydantic_model(type_hint: Any) -> bool:
"""Check if a type hint is a Pydantic model or contains one (e.g., list[Model]).
Returns True for:
- Pydantic models: MyModel
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
"""
try:
# Direct Pydantic model
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
return True
# Generic type (list, dict, etc.)
origin = get_origin(type_hint)
if origin is not None:
args = get_args(type_hint)
# Check if any type argument is a Pydantic model
for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel):
return True
return False
except TypeError:
return False
def get_json_schema(type_hint: Any) -> dict[str, Any]:
"""Generate JSON schema from type hint.
Handles:
- Pydantic models: MyModel
- Generic types: list[MyModel], dict[str, MyModel]
Note: Claude API requires root type to be "object" for structured outputs (tools).
For list/dict types, we wrap them in an object with a single property.
Sets additionalProperties to false for all objects.
"""
origin = get_origin(type_hint)
args = get_args(type_hint)
# Handle generic types (list, dict, etc.)
if origin is list:
# list[Model] - wrap in object since API requires root type = "object"
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
model = args[0]
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": model.model_json_schema()
}
},
"required": ["items"],
"additionalProperties": False
}
else:
raise ValueError(f"Unsupported list type: {type_hint}")
elif origin is dict:
# dict[str, Model] - wrap in object since API requires root type = "object"
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
model = args[1]
schema = {
"type": "object",
"properties": {
"values": {
"type": "object",
"additionalProperties": model.model_json_schema()
}
},
"required": ["values"],
"additionalProperties": False
}
else:
raise ValueError(f"Unsupported dict type: {type_hint}")
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
# Direct Pydantic model - already an object
schema = type_hint.model_json_schema()
else:
raise ValueError(f"Unsupported type for structured output: {type_hint}")
# Recursively set additionalProperties: false for all nested objects
def set_additional_properties(obj: dict[str, Any]) -> None:
if isinstance(obj, dict):
if obj.get("type") == "object" and "additionalProperties" not in obj:
obj["additionalProperties"] = False
for value in obj.values():
if isinstance(value, dict):
set_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
set_additional_properties(item)
set_additional_properties(schema)
return schema
def parse_json_response(
response: str | dict | list, type_hint: Any
) -> BaseModel | list[BaseModel] | dict[str, BaseModel]:
"""Parse JSON response into typed output.
Handles:
- Pydantic models: MyModel
- Generic types: list[MyModel], dict[str, MyModel]
Note: When schema has list/dict at root, the SDK wraps them in {"items": [...]}
or {"values": {...}} because API requires root type = "object".
Args:
response: JSON string or already-parsed dict/list from structured_output
type_hint: The output type annotation
Returns:
Validated and typed output
Raises:
json.JSONDecodeError: If response string is not valid JSON
pydantic.ValidationError: If JSON doesn't match schema
"""
import json
origin = get_origin(type_hint)
args = get_args(type_hint)
# Parse string to dict/list if needed
if isinstance(response, str):
parsed = json.loads(response)
else:
parsed = response
# Handle list[Model]
if origin is list:
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
model = args[0]
# Unwrap from {"items": [...]} if present (from structured_output)
if isinstance(parsed, dict) and "items" in parsed:
parsed = parsed["items"]
if not isinstance(parsed, list):
raise ValueError(f"Expected list, got {type(parsed)}")
return [model.model_validate(item) for item in parsed]
# Handle dict[str, Model]
elif origin is dict:
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
model = args[1]
# Unwrap from {"values": {...}} if present (from structured_output)
if isinstance(parsed, dict) and "values" in parsed:
parsed = parsed["values"]
if not isinstance(parsed, dict):
raise ValueError(f"Expected dict, got {type(parsed)}")
return {key: model.model_validate(value) for key, value in parsed.items()}
# Handle direct Pydantic model
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
if isinstance(response, str):
return type_hint.model_validate_json(response)
else:
return type_hint.model_validate(parsed)
raise ValueError(f"Unsupported type for parsing: {type_hint}")
def extract_text_from_response(response: str) -> str:
"""Extract plain text from response.
For string outputs, we just return the text as-is.
Claude Code may wrap responses in markdown or other formatting.
"""
return response.strip()