Use structured outputs with Pydantic models instead of text parsing
This commit is contained in:
@@ -44,6 +44,7 @@ class ClaudeCodeConfig(PrecompiledConfig):
|
||||
|
||||
model: str = "claude-opus-4-5-20251101"
|
||||
|
||||
|
||||
class ClaudeCodeKwargs(BaseModel):
|
||||
"""Arguments for ClaudeCode initialization.
|
||||
|
||||
@@ -54,10 +55,10 @@ class ClaudeCodeKwargs(BaseModel):
|
||||
# DSPy-specific (required)
|
||||
signature: Any # str | dspy.Signature - validated manually in __init__
|
||||
|
||||
# Authentication
|
||||
# auth
|
||||
api_key: str | None = None
|
||||
|
||||
# Basic configuration
|
||||
# basic config
|
||||
working_directory: str = "."
|
||||
permission_mode: str | None = None
|
||||
allowed_tools: list[str] | None = None # Any Claude Code tools
|
||||
@@ -65,16 +66,16 @@ class ClaudeCodeKwargs(BaseModel):
|
||||
sandbox: dict[str, Any] | None = None
|
||||
system_prompt: str | dict[str, Any] | None = None
|
||||
|
||||
# MCP servers
|
||||
# mcp servers
|
||||
mcp_servers: dict[str, Any] | str | Path | None = None
|
||||
|
||||
# Session management
|
||||
# session management
|
||||
continue_conversation: bool = False
|
||||
resume: str | None = None
|
||||
max_turns: int | None = None
|
||||
fork_session: bool = False
|
||||
|
||||
# Advanced options
|
||||
# advanced options
|
||||
permission_prompt_tool_name: str | None = None
|
||||
settings: str | None = None
|
||||
add_dirs: list[str | Path] | None = None
|
||||
@@ -82,21 +83,23 @@ class ClaudeCodeKwargs(BaseModel):
|
||||
extra_args: dict[str, str | None] | None = None
|
||||
max_buffer_size: int | None = None
|
||||
|
||||
# Callbacks and hooks
|
||||
stderr: Any | None = None # Callable[[str], None] - can't type check callables in Pydantic easily
|
||||
# callbacks and hooks
|
||||
stderr: Any | None = (
|
||||
None # Callable[[str], None] - can't type check callables in Pydantic easily
|
||||
)
|
||||
can_use_tool: Any | None = None # CanUseTool callback
|
||||
hooks: dict[str, list[dict[str, Any]]] | None = None
|
||||
|
||||
# User and settings
|
||||
# user and settings
|
||||
user: str | None = None
|
||||
include_partial_messages: bool = False
|
||||
setting_sources: list[str] | None = None # List of "user" | "project" | "local"
|
||||
|
||||
# Subagents and plugins
|
||||
# subagents and plugins
|
||||
agents: dict[str, dict[str, Any]] | None = None
|
||||
plugins: list[dict[str, Any]] | None = None
|
||||
|
||||
# CLI configuration
|
||||
# cli configuration
|
||||
cli_path: str | Path | None = None
|
||||
|
||||
|
||||
@@ -105,18 +108,6 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
Each agent instance maintains a stateful conversation session.
|
||||
Perfect for multi-turn agentic workflows with context preservation.
|
||||
|
||||
Example:
|
||||
>>> config = ClaudeCodeConfig()
|
||||
>>> agent = ClaudeCode(
|
||||
... config,
|
||||
... signature='message:str -> answer:str',
|
||||
... working_directory="."
|
||||
... )
|
||||
>>> result = agent(message="What files are in this directory?")
|
||||
>>> print(result.answer) # Typed access
|
||||
>>> print(result.trace) # Execution trace
|
||||
>>> print(result.usage) # Token usage
|
||||
"""
|
||||
|
||||
config: ClaudeCodeConfig
|
||||
@@ -130,14 +121,36 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
args = ClaudeCodeKwargs(**kwargs)
|
||||
|
||||
# Parse and validate signature
|
||||
# validate signature
|
||||
# Note: Raw string signatures only work with built-in types.
|
||||
# For custom Pydantic models, users must pass:
|
||||
# 1. A class-based signature, OR
|
||||
# 2. Pre-constructed dspy.Signature (in their module where types are defined)
|
||||
signature = args.signature
|
||||
if isinstance(signature, str):
|
||||
self.signature = dspy.Signature(signature)
|
||||
try:
|
||||
self.signature = dspy.Signature(signature)
|
||||
except ValueError as e:
|
||||
if "Unknown name:" in str(e):
|
||||
type_name = str(e).split("Unknown name: ")[-1]
|
||||
raise ValueError(
|
||||
f"Cannot resolve type '{type_name}' in string signature.\n"
|
||||
f"String signatures only work with built-in types (str, int, list[str], etc.).\n\n"
|
||||
f"For custom Pydantic models, use one of these approaches:\n\n"
|
||||
f"Option 1 - Class-based signature (recommended):\n"
|
||||
f" class MySignature(dspy.Signature):\n"
|
||||
f" input: str = dspy.InputField()\n"
|
||||
f" output: {type_name} = dspy.OutputField()\n"
|
||||
f" agent = ClaudeCode(config, signature=MySignature, ...)\n\n"
|
||||
f"Option 2 - Pre-construct signature in your module:\n"
|
||||
f" sig = dspy.Signature('{signature}')\n"
|
||||
f" agent = ClaudeCode(config, signature=sig, ...)\n"
|
||||
) from e
|
||||
raise
|
||||
else:
|
||||
self.signature = signature
|
||||
|
||||
# Validate signature has exactly 1 input and 1 output
|
||||
# validate signature has exactly 1 input and 1 output TODO: support multiple inputs/outputs
|
||||
input_fields = list(self.signature.input_fields.keys())
|
||||
output_fields = list(self.signature.output_fields.keys())
|
||||
|
||||
@@ -158,28 +171,28 @@ class ClaudeCode(PrecompiledProgram):
|
||||
self.input_field = self.signature.input_fields[self.input_field_name]
|
||||
self.output_field = self.signature.output_fields[self.output_field_name]
|
||||
|
||||
# Store all configuration values
|
||||
# store all configuration values
|
||||
self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.working_directory = Path(args.working_directory).resolve()
|
||||
self.model = config.model
|
||||
|
||||
# Basic options
|
||||
# basic options
|
||||
self.permission_mode = args.permission_mode
|
||||
self.allowed_tools = args.allowed_tools
|
||||
self.disallowed_tools = args.disallowed_tools
|
||||
self.sandbox = args.sandbox
|
||||
self.system_prompt = args.system_prompt
|
||||
|
||||
# MCP servers
|
||||
# mcp servers
|
||||
self.mcp_servers = args.mcp_servers
|
||||
|
||||
# Session management
|
||||
# session management
|
||||
self.continue_conversation = args.continue_conversation
|
||||
self.resume = args.resume
|
||||
self.max_turns = args.max_turns
|
||||
self.fork_session = args.fork_session
|
||||
|
||||
# Advanced options
|
||||
# advanced options
|
||||
self.permission_prompt_tool_name = args.permission_prompt_tool_name
|
||||
self.settings = args.settings
|
||||
self.add_dirs = args.add_dirs
|
||||
@@ -187,21 +200,21 @@ class ClaudeCode(PrecompiledProgram):
|
||||
self.extra_args = args.extra_args
|
||||
self.max_buffer_size = args.max_buffer_size
|
||||
|
||||
# Callbacks and hooks
|
||||
# callbacks and hooks
|
||||
self.stderr = args.stderr
|
||||
self.can_use_tool = args.can_use_tool
|
||||
self.hooks = args.hooks
|
||||
|
||||
# User and settings
|
||||
# user and settings
|
||||
self.user = args.user
|
||||
self.include_partial_messages = args.include_partial_messages
|
||||
self.setting_sources = args.setting_sources
|
||||
|
||||
# Subagents and plugins
|
||||
# subagents and plugins
|
||||
self.agents = args.agents
|
||||
self.plugins = args.plugins
|
||||
|
||||
# CLI configuration
|
||||
# cli configuration
|
||||
self.cli_path = args.cli_path
|
||||
|
||||
# determine output format upfront
|
||||
@@ -222,14 +235,14 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
def _create_client(self) -> ClaudeSDKClient:
|
||||
"""Create ClaudeSDKClient with configured options."""
|
||||
# Build options dict, only including non-None values
|
||||
# build options dict, only including non-None values
|
||||
options_dict = {
|
||||
"cwd": str(self.working_directory),
|
||||
"model": self.model,
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
# Add optional fields only if they're not None
|
||||
# add optional fields only if they're not None
|
||||
if self.permission_mode is not None:
|
||||
options_dict["permission_mode"] = self.permission_mode
|
||||
if self.allowed_tools is not None:
|
||||
@@ -251,7 +264,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
if self.fork_session:
|
||||
options_dict["fork_session"] = self.fork_session
|
||||
if self.permission_prompt_tool_name is not None:
|
||||
options_dict["permission_prompt_tool_name"] = self.permission_prompt_tool_name
|
||||
options_dict["permission_prompt_tool_name"] = (
|
||||
self.permission_prompt_tool_name
|
||||
)
|
||||
if self.settings is not None:
|
||||
options_dict["settings"] = self.settings
|
||||
if self.add_dirs is not None:
|
||||
@@ -283,7 +298,7 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
options = ClaudeAgentOptions(**options_dict)
|
||||
|
||||
# Set API key if provided
|
||||
# set API key if provided
|
||||
if self.api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = self.api_key
|
||||
|
||||
@@ -329,8 +344,7 @@ class ClaudeCode(PrecompiledProgram):
|
||||
if output_desc:
|
||||
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
|
||||
|
||||
# Don't add JSON instructions - the SDK handles structured outputs via output_format
|
||||
# The schema is passed through ClaudeAgentOptions and enforced by the SDK
|
||||
# the schema is passed through ClaudeAgentOptions and enforced by the SDK
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
@@ -363,7 +377,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
- trace: Execution trace items
|
||||
- usage: Token usage statistics
|
||||
"""
|
||||
print(f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})"
|
||||
)
|
||||
|
||||
# create client if needed
|
||||
if self._client is None:
|
||||
@@ -394,26 +410,36 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
# handle assistant messages
|
||||
if isinstance(message, AssistantMessage):
|
||||
print(f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks"
|
||||
)
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars"
|
||||
)
|
||||
response_text += block.text
|
||||
trace.append(
|
||||
AgentMessageItem(text=block.text, model=message.model)
|
||||
)
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
print(f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars"
|
||||
)
|
||||
trace.append(
|
||||
ThinkingItem(text=block.thinking, model=message.model)
|
||||
)
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
print(f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})"
|
||||
)
|
||||
# handle StructuredOutput tool (contains JSON response)
|
||||
if block.name == "StructuredOutput":
|
||||
# the JSON is directly in the tool input (already a dict)
|
||||
response_text = json.dumps(block.input)
|
||||
print(f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)"
|
||||
)
|
||||
|
||||
trace.append(
|
||||
ToolUseItem(
|
||||
@@ -423,7 +449,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
)
|
||||
)
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
print(f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}"
|
||||
)
|
||||
content_str = ""
|
||||
if isinstance(block.content, str):
|
||||
content_str = block.content
|
||||
@@ -447,7 +475,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
# handle result messages (final message with usage info)
|
||||
elif isinstance(message, ResultMessage):
|
||||
print(f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})"
|
||||
)
|
||||
# store session ID
|
||||
if hasattr(message, "session_id"):
|
||||
self._session_id = message.session_id
|
||||
@@ -463,7 +493,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
),
|
||||
output_tokens=usage_data.get("output_tokens", 0),
|
||||
)
|
||||
print(f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached"
|
||||
)
|
||||
|
||||
# check for errors
|
||||
if hasattr(message, "is_error") and message.is_error:
|
||||
@@ -478,14 +510,21 @@ class ClaudeCode(PrecompiledProgram):
|
||||
)
|
||||
raise RuntimeError(f"Agent execution failed: {error_msg}")
|
||||
|
||||
# Prefer structured_output over result (when using output_format)
|
||||
if hasattr(message, "structured_output") and message.structured_output is not None:
|
||||
# prefer structured_output over result (when using output_format)
|
||||
if (
|
||||
hasattr(message, "structured_output")
|
||||
and message.structured_output is not None
|
||||
):
|
||||
structured_output = message.structured_output
|
||||
print(f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)")
|
||||
# Fallback to result field for text outputs
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)"
|
||||
)
|
||||
# fallback to result field for text outputs
|
||||
elif hasattr(message, "result") and message.result:
|
||||
response_text = message.result
|
||||
print(f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)"
|
||||
)
|
||||
|
||||
# handle system messages
|
||||
elif isinstance(message, SystemMessage):
|
||||
@@ -493,14 +532,20 @@ class ClaudeCode(PrecompiledProgram):
|
||||
# log system messages to trace but don't error
|
||||
if hasattr(message, "data") and message.data:
|
||||
data_str = str(message.data)
|
||||
print(f"[ClaudeCode._run_async] - Data: {data_str[:100]}..." if len(data_str) > 100 else f"[ClaudeCode._run_async] - Data: {data_str}")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] - Data: {data_str[:100]}..."
|
||||
if len(data_str) > 100
|
||||
else f"[ClaudeCode._run_async] - Data: {data_str}"
|
||||
)
|
||||
trace.append(
|
||||
AgentMessageItem(text=f"[System: {data_str}]", model="system")
|
||||
)
|
||||
|
||||
print(f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items")
|
||||
print(
|
||||
f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items"
|
||||
)
|
||||
|
||||
# Return structured_output if available (for Pydantic outputs), otherwise text
|
||||
# return structured_output if available (for Pydantic outputs), otherwise text
|
||||
if structured_output is not None:
|
||||
print(f"[ClaudeCode._run_async] Returning structured output")
|
||||
return structured_output, trace, usage
|
||||
@@ -536,11 +581,17 @@ class ClaudeCode(PrecompiledProgram):
|
||||
)
|
||||
|
||||
input_value = kwargs[self.input_field_name]
|
||||
print(f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..."
|
||||
if len(str(input_value)) > 100
|
||||
else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}"
|
||||
)
|
||||
|
||||
# build prompt
|
||||
prompt = self._build_prompt(input_value)
|
||||
print(f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}...")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}..."
|
||||
)
|
||||
|
||||
# run async execution in event loop
|
||||
print(f"[ClaudeCode.forward] Starting async execution (model={self.model})")
|
||||
@@ -562,16 +613,22 @@ class ClaudeCode(PrecompiledProgram):
|
||||
# no event loop, create one
|
||||
response_text, trace, usage = asyncio.run(self._run_async(prompt))
|
||||
|
||||
# Log response details
|
||||
# log response details
|
||||
response_type = type(response_text).__name__
|
||||
response_len = len(str(response_text)) if response_text else 0
|
||||
print(f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)")
|
||||
print(f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
|
||||
)
|
||||
print(
|
||||
f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
|
||||
)
|
||||
|
||||
# parse response based on output type
|
||||
output_type = self.output_field.annotation
|
||||
if is_pydantic_model(output_type):
|
||||
print(f"[ClaudeCode.forward] Parsing structured output (type: {output_type})")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Parsing structured output (type: {output_type})"
|
||||
)
|
||||
try:
|
||||
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||
parsed_output = parse_json_response(response_text, output_type)
|
||||
@@ -584,7 +641,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
f"Response: {response_text}"
|
||||
)
|
||||
else:
|
||||
print(f"[ClaudeCode.forward] Extracting text response (output type: {output_type})")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Extracting text response (output type: {output_type})"
|
||||
)
|
||||
# string output - extract text
|
||||
if isinstance(response_text, str):
|
||||
parsed_output = extract_text_from_response(response_text)
|
||||
@@ -592,7 +651,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
# Shouldn't happen, but handle gracefully
|
||||
parsed_output = str(response_text)
|
||||
|
||||
print(f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n")
|
||||
print(
|
||||
f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n"
|
||||
)
|
||||
|
||||
# return prediction with typed output, trace, and usage
|
||||
return Prediction(
|
||||
@@ -624,7 +685,11 @@ class ClaudeCode(PrecompiledProgram):
|
||||
)
|
||||
|
||||
input_value = kwargs[self.input_field_name]
|
||||
print(f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}")
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..."
|
||||
if len(str(input_value)) > 100
|
||||
else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}"
|
||||
)
|
||||
|
||||
# build prompt
|
||||
prompt = self._build_prompt(input_value)
|
||||
@@ -637,13 +702,19 @@ class ClaudeCode(PrecompiledProgram):
|
||||
# Log response details
|
||||
response_type = type(response_text).__name__
|
||||
response_len = len(str(response_text)) if response_text else 0
|
||||
print(f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)")
|
||||
print(f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached")
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
|
||||
)
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
|
||||
)
|
||||
|
||||
# parse response based on output type
|
||||
output_type = self.output_field.annotation
|
||||
if is_pydantic_model(output_type):
|
||||
print(f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})")
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})"
|
||||
)
|
||||
try:
|
||||
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||
parsed_output = parse_json_response(response_text, output_type)
|
||||
@@ -656,7 +727,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
f"Response: {response_text}"
|
||||
)
|
||||
else:
|
||||
print(f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})")
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})"
|
||||
)
|
||||
# string output - extract text
|
||||
if isinstance(response_text, str):
|
||||
parsed_output = extract_text_from_response(response_text)
|
||||
@@ -664,7 +737,9 @@ class ClaudeCode(PrecompiledProgram):
|
||||
# Shouldn't happen, but handle gracefully
|
||||
parsed_output = str(response_text)
|
||||
|
||||
print(f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n")
|
||||
print(
|
||||
f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n"
|
||||
)
|
||||
|
||||
# return prediction with typed output, trace, and usage
|
||||
return Prediction(
|
||||
@@ -683,9 +758,11 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
if self._client and self._is_connected:
|
||||
try:
|
||||
asyncio.run(self.disconnect())
|
||||
except Exception:
|
||||
# best effort cleanup
|
||||
pass
|
||||
# Check attributes exist before accessing (may fail during __init__)
|
||||
if hasattr(self, "_client") and hasattr(self, "_is_connected"):
|
||||
if self._client and self._is_connected:
|
||||
try:
|
||||
asyncio.run(self.disconnect())
|
||||
except Exception:
|
||||
# best effort cleanup
|
||||
pass
|
||||
|
||||
@@ -34,15 +34,15 @@ def is_pydantic_model(type_hint: Any) -> bool:
|
||||
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
||||
"""
|
||||
try:
|
||||
# Direct Pydantic model
|
||||
# direct Pydantic model
|
||||
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
return True
|
||||
|
||||
# Generic type (list, dict, etc.)
|
||||
# generic type (list, dict, etc.)
|
||||
origin = get_origin(type_hint)
|
||||
if origin is not None:
|
||||
args = get_args(type_hint)
|
||||
# Check if any type argument is a Pydantic model
|
||||
# check if any type argument is a Pydantic model
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
@@ -67,7 +67,7 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Handle generic types (list, dict, etc.)
|
||||
# handle generic types (list, dict, etc.)
|
||||
if origin is list:
|
||||
# list[Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
||||
@@ -76,13 +76,10 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": model.model_json_schema()
|
||||
}
|
||||
"items": {"type": "array", "items": model.model_json_schema()}
|
||||
},
|
||||
"required": ["items"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported list type: {type_hint}")
|
||||
@@ -90,30 +87,34 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||
elif origin is dict:
|
||||
# dict[str, Model] - wrap in object since API requires root type = "object"
|
||||
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
if (
|
||||
len(args) >= 2
|
||||
and isinstance(args[1], type)
|
||||
and issubclass(args[1], BaseModel)
|
||||
):
|
||||
model = args[1]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values": {
|
||||
"type": "object",
|
||||
"additionalProperties": model.model_json_schema()
|
||||
"additionalProperties": model.model_json_schema(),
|
||||
}
|
||||
},
|
||||
"required": ["values"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported dict type: {type_hint}")
|
||||
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
# Direct Pydantic model - already an object
|
||||
# direct Pydantic model - already an object
|
||||
schema = type_hint.model_json_schema()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
||||
|
||||
# Recursively set additionalProperties: false for all nested objects
|
||||
# recursively set additionalProperties: false for all nested objects
|
||||
def set_additional_properties(obj: dict[str, Any]) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
||||
@@ -158,18 +159,18 @@ def parse_json_response(
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
# Parse string to dict/list if needed
|
||||
# parse string to dict/list if needed
|
||||
if isinstance(response, str):
|
||||
parsed = json.loads(response)
|
||||
else:
|
||||
parsed = response
|
||||
|
||||
# Handle list[Model]
|
||||
# handle list[Model]
|
||||
if origin is list:
|
||||
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||
model = args[0]
|
||||
|
||||
# Unwrap from {"items": [...]} if present (from structured_output)
|
||||
# unwrap from {"items": [...]} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "items" in parsed:
|
||||
parsed = parsed["items"]
|
||||
|
||||
@@ -177,12 +178,16 @@ def parse_json_response(
|
||||
raise ValueError(f"Expected list, got {type(parsed)}")
|
||||
return [model.model_validate(item) for item in parsed]
|
||||
|
||||
# Handle dict[str, Model]
|
||||
# handle dict[str, Model]
|
||||
elif origin is dict:
|
||||
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||
if (
|
||||
len(args) >= 2
|
||||
and isinstance(args[1], type)
|
||||
and issubclass(args[1], BaseModel)
|
||||
):
|
||||
model = args[1]
|
||||
|
||||
# Unwrap from {"values": {...}} if present (from structured_output)
|
||||
# unwrap from {"values": {...}} if present (from structured_output)
|
||||
if isinstance(parsed, dict) and "values" in parsed:
|
||||
parsed = parsed["values"]
|
||||
|
||||
@@ -190,7 +195,7 @@ def parse_json_response(
|
||||
raise ValueError(f"Expected dict, got {type(parsed)}")
|
||||
return {key: model.model_validate(value) for key, value in parsed.items()}
|
||||
|
||||
# Handle direct Pydantic model
|
||||
# handle direct Pydantic model
|
||||
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||
if isinstance(response, str):
|
||||
return type_hint.model_validate_json(response)
|
||||
|
||||
Reference in New Issue
Block a user