Files
claude-code/claude_dspy/agent.py

769 lines
30 KiB
Python

import asyncio
import json
import os
from pathlib import Path
from typing import Any, Optional
from modaic import PrecompiledProgram, PrecompiledConfig
from pydantic import BaseModel
import dspy
from dspy.primitives.prediction import Prediction
from claude_agent_sdk import (
ClaudeSDKClient,
ClaudeAgentOptions,
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolUseBlock,
ToolResultBlock,
)
from .trace import (
TraceItem,
AgentMessageItem,
ThinkingItem,
ToolUseItem,
ToolResultItem,
ErrorItem,
)
from .utils import (
Usage,
is_pydantic_model,
get_json_schema,
parse_json_response,
extract_text_from_response,
)
class ClaudeCodeConfig(PrecompiledConfig):
"""Configuration for ClaudeCode agent."""
model: str = "claude-opus-4-5-20251101"
class ClaudeCodeKwargs(BaseModel):
"""Arguments for ClaudeCode initialization.
Matches ClaudeAgentOptions from the SDK with additional DSPy-specific fields.
See: https://platform.claude.com/docs/en/agent-sdk/python#claudeagentoptions
"""
# DSPy-specific (required)
signature: Any # str | dspy.Signature - validated manually in __init__
# auth
api_key: str | None = None
# basic config
working_directory: str = "."
permission_mode: str | None = None
allowed_tools: list[str] | None = None # Any Claude Code tools
disallowed_tools: list[str] | None = None
sandbox: dict[str, Any] | None = None
system_prompt: str | dict[str, Any] | None = None
# mcp servers
mcp_servers: dict[str, Any] | str | Path | None = None
# session management
continue_conversation: bool = False
resume: str | None = None
max_turns: int | None = None
fork_session: bool = False
# advanced options
permission_prompt_tool_name: str | None = None
settings: str | None = None
add_dirs: list[str | Path] | None = None
env: dict[str, str] | None = None
extra_args: dict[str, str | None] | None = None
max_buffer_size: int | None = None
# callbacks and hooks
stderr: Any | None = (
None # Callable[[str], None] - can't type check callables in Pydantic easily
)
can_use_tool: Any | None = None # CanUseTool callback
hooks: dict[str, list[dict[str, Any]]] | None = None
# user and settings
user: str | None = None
include_partial_messages: bool = False
setting_sources: list[str] | None = None # List of "user" | "project" | "local"
# subagents and plugins
agents: dict[str, dict[str, Any]] | None = None
plugins: list[dict[str, Any]] | None = None
# cli configuration
cli_path: str | Path | None = None
class ClaudeCode(PrecompiledProgram):
"""DSPy module that wraps Claude Code SDK.
Each agent instance maintains a stateful conversation session.
Perfect for multi-turn agentic workflows with context preservation.
"""
config: ClaudeCodeConfig
def __init__(
self,
config: ClaudeCodeConfig,
**kwargs: dict,
):
super().__init__(config=config)
args = ClaudeCodeKwargs(**kwargs)
# validate signature
# Note: Raw string signatures only work with built-in types.
# For custom Pydantic models, users must pass:
# 1. A class-based signature, OR
# 2. Pre-constructed dspy.Signature (in their module where types are defined)
signature = args.signature
if isinstance(signature, str):
try:
self.signature = dspy.Signature(signature)
except ValueError as e:
if "Unknown name:" in str(e):
type_name = str(e).split("Unknown name: ")[-1]
raise ValueError(
f"Cannot resolve type '{type_name}' in string signature.\n"
f"String signatures only work with built-in types (str, int, list[str], etc.).\n\n"
f"For custom Pydantic models, use one of these approaches:\n\n"
f"Option 1 - Class-based signature (recommended):\n"
f" class MySignature(dspy.Signature):\n"
f" input: str = dspy.InputField()\n"
f" output: {type_name} = dspy.OutputField()\n"
f" agent = ClaudeCode(config, signature=MySignature, ...)\n\n"
f"Option 2 - Pre-construct signature in your module:\n"
f" sig = dspy.Signature('{signature}')\n"
f" agent = ClaudeCode(config, signature=sig, ...)\n"
) from e
raise
else:
self.signature = signature
# validate signature has exactly 1 input and 1 output TODO: support multiple inputs/outputs
input_fields = list(self.signature.input_fields.keys())
output_fields = list(self.signature.output_fields.keys())
if len(input_fields) != 1:
raise ValueError(
f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. "
f"Found: {input_fields}"
)
if len(output_fields) != 1:
raise ValueError(
f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. "
f"Found: {output_fields}"
)
self.input_field_name = input_fields[0]
self.output_field_name = output_fields[0]
self.input_field = self.signature.input_fields[self.input_field_name]
self.output_field = self.signature.output_fields[self.output_field_name]
# store all configuration values
self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY")
self.working_directory = Path(args.working_directory).resolve()
self.model = config.model
# basic options
self.permission_mode = args.permission_mode
self.allowed_tools = args.allowed_tools
self.disallowed_tools = args.disallowed_tools
self.sandbox = args.sandbox
self.system_prompt = args.system_prompt
# mcp servers
self.mcp_servers = args.mcp_servers
# session management
self.continue_conversation = args.continue_conversation
self.resume = args.resume
self.max_turns = args.max_turns
self.fork_session = args.fork_session
# advanced options
self.permission_prompt_tool_name = args.permission_prompt_tool_name
self.settings = args.settings
self.add_dirs = args.add_dirs
self.env = args.env
self.extra_args = args.extra_args
self.max_buffer_size = args.max_buffer_size
# callbacks and hooks
self.stderr = args.stderr
self.can_use_tool = args.can_use_tool
self.hooks = args.hooks
# user and settings
self.user = args.user
self.include_partial_messages = args.include_partial_messages
self.setting_sources = args.setting_sources
# subagents and plugins
self.agents = args.agents
self.plugins = args.plugins
# cli configuration
self.cli_path = args.cli_path
# determine output format upfront
self.output_format = self._get_output_format()
# session state
self._client: Optional[ClaudeSDKClient] = None
self._session_id: Optional[str] = None
self._is_connected = False
@property
def session_id(self) -> Optional[str]:
"""Get the session ID for this agent instance.
Returns None until first forward() call.
"""
return self._session_id
def _create_client(self) -> ClaudeSDKClient:
"""Create ClaudeSDKClient with configured options."""
# build options dict, only including non-None values
options_dict = {
"cwd": str(self.working_directory),
"model": self.model,
"output_format": self.output_format,
}
# add optional fields only if they're not None
if self.permission_mode is not None:
options_dict["permission_mode"] = self.permission_mode
if self.allowed_tools is not None:
options_dict["allowed_tools"] = self.allowed_tools
if self.disallowed_tools is not None:
options_dict["disallowed_tools"] = self.disallowed_tools
if self.sandbox is not None:
options_dict["sandbox"] = self.sandbox
if self.system_prompt is not None:
options_dict["system_prompt"] = self.system_prompt
if self.mcp_servers is not None:
options_dict["mcp_servers"] = self.mcp_servers
if self.continue_conversation:
options_dict["continue_conversation"] = self.continue_conversation
if self.resume is not None:
options_dict["resume"] = self.resume
if self.max_turns is not None:
options_dict["max_turns"] = self.max_turns
if self.fork_session:
options_dict["fork_session"] = self.fork_session
if self.permission_prompt_tool_name is not None:
options_dict["permission_prompt_tool_name"] = (
self.permission_prompt_tool_name
)
if self.settings is not None:
options_dict["settings"] = self.settings
if self.add_dirs is not None:
options_dict["add_dirs"] = self.add_dirs
if self.env is not None:
options_dict["env"] = self.env
if self.extra_args is not None:
options_dict["extra_args"] = self.extra_args
if self.max_buffer_size is not None:
options_dict["max_buffer_size"] = self.max_buffer_size
if self.stderr is not None:
options_dict["stderr"] = self.stderr
if self.can_use_tool is not None:
options_dict["can_use_tool"] = self.can_use_tool
if self.hooks is not None:
options_dict["hooks"] = self.hooks
if self.user is not None:
options_dict["user"] = self.user
if self.include_partial_messages:
options_dict["include_partial_messages"] = self.include_partial_messages
if self.setting_sources is not None:
options_dict["setting_sources"] = self.setting_sources
if self.agents is not None:
options_dict["agents"] = self.agents
if self.plugins is not None:
options_dict["plugins"] = self.plugins
if self.cli_path is not None:
options_dict["cli_path"] = self.cli_path
options = ClaudeAgentOptions(**options_dict)
# set API key if provided
if self.api_key:
os.environ["ANTHROPIC_API_KEY"] = self.api_key
return ClaudeSDKClient(options=options)
def _build_prompt(self, input_value: str) -> str:
"""Build prompt from signature docstring, field descriptions, and input value.
Note: When using structured outputs, the SDK handles JSON formatting automatically
via the output_format parameter, so we don't add JSON instructions to the prompt.
"""
prompt_parts = []
# add signature docstring if present
if self.signature.__doc__:
doc = self.signature.__doc__.strip()
if doc:
prompt_parts.append(f"Task: {doc}")
# add input field description if present
# DSPy fields store desc in json_schema_extra
input_desc = None
if (
hasattr(self.input_field, "json_schema_extra")
and self.input_field.json_schema_extra
):
input_desc = self.input_field.json_schema_extra.get("desc")
if input_desc:
prompt_parts.append(f"Input context: {input_desc}")
# add the actual input value
prompt_parts.append(input_value)
# add output field description if present
output_desc = None
if (
hasattr(self.output_field, "json_schema_extra")
and self.output_field.json_schema_extra
):
output_desc = self.output_field.json_schema_extra.get("desc")
if output_desc:
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
# the schema is passed through ClaudeAgentOptions and enforced by the SDK
return "\n\n".join(prompt_parts)
def _get_output_format(self) -> Optional[dict[str, Any]]:
"""Get output format configuration for structured outputs.
Supports:
- Direct Pydantic models: MyModel
- Generic types: list[MyModel], dict[str, MyModel]
"""
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
schema = get_json_schema(output_type)
return {
"type": "json_schema",
"schema": schema,
}
return None
async def _run_async(
self, prompt: str
) -> tuple[str | dict | list | None, list[TraceItem], Usage]:
"""Run the agent asynchronously and collect results.
Returns:
- response: For structured outputs, returns dict/list from structured_output.
For text outputs, returns string from result or text blocks.
- trace: Execution trace items
- usage: Token usage statistics
"""
print(
f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})"
)
# create client if needed
if self._client is None:
print(f"[ClaudeCode._run_async] Creating new ClaudeSDKClient")
self._client = self._create_client()
# connect if not already connected
if not self._is_connected:
print(f"[ClaudeCode._run_async] Connecting to Claude SDK...")
await self._client.connect()
self._is_connected = True
print(f"[ClaudeCode._run_async] Connected successfully")
# send query (output_format already configured in options)
print(f"[ClaudeCode._run_async] Sending query to agent...")
await self._client.query(prompt)
print(f"[ClaudeCode._run_async] Query sent, waiting for response...")
# collect messages and build trace
trace: list[TraceItem] = []
usage = Usage()
response_text = ""
structured_output = None
message_count = 0
async for message in self._client.receive_response():
message_count += 1
# handle assistant messages
if isinstance(message, AssistantMessage):
print(
f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks"
)
for block in message.content:
if isinstance(block, TextBlock):
print(
f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars"
)
response_text += block.text
trace.append(
AgentMessageItem(text=block.text, model=message.model)
)
elif isinstance(block, ThinkingBlock):
print(
f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars"
)
trace.append(
ThinkingItem(text=block.thinking, model=message.model)
)
elif isinstance(block, ToolUseBlock):
print(
f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})"
)
# handle StructuredOutput tool (contains JSON response)
if block.name == "StructuredOutput":
# the JSON is directly in the tool input (already a dict)
response_text = json.dumps(block.input)
print(
f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)"
)
trace.append(
ToolUseItem(
tool_name=block.name,
tool_input=block.input,
tool_use_id=block.id,
)
)
elif isinstance(block, ToolResultBlock):
print(
f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}"
)
content_str = ""
if isinstance(block.content, str):
content_str = block.content
elif isinstance(block.content, list):
# extract text from content blocks
for item in block.content:
if (
isinstance(item, dict)
and item.get("type") == "text"
):
content_str += item.get("text", "")
trace.append(
ToolResultItem(
tool_name="", # tool name not in ToolResultBlock
tool_use_id=block.tool_use_id,
content=content_str,
is_error=block.is_error or False,
)
)
# handle result messages (final message with usage info)
elif isinstance(message, ResultMessage):
print(
f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})"
)
# store session ID
if hasattr(message, "session_id"):
self._session_id = message.session_id
print(f"[ClaudeCode._run_async] - Session ID: {self._session_id}")
# extract usage
if hasattr(message, "usage") and message.usage:
usage_data = message.usage
usage = Usage(
input_tokens=usage_data.get("input_tokens", 0),
cached_input_tokens=usage_data.get(
"cache_read_input_tokens", 0
),
output_tokens=usage_data.get("output_tokens", 0),
)
print(
f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached"
)
# check for errors
if hasattr(message, "is_error") and message.is_error:
error_msg = (
message.result
if hasattr(message, "result")
else "Unknown error"
)
print(f"[ClaudeCode._run_async] - ERROR: {error_msg}")
trace.append(
ErrorItem(message=error_msg, error_type="execution_error")
)
raise RuntimeError(f"Agent execution failed: {error_msg}")
# prefer structured_output over result (when using output_format)
if (
hasattr(message, "structured_output")
and message.structured_output is not None
):
structured_output = message.structured_output
print(
f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)"
)
# fallback to result field for text outputs
elif hasattr(message, "result") and message.result:
response_text = message.result
print(
f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)"
)
# handle system messages
elif isinstance(message, SystemMessage):
print(f"[ClaudeCode._run_async] Received SystemMessage")
# log system messages to trace but don't error
if hasattr(message, "data") and message.data:
data_str = str(message.data)
print(
f"[ClaudeCode._run_async] - Data: {data_str[:100]}..."
if len(data_str) > 100
else f"[ClaudeCode._run_async] - Data: {data_str}"
)
trace.append(
AgentMessageItem(text=f"[System: {data_str}]", model="system")
)
print(
f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items"
)
# return structured_output if available (for Pydantic outputs), otherwise text
if structured_output is not None:
print(f"[ClaudeCode._run_async] Returning structured output")
return structured_output, trace, usage
else:
print(f"[ClaudeCode._run_async] Returning text response")
return response_text, trace, usage
def forward(self, **kwargs: Any) -> Prediction:
"""Execute the agent with an input message.
Args:
**kwargs: Must contain the input field specified in signature
Returns:
Prediction with:
- Typed output field (named according to signature)
- trace: list[TraceItem] - Execution trace
- usage: Usage - Token usage statistics
Example:
>>> result = agent(message="Hello")
>>> print(result.answer) # Access typed output
>>> print(result.trace) # List of execution items
>>> print(result.usage) # Token usage stats
"""
print(f"\n[ClaudeCode.forward] Called with fields: {list(kwargs.keys())}")
# extract input value
if self.input_field_name not in kwargs:
raise ValueError(
f"Missing required input field: {self.input_field_name}. "
f"Received: {list(kwargs.keys())}"
)
input_value = kwargs[self.input_field_name]
print(
f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..."
if len(str(input_value)) > 100
else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}"
)
# build prompt
prompt = self._build_prompt(input_value)
print(
f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}..."
)
# run async execution in event loop
print(f"[ClaudeCode.forward] Starting async execution (model={self.model})")
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If already in async context, create new loop
import nest_asyncio
nest_asyncio.apply()
response_text, trace, usage = loop.run_until_complete(
self._run_async(prompt)
)
else:
response_text, trace, usage = loop.run_until_complete(
self._run_async(prompt)
)
except RuntimeError:
# no event loop, create one
response_text, trace, usage = asyncio.run(self._run_async(prompt))
# log response details
response_type = type(response_text).__name__
response_len = len(str(response_text)) if response_text else 0
print(
f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
)
print(
f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
)
# parse response based on output type
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
print(
f"[ClaudeCode.forward] Parsing structured output (type: {output_type})"
)
try:
# response_text can be dict/list (from structured_output) or str (legacy)
parsed_output = parse_json_response(response_text, output_type)
print(f"[ClaudeCode.forward] Successfully parsed structured output")
except Exception as e:
print(f"[ClaudeCode.forward] ERROR: Failed to parse structured output")
raise ValueError(
f"Failed to parse Claude response as {output_type}: {e}\n"
f"Response type: {type(response_text)}\n"
f"Response: {response_text}"
)
else:
print(
f"[ClaudeCode.forward] Extracting text response (output type: {output_type})"
)
# string output - extract text
if isinstance(response_text, str):
parsed_output = extract_text_from_response(response_text)
else:
# Shouldn't happen, but handle gracefully
parsed_output = str(response_text)
print(
f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n"
)
# return prediction with typed output, trace, and usage
return Prediction(
**{
self.output_field_name: parsed_output,
"trace": trace,
"usage": usage,
}
)
async def aforward(self, **kwargs: Any) -> Prediction:
"""Async version of forward().
Use this when already in an async context to avoid event loop issues.
Args:
**kwargs: Must contain the input field specified in signature
Returns:
Prediction with typed output, trace, and usage
"""
print(f"\n[ClaudeCode.aforward] Called with fields: {list(kwargs.keys())}")
# extract input value
if self.input_field_name not in kwargs:
raise ValueError(
f"Missing required input field: {self.input_field_name}. "
f"Received: {list(kwargs.keys())}"
)
input_value = kwargs[self.input_field_name]
print(
f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..."
if len(str(input_value)) > 100
else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}"
)
# build prompt
prompt = self._build_prompt(input_value)
print(f"[ClaudeCode.aforward] Built prompt ({len(prompt)} chars)")
# run async execution
print(f"[ClaudeCode.aforward] Starting async execution (model={self.model})")
response_text, trace, usage = await self._run_async(prompt)
# Log response details
response_type = type(response_text).__name__
response_len = len(str(response_text)) if response_text else 0
print(
f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)"
)
print(
f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached"
)
# parse response based on output type
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
print(
f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})"
)
try:
# response_text can be dict/list (from structured_output) or str (legacy)
parsed_output = parse_json_response(response_text, output_type)
print(f"[ClaudeCode.aforward] Successfully parsed structured output")
except Exception as e:
print(f"[ClaudeCode.aforward] ERROR: Failed to parse structured output")
raise ValueError(
f"Failed to parse Claude response as {output_type}: {e}\n"
f"Response type: {type(response_text)}\n"
f"Response: {response_text}"
)
else:
print(
f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})"
)
# string output - extract text
if isinstance(response_text, str):
parsed_output = extract_text_from_response(response_text)
else:
# Shouldn't happen, but handle gracefully
parsed_output = str(response_text)
print(
f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n"
)
# return prediction with typed output, trace, and usage
return Prediction(
**{
self.output_field_name: parsed_output,
"trace": trace,
"usage": usage,
}
)
async def disconnect(self) -> None:
"""Disconnect from Claude Code and clean up resources."""
if self._client and self._is_connected:
await self._client.disconnect()
self._is_connected = False
def __del__(self):
"""Cleanup on deletion."""
# Check attributes exist before accessing (may fail during __init__)
if hasattr(self, "_client") and hasattr(self, "_is_connected"):
if self._client and self._is_connected:
try:
asyncio.run(self.disconnect())
except Exception:
# best effort cleanup
pass