Use structured outputs with Pydantic models instead of text parsing

This commit is contained in:
2025-12-08 14:29:29 -05:00
parent 4df7bb42b4
commit 9e92fafd36
4 changed files with 203 additions and 118 deletions

View File

@@ -128,6 +128,7 @@ print(result.usage) # Token counts
```python ```python
from claude_dspy import ClaudeCode, ClaudeCodeConfig from claude_dspy import ClaudeCode, ClaudeCodeConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import dspy
class BugReport(BaseModel): class BugReport(BaseModel):
severity: str = Field(description="critical, high, medium, or low") severity: str = Field(description="critical, high, medium, or low")
@@ -137,9 +138,23 @@ class BugReport(BaseModel):
# Create config with Pydantic output # Create config with Pydantic output
config = ClaudeCodeConfig() config = ClaudeCodeConfig()
# Option 1: Pre-construct signature in your module (where BugReport is defined)
sig = dspy.Signature("message:str -> report:BugReport")
agent = ClaudeCode( agent = ClaudeCode(
config, config,
signature="message:str -> report:BugReport", signature=sig,
working_directory="."
)
# Option 2: Use class-based signature (recommended)
class BugReportSignature(dspy.Signature):
"""Analyze bugs and generate report."""
message: str = dspy.InputField()
report: BugReport = dspy.OutputField()
agent = ClaudeCode(
config,
signature=BugReportSignature,
working_directory="." working_directory="."
) )
@@ -148,6 +163,10 @@ print(result.report.severity) # Typed access!
print(result.report.affected_files) print(result.report.affected_files)
``` ```
**Note**: String signatures like `"message:str -> report:BugReport"` only work with built-in types unless you use `dspy.Signature()` with `custom_types`. For custom Pydantic models, either:
- Use `dspy.Signature("...", custom_types={...})`
- Use class-based signatures (recommended)
### Push to Modaic Hub ### Push to Modaic Hub
```python ```python

View File

@@ -44,6 +44,7 @@ class ClaudeCodeConfig(PrecompiledConfig):
model: str = "claude-opus-4-5-20251101" model: str = "claude-opus-4-5-20251101"
class ClaudeCodeKwargs(BaseModel): class ClaudeCodeKwargs(BaseModel):
"""Arguments for ClaudeCode initialization. """Arguments for ClaudeCode initialization.
@@ -54,10 +55,10 @@ class ClaudeCodeKwargs(BaseModel):
# DSPy-specific (required) # DSPy-specific (required)
signature: Any # str | dspy.Signature - validated manually in __init__ signature: Any # str | dspy.Signature - validated manually in __init__
# Authentication # auth
api_key: str | None = None api_key: str | None = None
# Basic configuration # basic config
working_directory: str = "." working_directory: str = "."
permission_mode: str | None = None permission_mode: str | None = None
allowed_tools: list[str] | None = None # Any Claude Code tools allowed_tools: list[str] | None = None # Any Claude Code tools
@@ -65,16 +66,16 @@ class ClaudeCodeKwargs(BaseModel):
sandbox: dict[str, Any] | None = None sandbox: dict[str, Any] | None = None
system_prompt: str | dict[str, Any] | None = None system_prompt: str | dict[str, Any] | None = None
# MCP servers # mcp servers
mcp_servers: dict[str, Any] | str | Path | None = None mcp_servers: dict[str, Any] | str | Path | None = None
# Session management # session management
continue_conversation: bool = False continue_conversation: bool = False
resume: str | None = None resume: str | None = None
max_turns: int | None = None max_turns: int | None = None
fork_session: bool = False fork_session: bool = False
# Advanced options # advanced options
permission_prompt_tool_name: str | None = None permission_prompt_tool_name: str | None = None
settings: str | None = None settings: str | None = None
add_dirs: list[str | Path] | None = None add_dirs: list[str | Path] | None = None
@@ -82,21 +83,23 @@ class ClaudeCodeKwargs(BaseModel):
extra_args: dict[str, str | None] | None = None extra_args: dict[str, str | None] | None = None
max_buffer_size: int | None = None max_buffer_size: int | None = None
# Callbacks and hooks # callbacks and hooks
stderr: Any | None = None # Callable[[str], None] - can't type check callables in Pydantic easily stderr: Any | None = (
None # Callable[[str], None] - can't type check callables in Pydantic easily
)
can_use_tool: Any | None = None # CanUseTool callback can_use_tool: Any | None = None # CanUseTool callback
hooks: dict[str, list[dict[str, Any]]] | None = None hooks: dict[str, list[dict[str, Any]]] | None = None
# User and settings # user and settings
user: str | None = None user: str | None = None
include_partial_messages: bool = False include_partial_messages: bool = False
setting_sources: list[str] | None = None # List of "user" | "project" | "local" setting_sources: list[str] | None = None # List of "user" | "project" | "local"
# Subagents and plugins # subagents and plugins
agents: dict[str, dict[str, Any]] | None = None agents: dict[str, dict[str, Any]] | None = None
plugins: list[dict[str, Any]] | None = None plugins: list[dict[str, Any]] | None = None
# CLI configuration # cli configuration
cli_path: str | Path | None = None cli_path: str | Path | None = None
@@ -105,18 +108,6 @@ class ClaudeCode(PrecompiledProgram):
Each agent instance maintains a stateful conversation session. Each agent instance maintains a stateful conversation session.
Perfect for multi-turn agentic workflows with context preservation. Perfect for multi-turn agentic workflows with context preservation.
Example:
>>> config = ClaudeCodeConfig()
>>> agent = ClaudeCode(
... config,
... signature='message:str -> answer:str',
... working_directory="."
... )
>>> result = agent(message="What files are in this directory?")
>>> print(result.answer) # Typed access
>>> print(result.trace) # Execution trace
>>> print(result.usage) # Token usage
""" """
config: ClaudeCodeConfig config: ClaudeCodeConfig
@@ -130,14 +121,36 @@ class ClaudeCode(PrecompiledProgram):
args = ClaudeCodeKwargs(**kwargs) args = ClaudeCodeKwargs(**kwargs)
# Parse and validate signature # validate signature
# Note: Raw string signatures only work with built-in types.
# For custom Pydantic models, users must pass:
# 1. A class-based signature, OR
# 2. Pre-constructed dspy.Signature (in their module where types are defined)
signature = args.signature signature = args.signature
if isinstance(signature, str): if isinstance(signature, str):
try:
self.signature = dspy.Signature(signature) self.signature = dspy.Signature(signature)
except ValueError as e:
if "Unknown name:" in str(e):
type_name = str(e).split("Unknown name: ")[-1]
raise ValueError(
f"Cannot resolve type '{type_name}' in string signature.\n"
f"String signatures only work with built-in types (str, int, list[str], etc.).\n\n"
f"For custom Pydantic models, use one of these approaches:\n\n"
f"Option 1 - Class-based signature (recommended):\n"
f" class MySignature(dspy.Signature):\n"
f" input: str = dspy.InputField()\n"
f" output: {type_name} = dspy.OutputField()\n"
f" agent = ClaudeCode(config, signature=MySignature, ...)\n\n"
f"Option 2 - Pre-construct signature in your module:\n"
f" sig = dspy.Signature('{signature}')\n"
f" agent = ClaudeCode(config, signature=sig, ...)\n"
) from e
raise
else: else:
self.signature = signature self.signature = signature
# Validate signature has exactly 1 input and 1 output # validate signature has exactly 1 input and 1 output TODO: support multiple inputs/outputs
input_fields = list(self.signature.input_fields.keys()) input_fields = list(self.signature.input_fields.keys())
output_fields = list(self.signature.output_fields.keys()) output_fields = list(self.signature.output_fields.keys())
@@ -158,28 +171,28 @@ class ClaudeCode(PrecompiledProgram):
self.input_field = self.signature.input_fields[self.input_field_name] self.input_field = self.signature.input_fields[self.input_field_name]
self.output_field = self.signature.output_fields[self.output_field_name] self.output_field = self.signature.output_fields[self.output_field_name]
# Store all configuration values # store all configuration values
self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY") self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY")
self.working_directory = Path(args.working_directory).resolve() self.working_directory = Path(args.working_directory).resolve()
self.model = config.model self.model = config.model
# Basic options # basic options
self.permission_mode = args.permission_mode self.permission_mode = args.permission_mode
self.allowed_tools = args.allowed_tools self.allowed_tools = args.allowed_tools
self.disallowed_tools = args.disallowed_tools self.disallowed_tools = args.disallowed_tools
self.sandbox = args.sandbox self.sandbox = args.sandbox
self.system_prompt = args.system_prompt self.system_prompt = args.system_prompt
# MCP servers # mcp servers
self.mcp_servers = args.mcp_servers self.mcp_servers = args.mcp_servers
# Session management # session management
self.continue_conversation = args.continue_conversation self.continue_conversation = args.continue_conversation
self.resume = args.resume self.resume = args.resume
self.max_turns = args.max_turns self.max_turns = args.max_turns
self.fork_session = args.fork_session self.fork_session = args.fork_session
# Advanced options # advanced options
self.permission_prompt_tool_name = args.permission_prompt_tool_name self.permission_prompt_tool_name = args.permission_prompt_tool_name
self.settings = args.settings self.settings = args.settings
self.add_dirs = args.add_dirs self.add_dirs = args.add_dirs
@@ -187,21 +200,21 @@ class ClaudeCode(PrecompiledProgram):
self.extra_args = args.extra_args self.extra_args = args.extra_args
self.max_buffer_size = args.max_buffer_size self.max_buffer_size = args.max_buffer_size
# Callbacks and hooks # callbacks and hooks
self.stderr = args.stderr self.stderr = args.stderr
self.can_use_tool = args.can_use_tool self.can_use_tool = args.can_use_tool
self.hooks = args.hooks self.hooks = args.hooks
# User and settings # user and settings
self.user = args.user self.user = args.user
self.include_partial_messages = args.include_partial_messages self.include_partial_messages = args.include_partial_messages
self.setting_sources = args.setting_sources self.setting_sources = args.setting_sources
# Subagents and plugins # subagents and plugins
self.agents = args.agents self.agents = args.agents
self.plugins = args.plugins self.plugins = args.plugins
# CLI configuration # cli configuration
self.cli_path = args.cli_path self.cli_path = args.cli_path
# determine output format upfront # determine output format upfront
@@ -222,14 +235,14 @@ class ClaudeCode(PrecompiledProgram):
def _create_client(self) -> ClaudeSDKClient: def _create_client(self) -> ClaudeSDKClient:
"""Create ClaudeSDKClient with configured options.""" """Create ClaudeSDKClient with configured options."""
# Build options dict, only including non-None values # build options dict, only including non-None values
options_dict = { options_dict = {
"cwd": str(self.working_directory), "cwd": str(self.working_directory),
"model": self.model, "model": self.model,
"output_format": self.output_format, "output_format": self.output_format,
} }
# Add optional fields only if they're not None # add optional fields only if they're not None
if self.permission_mode is not None: if self.permission_mode is not None:
options_dict["permission_mode"] = self.permission_mode options_dict["permission_mode"] = self.permission_mode
if self.allowed_tools is not None: if self.allowed_tools is not None:
@@ -251,7 +264,9 @@ class ClaudeCode(PrecompiledProgram):
if self.fork_session: if self.fork_session:
options_dict["fork_session"] = self.fork_session options_dict["fork_session"] = self.fork_session
if self.permission_prompt_tool_name is not None: if self.permission_prompt_tool_name is not None:
options_dict["permission_prompt_tool_name"] = self.permission_prompt_tool_name options_dict["permission_prompt_tool_name"] = (
self.permission_prompt_tool_name
)
if self.settings is not None: if self.settings is not None:
options_dict["settings"] = self.settings options_dict["settings"] = self.settings
if self.add_dirs is not None: if self.add_dirs is not None:
@@ -283,7 +298,7 @@ class ClaudeCode(PrecompiledProgram):
options = ClaudeAgentOptions(**options_dict) options = ClaudeAgentOptions(**options_dict)
# Set API key if provided # set API key if provided
if self.api_key: if self.api_key:
os.environ["ANTHROPIC_API_KEY"] = self.api_key os.environ["ANTHROPIC_API_KEY"] = self.api_key
@@ -329,8 +344,7 @@ class ClaudeCode(PrecompiledProgram):
if output_desc: if output_desc:
prompt_parts.append(f"\nPlease produce the following output: {output_desc}") prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
# Don't add JSON instructions - the SDK handles structured outputs via output_format # the schema is passed through ClaudeAgentOptions and enforced by the SDK
# The schema is passed through ClaudeAgentOptions and enforced by the SDK
return "\n\n".join(prompt_parts) return "\n\n".join(prompt_parts)
@@ -363,7 +377,9 @@ class ClaudeCode(PrecompiledProgram):
- trace: Execution trace items - trace: Execution trace items
- usage: Token usage statistics - usage: Token usage statistics
""" """
print(f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})") print(
f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})"
)
# create client if needed # create client if needed
if self._client is None: if self._client is None:
@@ -394,26 +410,36 @@ class ClaudeCode(PrecompiledProgram):
# handle assistant messages # handle assistant messages
if isinstance(message, AssistantMessage): if isinstance(message, AssistantMessage):
print(f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks") print(
f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks"
)
for block in message.content: for block in message.content:
if isinstance(block, TextBlock): if isinstance(block, TextBlock):
print(f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars") print(
f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars"
)
response_text += block.text response_text += block.text
trace.append( trace.append(
AgentMessageItem(text=block.text, model=message.model) AgentMessageItem(text=block.text, model=message.model)
) )
elif isinstance(block, ThinkingBlock): elif isinstance(block, ThinkingBlock):
print(f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars") print(
f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars"
)
trace.append( trace.append(
ThinkingItem(text=block.thinking, model=message.model) ThinkingItem(text=block.thinking, model=message.model)
) )
elif isinstance(block, ToolUseBlock): elif isinstance(block, ToolUseBlock):
print(f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})") print(
f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})"
)
# handle StructuredOutput tool (contains JSON response) # handle StructuredOutput tool (contains JSON response)
if block.name == "StructuredOutput": if block.name == "StructuredOutput":
# the JSON is directly in the tool input (already a dict) # the JSON is directly in the tool input (already a dict)
response_text = json.dumps(block.input) response_text = json.dumps(block.input)
print(f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)") print(
f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)"
)
trace.append( trace.append(
ToolUseItem( ToolUseItem(
@@ -423,7 +449,9 @@ class ClaudeCode(PrecompiledProgram):
) )
) )
elif isinstance(block, ToolResultBlock): elif isinstance(block, ToolResultBlock):
print(f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}") print(
f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}"
)
content_str = "" content_str = ""
if isinstance(block.content, str): if isinstance(block.content, str):
content_str = block.content content_str = block.content
@@ -447,7 +475,9 @@ class ClaudeCode(PrecompiledProgram):
# handle result messages (final message with usage info) # handle result messages (final message with usage info)
elif isinstance(message, ResultMessage): elif isinstance(message, ResultMessage):
print(f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})") print(
f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})"
)
# store session ID # store session ID
if hasattr(message, "session_id"): if hasattr(message, "session_id"):
self._session_id = message.session_id self._session_id = message.session_id
@@ -463,7 +493,9 @@ class ClaudeCode(PrecompiledProgram):
), ),
output_tokens=usage_data.get("output_tokens", 0), output_tokens=usage_data.get("output_tokens", 0),
) )
print(f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached") print(
f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached"
)
# check for errors # check for errors
if hasattr(message, "is_error") and message.is_error: if hasattr(message, "is_error") and message.is_error:
@@ -478,14 +510,21 @@ class ClaudeCode(PrecompiledProgram):
) )
raise RuntimeError(f"Agent execution failed: {error_msg}") raise RuntimeError(f"Agent execution failed: {error_msg}")
# Prefer structured_output over result (when using output_format) # prefer structured_output over result (when using output_format)
if hasattr(message, "structured_output") and message.structured_output is not None: if (
hasattr(message, "structured_output")
and message.structured_output is not None
):
structured_output = message.structured_output structured_output = message.structured_output
print(f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)") print(
# Fallback to result field for text outputs f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)"
)
# fallback to result field for text outputs
elif hasattr(message, "result") and message.result: elif hasattr(message, "result") and message.result:
response_text = message.result response_text = message.result
print(f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)") print(
f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)"
)
# handle system messages # handle system messages
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
@@ -493,14 +532,20 @@ class ClaudeCode(PrecompiledProgram):
# log system messages to trace but don't error # log system messages to trace but don't error
if hasattr(message, "data") and message.data: if hasattr(message, "data") and message.data:
data_str = str(message.data) data_str = str(message.data)
print(f"[ClaudeCode._run_async] - Data: {data_str[:100]}..." if len(data_str) > 100 else f"[ClaudeCode._run_async] - Data: {data_str}") print(
f"[ClaudeCode._run_async] - Data: {data_str[:100]}..."
if len(data_str) > 100
else f"[ClaudeCode._run_async] - Data: {data_str}"
)
trace.append( trace.append(
AgentMessageItem(text=f"[System: {data_str}]", model="system") AgentMessageItem(text=f"[System: {data_str}]", model="system")
) )
print(f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items") print(
f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items"
)
# Return structured_output if available (for Pydantic outputs), otherwise text # return structured_output if available (for Pydantic outputs), otherwise text
if structured_output is not None: if structured_output is not None:
print(f"[ClaudeCode._run_async] Returning structured output") print(f"[ClaudeCode._run_async] Returning structured output")
return structured_output, trace, usage return structured_output, trace, usage
@@ -536,11 +581,17 @@ class ClaudeCode(PrecompiledProgram):
) )
input_value = kwargs[self.input_field_name] input_value = kwargs[self.input_field_name]
print(f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}") print(
f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..."
if len(str(input_value)) > 100
else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}"
)
# build prompt # build prompt
prompt = self._build_prompt(input_value) prompt = self._build_prompt(input_value)
print(f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}...") print(
f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}..."
)
# run async execution in event loop # run async execution in event loop
print(f"[ClaudeCode.forward] Starting async execution (model={self.model})") print(f"[ClaudeCode.forward] Starting async execution (model={self.model})")
@@ -562,16 +613,22 @@ class ClaudeCode(PrecompiledProgram):
# no event loop, create one # no event loop, create one
response_text, trace, usage = asyncio.run(self._run_async(prompt)) response_text, trace, usage = asyncio.run(self._run_async(prompt))
# Log response details # log response details
response_type = type(response_text).__name__ response_type = type(response_text).__name__
response_len = len(str(response_text)) if response_text else 0 response_len = len(str(response_text)) if response_text else 0
print(f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)") print(
print(f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached") f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
)
print(
f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
)
# parse response based on output type # parse response based on output type
output_type = self.output_field.annotation output_type = self.output_field.annotation
if is_pydantic_model(output_type): if is_pydantic_model(output_type):
print(f"[ClaudeCode.forward] Parsing structured output (type: {output_type})") print(
f"[ClaudeCode.forward] Parsing structured output (type: {output_type})"
)
try: try:
# response_text can be dict/list (from structured_output) or str (legacy) # response_text can be dict/list (from structured_output) or str (legacy)
parsed_output = parse_json_response(response_text, output_type) parsed_output = parse_json_response(response_text, output_type)
@@ -584,7 +641,9 @@ class ClaudeCode(PrecompiledProgram):
f"Response: {response_text}" f"Response: {response_text}"
) )
else: else:
print(f"[ClaudeCode.forward] Extracting text response (output type: {output_type})") print(
f"[ClaudeCode.forward] Extracting text response (output type: {output_type})"
)
# string output - extract text # string output - extract text
if isinstance(response_text, str): if isinstance(response_text, str):
parsed_output = extract_text_from_response(response_text) parsed_output = extract_text_from_response(response_text)
@@ -592,7 +651,9 @@ class ClaudeCode(PrecompiledProgram):
# Shouldn't happen, but handle gracefully # Shouldn't happen, but handle gracefully
parsed_output = str(response_text) parsed_output = str(response_text)
print(f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n") print(
f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n"
)
# return prediction with typed output, trace, and usage # return prediction with typed output, trace, and usage
return Prediction( return Prediction(
@@ -624,7 +685,11 @@ class ClaudeCode(PrecompiledProgram):
) )
input_value = kwargs[self.input_field_name] input_value = kwargs[self.input_field_name]
print(f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}") print(
f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..."
if len(str(input_value)) > 100
else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}"
)
# build prompt # build prompt
prompt = self._build_prompt(input_value) prompt = self._build_prompt(input_value)
@@ -637,13 +702,19 @@ class ClaudeCode(PrecompiledProgram):
# Log response details # Log response details
response_type = type(response_text).__name__ response_type = type(response_text).__name__
response_len = len(str(response_text)) if response_text else 0 response_len = len(str(response_text)) if response_text else 0
print(f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)") print(
print(f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached") f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
)
print(
f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
)
# parse response based on output type # parse response based on output type
output_type = self.output_field.annotation output_type = self.output_field.annotation
if is_pydantic_model(output_type): if is_pydantic_model(output_type):
print(f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})") print(
f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})"
)
try: try:
# response_text can be dict/list (from structured_output) or str (legacy) # response_text can be dict/list (from structured_output) or str (legacy)
parsed_output = parse_json_response(response_text, output_type) parsed_output = parse_json_response(response_text, output_type)
@@ -656,7 +727,9 @@ class ClaudeCode(PrecompiledProgram):
f"Response: {response_text}" f"Response: {response_text}"
) )
else: else:
print(f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})") print(
f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})"
)
# string output - extract text # string output - extract text
if isinstance(response_text, str): if isinstance(response_text, str):
parsed_output = extract_text_from_response(response_text) parsed_output = extract_text_from_response(response_text)
@@ -664,7 +737,9 @@ class ClaudeCode(PrecompiledProgram):
# Shouldn't happen, but handle gracefully # Shouldn't happen, but handle gracefully
parsed_output = str(response_text) parsed_output = str(response_text)
print(f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n") print(
f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n"
)
# return prediction with typed output, trace, and usage # return prediction with typed output, trace, and usage
return Prediction( return Prediction(
@@ -683,6 +758,8 @@ class ClaudeCode(PrecompiledProgram):
def __del__(self): def __del__(self):
"""Cleanup on deletion.""" """Cleanup on deletion."""
# Check attributes exist before accessing (may fail during __init__)
if hasattr(self, "_client") and hasattr(self, "_is_connected"):
if self._client and self._is_connected: if self._client and self._is_connected:
try: try:
asyncio.run(self.disconnect()) asyncio.run(self.disconnect())

View File

@@ -34,15 +34,15 @@ def is_pydantic_model(type_hint: Any) -> bool:
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel] - Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
""" """
try: try:
# Direct Pydantic model # direct Pydantic model
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel): if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
return True return True
# Generic type (list, dict, etc.) # generic type (list, dict, etc.)
origin = get_origin(type_hint) origin = get_origin(type_hint)
if origin is not None: if origin is not None:
args = get_args(type_hint) args = get_args(type_hint)
# Check if any type argument is a Pydantic model # check if any type argument is a Pydantic model
for arg in args: for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel): if isinstance(arg, type) and issubclass(arg, BaseModel):
return True return True
@@ -67,7 +67,7 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
origin = get_origin(type_hint) origin = get_origin(type_hint)
args = get_args(type_hint) args = get_args(type_hint)
# Handle generic types (list, dict, etc.) # handle generic types (list, dict, etc.)
if origin is list: if origin is list:
# list[Model] - wrap in object since API requires root type = "object" # list[Model] - wrap in object since API requires root type = "object"
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}} # {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
@@ -76,13 +76,10 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"items": { "items": {"type": "array", "items": model.model_json_schema()}
"type": "array",
"items": model.model_json_schema()
}
}, },
"required": ["items"], "required": ["items"],
"additionalProperties": False "additionalProperties": False,
} }
else: else:
raise ValueError(f"Unsupported list type: {type_hint}") raise ValueError(f"Unsupported list type: {type_hint}")
@@ -90,30 +87,34 @@ def get_json_schema(type_hint: Any) -> dict[str, Any]:
elif origin is dict: elif origin is dict:
# dict[str, Model] - wrap in object since API requires root type = "object" # dict[str, Model] - wrap in object since API requires root type = "object"
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}} # {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel): if (
len(args) >= 2
and isinstance(args[1], type)
and issubclass(args[1], BaseModel)
):
model = args[1] model = args[1]
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"values": { "values": {
"type": "object", "type": "object",
"additionalProperties": model.model_json_schema() "additionalProperties": model.model_json_schema(),
} }
}, },
"required": ["values"], "required": ["values"],
"additionalProperties": False "additionalProperties": False,
} }
else: else:
raise ValueError(f"Unsupported dict type: {type_hint}") raise ValueError(f"Unsupported dict type: {type_hint}")
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel): elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
# Direct Pydantic model - already an object # direct Pydantic model - already an object
schema = type_hint.model_json_schema() schema = type_hint.model_json_schema()
else: else:
raise ValueError(f"Unsupported type for structured output: {type_hint}") raise ValueError(f"Unsupported type for structured output: {type_hint}")
# Recursively set additionalProperties: false for all nested objects # recursively set additionalProperties: false for all nested objects
def set_additional_properties(obj: dict[str, Any]) -> None: def set_additional_properties(obj: dict[str, Any]) -> None:
if isinstance(obj, dict): if isinstance(obj, dict):
if obj.get("type") == "object" and "additionalProperties" not in obj: if obj.get("type") == "object" and "additionalProperties" not in obj:
@@ -158,18 +159,18 @@ def parse_json_response(
origin = get_origin(type_hint) origin = get_origin(type_hint)
args = get_args(type_hint) args = get_args(type_hint)
# Parse string to dict/list if needed # parse string to dict/list if needed
if isinstance(response, str): if isinstance(response, str):
parsed = json.loads(response) parsed = json.loads(response)
else: else:
parsed = response parsed = response
# Handle list[Model] # handle list[Model]
if origin is list: if origin is list:
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel): if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
model = args[0] model = args[0]
# Unwrap from {"items": [...]} if present (from structured_output) # unwrap from {"items": [...]} if present (from structured_output)
if isinstance(parsed, dict) and "items" in parsed: if isinstance(parsed, dict) and "items" in parsed:
parsed = parsed["items"] parsed = parsed["items"]
@@ -177,12 +178,16 @@ def parse_json_response(
raise ValueError(f"Expected list, got {type(parsed)}") raise ValueError(f"Expected list, got {type(parsed)}")
return [model.model_validate(item) for item in parsed] return [model.model_validate(item) for item in parsed]
# Handle dict[str, Model] # handle dict[str, Model]
elif origin is dict: elif origin is dict:
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel): if (
len(args) >= 2
and isinstance(args[1], type)
and issubclass(args[1], BaseModel)
):
model = args[1] model = args[1]
# Unwrap from {"values": {...}} if present (from structured_output) # unwrap from {"values": {...}} if present (from structured_output)
if isinstance(parsed, dict) and "values" in parsed: if isinstance(parsed, dict) and "values" in parsed:
parsed = parsed["values"] parsed = parsed["values"]
@@ -190,7 +195,7 @@ def parse_json_response(
raise ValueError(f"Expected dict, got {type(parsed)}") raise ValueError(f"Expected dict, got {type(parsed)}")
return {key: model.model_validate(value) for key, value in parsed.items()} return {key: model.model_validate(value) for key, value in parsed.items()}
# Handle direct Pydantic model # handle direct Pydantic model
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel): elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
if isinstance(response, str): if isinstance(response, str):
return type_hint.model_validate_json(response) return type_hint.model_validate_json(response)

22
main.py
View File

@@ -7,7 +7,9 @@ import dspy
class ErrorReport(BaseModel): class ErrorReport(BaseModel):
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
timestamp: str = Field(..., description="Timestamp of error") timestamp: str = Field(..., description="Timestamp of error")
line_located: int | None = Field(..., description="Line number where error occurred") line_located: int | None = Field(
..., description="Line number where error occurred"
)
file_located: str | None = Field(..., description="File where error occurred") file_located: str | None = Field(..., description="File where error occurred")
description: str = Field(..., description="Description of errors") description: str = Field(..., description="Description of errors")
reccomedated_fixes: list[str] = Field(..., description="List of recommended fixes") reccomedated_fixes: list[str] = Field(..., description="List of recommended fixes")
@@ -37,24 +39,6 @@ def main():
commit_message="Use structured outputs with Pydantic models instead of text parsing", commit_message="Use structured outputs with Pydantic models instead of text parsing",
) )
# agent = AutoProgram.from_precompiled("farouk1/claude-code", signature=ClaudeCodeSignature, working_directory=".", permission_mode="acceptEdits", allowed_tools=["Read", "Write", "Bash"])
"""
print("Running Claude Code...")
result = cc(log_file="log.txt")
print("Claude Code finished running!")
print("-" * 50)
print("Output: ", result.output)
print("Output type: ", type(result.output))
print("Usage: ", result.usage)
print("Output type: ", type(result.output))
print("-" * 50)
print()
print("Trace:")
for i, item in enumerate(result.trace):
print(f" {i}: {item}")
"""
if __name__ == "__main__": if __name__ == "__main__":
main() main()