import asyncio import json import os from pathlib import Path from typing import Any, Optional from modaic import PrecompiledProgram, PrecompiledConfig from pydantic import BaseModel import dspy from dspy.primitives.prediction import Prediction from claude_agent_sdk import ( ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, ResultMessage, SystemMessage, TextBlock, ThinkingBlock, ToolUseBlock, ToolResultBlock, ) from .trace import ( TraceItem, AgentMessageItem, ThinkingItem, ToolUseItem, ToolResultItem, ErrorItem, ) from .utils import ( Usage, is_pydantic_model, get_json_schema, parse_json_response, extract_text_from_response, ) class ClaudeCodeConfig(PrecompiledConfig): """Configuration for ClaudeCode agent.""" model: str = "claude-opus-4-5-20251101" class ClaudeCodeKwargs(BaseModel): model_config = {"arbitrary_types_allowed": True} signature: Any # str | dspy.Signature (validated manually) api_key: str | None = None working_directory: str = "." permission_mode: str | None = None allowed_tools: list[str] | None = None disallowed_tools: list[str] | None = None sandbox: dict[str, Any] | None = None system_prompt: str | dict[str, Any] | None = None class ClaudeCode(PrecompiledProgram): """DSPy module that wraps Claude Code SDK. Each agent instance maintains a stateful conversation session. Perfect for multi-turn agentic workflows with context preservation. Example: >>> config = ClaudeCodeConfig() >>> agent = ClaudeCode( ... config, ... signature='message:str -> answer:str', ... working_directory="." ... ) >>> result = agent(message="What files are in this directory?") >>> print(result.answer) # Typed access >>> print(result.trace) # Execution trace >>> print(result.usage) # Token usage """ config: ClaudeCodeConfig def __init__( self, config: ClaudeCodeConfig, **kwargs: dict, ): super().__init__(config=config) args = ClaudeCodeKwargs(**kwargs) signature = args.signature api_key = args.api_key working_directory = args.working_directory permission_mode = args.permission_mode allowed_tools = args.allowed_tools disallowed_tools = args.disallowed_tools sandbox = args.sandbox system_prompt = args.system_prompt # parse and validate signature if isinstance(signature, str): self.signature = dspy.Signature(signature) else: self.signature = signature # validate signature has exactly 1 input and 1 output input_fields = list(self.signature.input_fields.keys()) output_fields = list(self.signature.output_fields.keys()) if len(input_fields) != 1: raise ValueError( f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. " f"Found: {input_fields}" ) if len(output_fields) != 1: raise ValueError( f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. " f"Found: {output_fields}" ) self.input_field_name = input_fields[0] self.output_field_name = output_fields[0] self.input_field = self.signature.input_fields[self.input_field_name] self.output_field = self.signature.output_fields[self.output_field_name] # store config values self.working_directory = Path(working_directory).resolve() self.model = config.model self.permission_mode = permission_mode self.allowed_tools = allowed_tools self.disallowed_tools = disallowed_tools self.sandbox = sandbox self.system_prompt = system_prompt self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY") # No extra options since all kwargs are parsed by ClaudeCodeKwargs self.extra_options = {} # determine output format upfront self.output_format = self._get_output_format() # session state self._client: Optional[ClaudeSDKClient] = None self._session_id: Optional[str] = None self._is_connected = False @property def session_id(self) -> Optional[str]: """Get the session ID for this agent instance. Returns None until first forward() call. """ return self._session_id def _create_client(self) -> ClaudeSDKClient: """Create ClaudeSDKClient with configured options.""" options = ClaudeAgentOptions( cwd=str(self.working_directory), model=self.model, permission_mode=self.permission_mode, allowed_tools=self.allowed_tools or [], disallowed_tools=self.disallowed_tools or [], sandbox=self.sandbox, system_prompt=self.system_prompt, output_format=self.output_format, # include output format **self.extra_options, ) # set API key if provided if self.api_key: os.environ["ANTHROPIC_API_KEY"] = self.api_key return ClaudeSDKClient(options=options) def _build_prompt(self, input_value: str) -> str: """Build prompt from signature docstring, field descriptions, and input value.""" prompt_parts = [] # add signature docstring if present if self.signature.__doc__: doc = self.signature.__doc__.strip() if doc: prompt_parts.append(f"Task: {doc}") # add input field description if present # DSPy fields store desc in json_schema_extra input_desc = None if ( hasattr(self.input_field, "json_schema_extra") and self.input_field.json_schema_extra ): input_desc = self.input_field.json_schema_extra.get("desc") if input_desc: prompt_parts.append(f"Input context: {input_desc}") # add the actual input value prompt_parts.append(input_value) # add output field description if present output_desc = None if ( hasattr(self.output_field, "json_schema_extra") and self.output_field.json_schema_extra ): output_desc = self.output_field.json_schema_extra.get("desc") if output_desc: prompt_parts.append(f"\nPlease produce the following output: {output_desc}") # for Pydantic outputs, add explicit JSON instructions if self.output_format: schema = self.output_format["schema"] prompt_parts.append( f"\nYou MUST respond with ONLY valid JSON matching this schema:\n" f"{json.dumps(schema, indent=2)}\n\n" f"Do not include any explanatory text, markdown formatting, or code blocks. " f"Return ONLY the raw JSON object." ) return "\n\n".join(prompt_parts) def _get_output_format(self) -> Optional[dict[str, Any]]: """Get output format configuration for structured outputs.""" output_type = self.output_field.annotation if is_pydantic_model(output_type): schema = get_json_schema(output_type) return { "type": "json_schema", "schema": schema, } return None async def _run_async(self, prompt: str) -> tuple[str, list[TraceItem], Usage]: """Run the agent asynchronously and collect results.""" # create client if needed if self._client is None: self._client = self._create_client() # connect if not already connected if not self._is_connected: await self._client.connect() self._is_connected = True # send query (output_format already configured in options) await self._client.query(prompt) # collect messages and build trace trace: list[TraceItem] = [] usage = Usage() response_text = "" async for message in self._client.receive_response(): # handle assistant messages if isinstance(message, AssistantMessage): for block in message.content: if isinstance(block, TextBlock): response_text += block.text trace.append( AgentMessageItem(text=block.text, model=message.model) ) elif isinstance(block, ThinkingBlock): trace.append( ThinkingItem(text=block.thinking, model=message.model) ) elif isinstance(block, ToolUseBlock): # Handle StructuredOutput tool (contains JSON response) if block.name == "StructuredOutput": # The JSON is directly in the tool input (already a dict) response_text = json.dumps(block.input) trace.append( ToolUseItem( tool_name=block.name, tool_input=block.input, tool_use_id=block.id, ) ) elif isinstance(block, ToolResultBlock): content_str = "" if isinstance(block.content, str): content_str = block.content elif isinstance(block.content, list): # Extract text from content blocks for item in block.content: if ( isinstance(item, dict) and item.get("type") == "text" ): content_str += item.get("text", "") trace.append( ToolResultItem( tool_name="", # Tool name not in ToolResultBlock tool_use_id=block.tool_use_id, content=content_str, is_error=block.is_error or False, ) ) # handle result messages (final message with usage info) elif isinstance(message, ResultMessage): # store session ID if hasattr(message, "session_id"): self._session_id = message.session_id # extract usage if hasattr(message, "usage") and message.usage: usage_data = message.usage usage = Usage( input_tokens=usage_data.get("input_tokens", 0), cached_input_tokens=usage_data.get( "cache_read_input_tokens", 0 ), output_tokens=usage_data.get("output_tokens", 0), ) # check for errors if hasattr(message, "is_error") and message.is_error: error_msg = ( message.result if hasattr(message, "result") else "Unknown error" ) trace.append( ErrorItem(message=error_msg, error_type="execution_error") ) raise RuntimeError(f"Agent execution failed: {error_msg}") # extract result if present (for structured outputs from result field) # Note: structured outputs may come from StructuredOutput tool instead if hasattr(message, "result") and message.result: response_text = message.result # handle system messages elif isinstance(message, SystemMessage): # log system messages to trace but don't error if hasattr(message, "data") and message.data: data_str = str(message.data) trace.append( AgentMessageItem(text=f"[System: {data_str}]", model="system") ) return response_text, trace, usage def forward(self, **kwargs: Any) -> Prediction: """Execute the agent with an input message. Args: **kwargs: Must contain the input field specified in signature Returns: Prediction with: - Typed output field (named according to signature) - trace: list[TraceItem] - Execution trace - usage: Usage - Token usage statistics Example: >>> result = agent(message="Hello") >>> print(result.answer) # Access typed output >>> print(result.trace) # List of execution items >>> print(result.usage) # Token usage stats """ # extract input value if self.input_field_name not in kwargs: raise ValueError( f"Missing required input field: {self.input_field_name}. " f"Received: {list(kwargs.keys())}" ) input_value = kwargs[self.input_field_name] # build prompt prompt = self._build_prompt(input_value) # run async execution in event loop try: loop = asyncio.get_event_loop() if loop.is_running(): # If already in async context, create new loop import nest_asyncio nest_asyncio.apply() response_text, trace, usage = loop.run_until_complete( self._run_async(prompt) ) else: response_text, trace, usage = loop.run_until_complete( self._run_async(prompt) ) except RuntimeError: # no event loop, create one response_text, trace, usage = asyncio.run(self._run_async(prompt)) # parse response based on output type output_type = self.output_field.annotation if is_pydantic_model(output_type): try: parsed_output = parse_json_response(response_text, output_type) except Exception as e: raise ValueError( f"Failed to parse Claude response as {output_type.__name__}: {e}\n" f"Response: {response_text}" ) else: # string output - extract text parsed_output = extract_text_from_response(response_text) # return prediction with typed output, trace, and usage return Prediction( **{ self.output_field_name: parsed_output, "trace": trace, "usage": usage, } ) async def aforward(self, **kwargs: Any) -> Prediction: """Async version of forward(). Use this when already in an async context to avoid event loop issues. Args: **kwargs: Must contain the input field specified in signature Returns: Prediction with typed output, trace, and usage """ # extract input value if self.input_field_name not in kwargs: raise ValueError( f"Missing required input field: {self.input_field_name}. " f"Received: {list(kwargs.keys())}" ) input_value = kwargs[self.input_field_name] # build prompt prompt = self._build_prompt(input_value) # run async execution response_text, trace, usage = await self._run_async(prompt) # parse response based on output type output_type = self.output_field.annotation if is_pydantic_model(output_type): try: parsed_output = parse_json_response(response_text, output_type) except Exception as e: raise ValueError( f"Failed to parse Claude response as {output_type.__name__}: {e}\n" f"Response: {response_text}" ) else: # string output - extract text parsed_output = extract_text_from_response(response_text) # return prediction with typed output, trace, and usage return Prediction( **{ self.output_field_name: parsed_output, "trace": trace, "usage": usage, } ) async def disconnect(self) -> None: """Disconnect from Claude Code and clean up resources.""" if self._client and self._is_connected: await self._client.disconnect() self._is_connected = False def __del__(self): """Cleanup on deletion.""" if self._client and self._is_connected: try: asyncio.run(self.disconnect()) except Exception: # best effort cleanup pass