import asyncio import json import os from pathlib import Path from typing import Any, Optional from modaic import PrecompiledProgram, PrecompiledConfig from pydantic import BaseModel import dspy from dspy.primitives.prediction import Prediction from claude_agent_sdk import ( ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, ResultMessage, SystemMessage, TextBlock, ThinkingBlock, ToolUseBlock, ToolResultBlock, ) from .trace import ( TraceItem, AgentMessageItem, ThinkingItem, ToolUseItem, ToolResultItem, ErrorItem, ) from .utils import ( Usage, is_pydantic_model, get_json_schema, parse_json_response, extract_text_from_response, ) class ClaudeCodeConfig(PrecompiledConfig): """Configuration for ClaudeCode agent.""" model: str = "claude-opus-4-5-20251101" class ClaudeCodeKwargs(BaseModel): """Arguments for ClaudeCode initialization. Matches ClaudeAgentOptions from the SDK with additional DSPy-specific fields. See: https://platform.claude.com/docs/en/agent-sdk/python#claudeagentoptions """ # DSPy-specific (required) signature: Any # str | dspy.Signature - validated manually in __init__ # auth api_key: str | None = None # basic config working_directory: str = "." permission_mode: str | None = None allowed_tools: list[str] | None = None # Any Claude Code tools disallowed_tools: list[str] | None = None sandbox: dict[str, Any] | None = None system_prompt: str | dict[str, Any] | None = None # mcp servers mcp_servers: dict[str, Any] | str | Path | None = None # session management continue_conversation: bool = False resume: str | None = None max_turns: int | None = None fork_session: bool = False # advanced options permission_prompt_tool_name: str | None = None settings: str | None = None add_dirs: list[str | Path] | None = None env: dict[str, str] | None = None extra_args: dict[str, str | None] | None = None max_buffer_size: int | None = None # callbacks and hooks stderr: Any | None = ( None # Callable[[str], None] - can't type check callables in Pydantic easily ) can_use_tool: Any | None = None # CanUseTool callback hooks: dict[str, list[dict[str, Any]]] | None = None # user and settings user: str | None = None include_partial_messages: bool = False setting_sources: list[str] | None = None # List of "user" | "project" | "local" # subagents and plugins agents: dict[str, dict[str, Any]] | None = None plugins: list[dict[str, Any]] | None = None # cli configuration cli_path: str | Path | None = None class ClaudeCode(PrecompiledProgram): """DSPy module that wraps Claude Code SDK. Each agent instance maintains a stateful conversation session. Perfect for multi-turn agentic workflows with context preservation. """ config: ClaudeCodeConfig def __init__( self, config: ClaudeCodeConfig, **kwargs: dict, ): super().__init__(config=config) args = ClaudeCodeKwargs(**kwargs) # validate signature # Note: Raw string signatures only work with built-in types. # For custom Pydantic models, users must pass: # 1. A class-based signature, OR # 2. Pre-constructed dspy.Signature (in their module where types are defined) signature = args.signature if isinstance(signature, str): try: self.signature = dspy.Signature(signature) except ValueError as e: if "Unknown name:" in str(e): type_name = str(e).split("Unknown name: ")[-1] raise ValueError( f"Cannot resolve type '{type_name}' in string signature.\n" f"String signatures only work with built-in types (str, int, list[str], etc.).\n\n" f"For custom Pydantic models, use one of these approaches:\n\n" f"Option 1 - Class-based signature (recommended):\n" f" class MySignature(dspy.Signature):\n" f" input: str = dspy.InputField()\n" f" output: {type_name} = dspy.OutputField()\n" f" agent = ClaudeCode(config, signature=MySignature, ...)\n\n" f"Option 2 - Pre-construct signature in your module:\n" f" sig = dspy.Signature('{signature}')\n" f" agent = ClaudeCode(config, signature=sig, ...)\n" ) from e raise else: self.signature = signature # validate signature has exactly 1 input and 1 output TODO: support multiple inputs/outputs input_fields = list(self.signature.input_fields.keys()) output_fields = list(self.signature.output_fields.keys()) if len(input_fields) != 1: raise ValueError( f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. " f"Found: {input_fields}" ) if len(output_fields) != 1: raise ValueError( f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. " f"Found: {output_fields}" ) self.input_field_name = input_fields[0] self.output_field_name = output_fields[0] self.input_field = self.signature.input_fields[self.input_field_name] self.output_field = self.signature.output_fields[self.output_field_name] # store all configuration values self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY") self.working_directory = Path(args.working_directory).resolve() self.model = config.model # basic options self.permission_mode = args.permission_mode self.allowed_tools = args.allowed_tools self.disallowed_tools = args.disallowed_tools self.sandbox = args.sandbox self.system_prompt = args.system_prompt # mcp servers self.mcp_servers = args.mcp_servers # session management self.continue_conversation = args.continue_conversation self.resume = args.resume self.max_turns = args.max_turns self.fork_session = args.fork_session # advanced options self.permission_prompt_tool_name = args.permission_prompt_tool_name self.settings = args.settings self.add_dirs = args.add_dirs self.env = args.env self.extra_args = args.extra_args self.max_buffer_size = args.max_buffer_size # callbacks and hooks self.stderr = args.stderr self.can_use_tool = args.can_use_tool self.hooks = args.hooks # user and settings self.user = args.user self.include_partial_messages = args.include_partial_messages self.setting_sources = args.setting_sources # subagents and plugins self.agents = args.agents self.plugins = args.plugins # cli configuration self.cli_path = args.cli_path # determine output format upfront self.output_format = self._get_output_format() # session state self._client: Optional[ClaudeSDKClient] = None self._session_id: Optional[str] = None self._is_connected = False @property def session_id(self) -> Optional[str]: """Get the session ID for this agent instance. Returns None until first forward() call. """ return self._session_id def _create_client(self) -> ClaudeSDKClient: """Create ClaudeSDKClient with configured options.""" # build options dict, only including non-None values options_dict = { "cwd": str(self.working_directory), "model": self.model, "output_format": self.output_format, } # add optional fields only if they're not None if self.permission_mode is not None: options_dict["permission_mode"] = self.permission_mode if self.allowed_tools is not None: options_dict["allowed_tools"] = self.allowed_tools if self.disallowed_tools is not None: options_dict["disallowed_tools"] = self.disallowed_tools if self.sandbox is not None: options_dict["sandbox"] = self.sandbox if self.system_prompt is not None: options_dict["system_prompt"] = self.system_prompt if self.mcp_servers is not None: options_dict["mcp_servers"] = self.mcp_servers if self.continue_conversation: options_dict["continue_conversation"] = self.continue_conversation if self.resume is not None: options_dict["resume"] = self.resume if self.max_turns is not None: options_dict["max_turns"] = self.max_turns if self.fork_session: options_dict["fork_session"] = self.fork_session if self.permission_prompt_tool_name is not None: options_dict["permission_prompt_tool_name"] = ( self.permission_prompt_tool_name ) if self.settings is not None: options_dict["settings"] = self.settings if self.add_dirs is not None: options_dict["add_dirs"] = self.add_dirs if self.env is not None: options_dict["env"] = self.env if self.extra_args is not None: options_dict["extra_args"] = self.extra_args if self.max_buffer_size is not None: options_dict["max_buffer_size"] = self.max_buffer_size if self.stderr is not None: options_dict["stderr"] = self.stderr if self.can_use_tool is not None: options_dict["can_use_tool"] = self.can_use_tool if self.hooks is not None: options_dict["hooks"] = self.hooks if self.user is not None: options_dict["user"] = self.user if self.include_partial_messages: options_dict["include_partial_messages"] = self.include_partial_messages if self.setting_sources is not None: options_dict["setting_sources"] = self.setting_sources if self.agents is not None: options_dict["agents"] = self.agents if self.plugins is not None: options_dict["plugins"] = self.plugins if self.cli_path is not None: options_dict["cli_path"] = self.cli_path options = ClaudeAgentOptions(**options_dict) # set API key if provided if self.api_key: os.environ["ANTHROPIC_API_KEY"] = self.api_key return ClaudeSDKClient(options=options) def _build_prompt(self, input_value: str) -> str: """Build prompt from signature docstring, field descriptions, and input value. Note: When using structured outputs, the SDK handles JSON formatting automatically via the output_format parameter, so we don't add JSON instructions to the prompt. """ prompt_parts = [] # add signature docstring if present if self.signature.__doc__: doc = self.signature.__doc__.strip() if doc: prompt_parts.append(f"Task: {doc}") # add input field description if present # DSPy fields store desc in json_schema_extra input_desc = None if ( hasattr(self.input_field, "json_schema_extra") and self.input_field.json_schema_extra ): input_desc = self.input_field.json_schema_extra.get("desc") if input_desc: prompt_parts.append(f"Input context: {input_desc}") # add the actual input value prompt_parts.append(input_value) # add output field description if present output_desc = None if ( hasattr(self.output_field, "json_schema_extra") and self.output_field.json_schema_extra ): output_desc = self.output_field.json_schema_extra.get("desc") if output_desc: prompt_parts.append(f"\nPlease produce the following output: {output_desc}") # the schema is passed through ClaudeAgentOptions and enforced by the SDK return "\n\n".join(prompt_parts) def _get_output_format(self) -> Optional[dict[str, Any]]: """Get output format configuration for structured outputs. Supports: - Direct Pydantic models: MyModel - Generic types: list[MyModel], dict[str, MyModel] """ output_type = self.output_field.annotation if is_pydantic_model(output_type): schema = get_json_schema(output_type) return { "type": "json_schema", "schema": schema, } return None async def _run_async( self, prompt: str ) -> tuple[str | dict | list | None, list[TraceItem], Usage]: """Run the agent asynchronously and collect results. Returns: - response: For structured outputs, returns dict/list from structured_output. For text outputs, returns string from result or text blocks. - trace: Execution trace items - usage: Token usage statistics """ print( f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})" ) # create client if needed if self._client is None: self._client = self._create_client() # connect if not already connected if not self._is_connected: await self._client.connect() self._is_connected = True await self._client.query(prompt) print(f"[ClaudeCode._run_async] Query sent, waiting for response...") # collect messages and build trace trace: list[TraceItem] = [] usage = Usage() response_text = "" structured_output = None message_count = 0 async for message in self._client.receive_response(): message_count += 1 # handle assistant messages if isinstance(message, AssistantMessage): for block in message.content: if isinstance(block, TextBlock): response_text += block.text trace.append( AgentMessageItem(text=block.text, model=message.model) ) elif isinstance(block, ThinkingBlock): trace.append( ThinkingItem(text=block.thinking, model=message.model) ) elif isinstance(block, ToolUseBlock): # handle StructuredOutput tool (contains JSON response) if block.name == "StructuredOutput": # the JSON is directly in the tool input (already a dict) response_text = json.dumps(block.input) trace.append( ToolUseItem( tool_name=block.name, tool_input=block.input, tool_use_id=block.id, ) ) elif isinstance(block, ToolResultBlock): content_str = "" if isinstance(block.content, str): content_str = block.content elif isinstance(block.content, list): # extract text from content blocks for item in block.content: if ( isinstance(item, dict) and item.get("type") == "text" ): content_str += item.get("text", "") trace.append( ToolResultItem( tool_name="", # tool name not in ToolResultBlock tool_use_id=block.tool_use_id, content=content_str, is_error=block.is_error or False, ) ) # handle result messages (final message with usage info) elif isinstance(message, ResultMessage): # store session ID if hasattr(message, "session_id"): self._session_id = message.session_id print(f"[ClaudeCode._run_async] - Session ID: {self._session_id}") # extract usage if hasattr(message, "usage") and message.usage: usage_data = message.usage usage = Usage( input_tokens=usage_data.get("input_tokens", 0), cached_input_tokens=usage_data.get( "cache_read_input_tokens", 0 ), output_tokens=usage_data.get("output_tokens", 0), ) # check for errors if hasattr(message, "is_error") and message.is_error: error_msg = ( message.result if hasattr(message, "result") else "Unknown error" ) trace.append( ErrorItem(message=error_msg, error_type="execution_error") ) raise RuntimeError(f"Agent execution failed: {error_msg}") # prefer structured_output over result (when using output_format) if ( hasattr(message, "structured_output") and message.structured_output is not None ): structured_output = message.structured_output # fallback to result field for text outputs elif hasattr(message, "result") and message.result: response_text = message.result # handle system messages elif isinstance(message, SystemMessage): # log system messages to trace but don't error if hasattr(message, "data") and message.data: data_str = str(message.data) trace.append( AgentMessageItem(text=f"[System: {data_str}]", model="system") ) # return structured_output if available (for Pydantic outputs), otherwise text if structured_output is not None: return structured_output, trace, usage else: return response_text, trace, usage def forward(self, **kwargs: Any) -> Prediction: """Execute the agent with an input message. Args: **kwargs: Must contain the input field specified in signature Returns: Prediction with: - Typed output field (named according to signature) - trace: list[TraceItem] - Execution trace - usage: Usage - Token usage statistics Example: >>> result = agent(message="Hello") >>> print(result.answer) # Access typed output >>> print(result.trace) # List of execution items >>> print(result.usage) # Token usage stats """ # extract input value if self.input_field_name not in kwargs: raise ValueError( f"Missing required input field: {self.input_field_name}. " f"Received: {list(kwargs.keys())}" ) input_value = kwargs[self.input_field_name] # build prompt prompt = self._build_prompt(input_value) # run async execution in event loop try: loop = asyncio.get_event_loop() if loop.is_running(): # If already in async context, create new loop import nest_asyncio nest_asyncio.apply() response_text, trace, usage = loop.run_until_complete( self._run_async(prompt) ) else: response_text, trace, usage = loop.run_until_complete( self._run_async(prompt) ) except RuntimeError: # no event loop, create one response_text, trace, usage = asyncio.run(self._run_async(prompt)) # parse response based on output type output_type = self.output_field.annotation if is_pydantic_model(output_type): try: # response_text can be dict/list (from structured_output) or str (legacy) parsed_output = parse_json_response(response_text, output_type) except Exception as e: raise ValueError( f"Failed to parse Claude response as {output_type}: {e}\n" f"Response type: {type(response_text)}\n" f"Response: {response_text}" ) else: # string output - extract text if isinstance(response_text, str): parsed_output = extract_text_from_response(response_text) else: # Shouldn't happen, but handle gracefully parsed_output = str(response_text) # return prediction with typed output, trace, and usage return Prediction( **{ self.output_field_name: parsed_output, "trace": trace, "usage": usage, } ) async def aforward(self, **kwargs: Any) -> Prediction: """Async version of forward(). Use this when already in an async context to avoid event loop issues. Args: **kwargs: Must contain the input field specified in signature Returns: Prediction with typed output, trace, and usage """ # extract input value if self.input_field_name not in kwargs: raise ValueError( f"Missing required input field: {self.input_field_name}. " f"Received: {list(kwargs.keys())}" ) input_value = kwargs[self.input_field_name] # build prompt prompt = self._build_prompt(input_value) # run async execution response_text, trace, usage = await self._run_async(prompt) # parse response based on output type output_type = self.output_field.annotation if is_pydantic_model(output_type): try: # response_text can be dict/list (from structured_output) or str (legacy) parsed_output = parse_json_response(response_text, output_type) except Exception as e: raise ValueError( f"Failed to parse Claude response as {output_type}: {e}\n" f"Response type: {type(response_text)}\n" f"Response: {response_text}" ) else: # string output - extract text if isinstance(response_text, str): parsed_output = extract_text_from_response(response_text) else: # Shouldn't happen, but handle gracefully parsed_output = str(response_text) # return prediction with typed output, trace, and usage return Prediction( **{ self.output_field_name: parsed_output, "trace": trace, "usage": usage, } ) async def disconnect(self) -> None: """Disconnect from Claude Code and clean up resources.""" if self._client and self._is_connected: await self._client.disconnect() self._is_connected = False def __del__(self): """Cleanup on deletion.""" # Check attributes exist before accessing (may fail during __init__) if hasattr(self, "_client") and hasattr(self, "_is_connected"): if self._client and self._is_connected: try: asyncio.run(self.disconnect()) except Exception: # best effort cleanup pass