Add more functionality to signature description parsing

This commit is contained in:
2025-12-05 23:40:50 -05:00
parent d776acd439
commit e1c81644c3
9 changed files with 827 additions and 76 deletions

479
claude_dspy/agent.py Normal file
View File

@@ -0,0 +1,479 @@
import asyncio
import json
import os
from pathlib import Path
from typing import Any, Optional
from modaic import PrecompiledProgram, PrecompiledConfig
from pydantic import BaseModel
import dspy
from dspy.primitives.prediction import Prediction
from claude_agent_sdk import (
ClaudeSDKClient,
ClaudeAgentOptions,
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolUseBlock,
ToolResultBlock,
)
from claude_dspy.trace import (
TraceItem,
AgentMessageItem,
ThinkingItem,
ToolUseItem,
ToolResultItem,
ErrorItem,
)
from claude_dspy.utils import (
Usage,
is_pydantic_model,
get_json_schema,
parse_json_response,
extract_text_from_response,
)
class ClaudeCodeConfig(PrecompiledConfig):
"""Configuration for ClaudeCode agent."""
model: str = "claude-opus-4-5-20251101"
class ClaudeCodeKwargs(BaseModel):
model_config = {"arbitrary_types_allowed": True}
signature: Any # str | dspy.Signature (validated manually)
api_key: str | None = None
working_directory: str = "."
permission_mode: str | None = None
allowed_tools: list[str] | None = None
disallowed_tools: list[str] | None = None
sandbox: dict[str, Any] | None = None
system_prompt: str | dict[str, Any] | None = None
class ClaudeCode(PrecompiledProgram):
"""DSPy module that wraps Claude Code SDK.
Each agent instance maintains a stateful conversation session.
Perfect for multi-turn agentic workflows with context preservation.
Example:
>>> config = ClaudeCodeConfig()
>>> agent = ClaudeCode(
... config,
... signature='message:str -> answer:str',
... working_directory="."
... )
>>> result = agent(message="What files are in this directory?")
>>> print(result.answer) # Typed access
>>> print(result.trace) # Execution trace
>>> print(result.usage) # Token usage
"""
config: ClaudeCodeConfig
def __init__(
self,
config: ClaudeCodeConfig,
**kwargs: dict,
):
super().__init__(config=config)
args = ClaudeCodeKwargs(**kwargs)
signature = args.signature
api_key = args.api_key
working_directory = args.working_directory
permission_mode = args.permission_mode
allowed_tools = args.allowed_tools
disallowed_tools = args.disallowed_tools
sandbox = args.sandbox
system_prompt = args.system_prompt
# parse and validate signature
if isinstance(signature, str):
self.signature = dspy.Signature(signature)
else:
self.signature = signature
# validate signature has exactly 1 input and 1 output
input_fields = list(self.signature.input_fields.keys())
output_fields = list(self.signature.output_fields.keys())
if len(input_fields) != 1:
raise ValueError(
f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. "
f"Found: {input_fields}"
)
if len(output_fields) != 1:
raise ValueError(
f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. "
f"Found: {output_fields}"
)
self.input_field_name = input_fields[0]
self.output_field_name = output_fields[0]
self.input_field = self.signature.input_fields[self.input_field_name]
self.output_field = self.signature.output_fields[self.output_field_name]
# store config values
self.working_directory = Path(working_directory).resolve()
self.model = config.model
self.permission_mode = permission_mode
self.allowed_tools = allowed_tools
self.disallowed_tools = disallowed_tools
self.sandbox = sandbox
self.system_prompt = system_prompt
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
# No extra options since all kwargs are parsed by ClaudeCodeKwargs
self.extra_options = {}
# determine output format upfront
self.output_format = self._get_output_format()
# session state
self._client: Optional[ClaudeSDKClient] = None
self._session_id: Optional[str] = None
self._is_connected = False
@property
def session_id(self) -> Optional[str]:
"""Get the session ID for this agent instance.
Returns None until first forward() call.
"""
return self._session_id
def _create_client(self) -> ClaudeSDKClient:
"""Create ClaudeSDKClient with configured options."""
options = ClaudeAgentOptions(
cwd=str(self.working_directory),
model=self.model,
permission_mode=self.permission_mode,
allowed_tools=self.allowed_tools or [],
disallowed_tools=self.disallowed_tools or [],
sandbox=self.sandbox,
system_prompt=self.system_prompt,
output_format=self.output_format, # include output format
**self.extra_options,
)
# set API key if provided
if self.api_key:
os.environ["ANTHROPIC_API_KEY"] = self.api_key
return ClaudeSDKClient(options=options)
def _build_prompt(self, input_value: str) -> str:
"""Build prompt from signature docstring, field descriptions, and input value."""
prompt_parts = []
# add signature docstring if present
if self.signature.__doc__:
doc = self.signature.__doc__.strip()
if doc:
prompt_parts.append(f"Task: {doc}")
# add input field description if present
# DSPy fields store desc in json_schema_extra
input_desc = None
if (
hasattr(self.input_field, "json_schema_extra")
and self.input_field.json_schema_extra
):
input_desc = self.input_field.json_schema_extra.get("desc")
if input_desc:
prompt_parts.append(f"Input context: {input_desc}")
# add the actual input value
prompt_parts.append(input_value)
# add output field description if present
output_desc = None
if (
hasattr(self.output_field, "json_schema_extra")
and self.output_field.json_schema_extra
):
output_desc = self.output_field.json_schema_extra.get("desc")
if output_desc:
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
# for Pydantic outputs, add explicit JSON instructions
if self.output_format:
schema = self.output_format["schema"]
prompt_parts.append(
f"\nYou MUST respond with ONLY valid JSON matching this schema:\n"
f"{json.dumps(schema, indent=2)}\n\n"
f"Do not include any explanatory text, markdown formatting, or code blocks. "
f"Return ONLY the raw JSON object."
)
return "\n\n".join(prompt_parts)
def _get_output_format(self) -> Optional[dict[str, Any]]:
"""Get output format configuration for structured outputs."""
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
schema = get_json_schema(output_type)
return {
"type": "json_schema",
"schema": schema,
}
return None
async def _run_async(self, prompt: str) -> tuple[str, list[TraceItem], Usage]:
"""Run the agent asynchronously and collect results."""
# create client if needed
if self._client is None:
self._client = self._create_client()
# connect if not already connected
if not self._is_connected:
await self._client.connect()
self._is_connected = True
# send query (output_format already configured in options)
await self._client.query(prompt)
# collect messages and build trace
trace: list[TraceItem] = []
usage = Usage()
response_text = ""
async for message in self._client.receive_response():
# handle assistant messages
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
response_text += block.text
trace.append(
AgentMessageItem(text=block.text, model=message.model)
)
elif isinstance(block, ThinkingBlock):
trace.append(
ThinkingItem(text=block.thinking, model=message.model)
)
elif isinstance(block, ToolUseBlock):
# Handle StructuredOutput tool (contains JSON response)
if block.name == "StructuredOutput":
# The JSON is directly in the tool input (already a dict)
response_text = json.dumps(block.input)
trace.append(
ToolUseItem(
tool_name=block.name,
tool_input=block.input,
tool_use_id=block.id,
)
)
elif isinstance(block, ToolResultBlock):
content_str = ""
if isinstance(block.content, str):
content_str = block.content
elif isinstance(block.content, list):
# Extract text from content blocks
for item in block.content:
if (
isinstance(item, dict)
and item.get("type") == "text"
):
content_str += item.get("text", "")
trace.append(
ToolResultItem(
tool_name="", # Tool name not in ToolResultBlock
tool_use_id=block.tool_use_id,
content=content_str,
is_error=block.is_error or False,
)
)
# handle result messages (final message with usage info)
elif isinstance(message, ResultMessage):
# store session ID
if hasattr(message, "session_id"):
self._session_id = message.session_id
# extract usage
if hasattr(message, "usage") and message.usage:
usage_data = message.usage
usage = Usage(
input_tokens=usage_data.get("input_tokens", 0),
cached_input_tokens=usage_data.get(
"cache_read_input_tokens", 0
),
output_tokens=usage_data.get("output_tokens", 0),
)
# check for errors
if hasattr(message, "is_error") and message.is_error:
error_msg = (
message.result
if hasattr(message, "result")
else "Unknown error"
)
trace.append(
ErrorItem(message=error_msg, error_type="execution_error")
)
raise RuntimeError(f"Agent execution failed: {error_msg}")
# extract result if present (for structured outputs from result field)
# Note: structured outputs may come from StructuredOutput tool instead
if hasattr(message, "result") and message.result:
response_text = message.result
# handle system messages
elif isinstance(message, SystemMessage):
# log system messages to trace but don't error
if hasattr(message, "data") and message.data:
data_str = str(message.data)
trace.append(
AgentMessageItem(text=f"[System: {data_str}]", model="system")
)
return response_text, trace, usage
def forward(self, **kwargs: Any) -> Prediction:
"""Execute the agent with an input message.
Args:
**kwargs: Must contain the input field specified in signature
Returns:
Prediction with:
- Typed output field (named according to signature)
- trace: list[TraceItem] - Execution trace
- usage: Usage - Token usage statistics
Example:
>>> result = agent(message="Hello")
>>> print(result.answer) # Access typed output
>>> print(result.trace) # List of execution items
>>> print(result.usage) # Token usage stats
"""
# extract input value
if self.input_field_name not in kwargs:
raise ValueError(
f"Missing required input field: {self.input_field_name}. "
f"Received: {list(kwargs.keys())}"
)
input_value = kwargs[self.input_field_name]
# build prompt
prompt = self._build_prompt(input_value)
# run async execution in event loop
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If already in async context, create new loop
import nest_asyncio
nest_asyncio.apply()
response_text, trace, usage = loop.run_until_complete(
self._run_async(prompt)
)
else:
response_text, trace, usage = loop.run_until_complete(
self._run_async(prompt)
)
except RuntimeError:
# no event loop, create one
response_text, trace, usage = asyncio.run(self._run_async(prompt))
# parse response based on output type
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
try:
parsed_output = parse_json_response(response_text, output_type)
except Exception as e:
raise ValueError(
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
f"Response: {response_text}"
)
else:
# string output - extract text
parsed_output = extract_text_from_response(response_text)
# return prediction with typed output, trace, and usage
return Prediction(
**{
self.output_field_name: parsed_output,
"trace": trace,
"usage": usage,
}
)
async def aforward(self, **kwargs: Any) -> Prediction:
"""Async version of forward().
Use this when already in an async context to avoid event loop issues.
Args:
**kwargs: Must contain the input field specified in signature
Returns:
Prediction with typed output, trace, and usage
"""
# extract input value
if self.input_field_name not in kwargs:
raise ValueError(
f"Missing required input field: {self.input_field_name}. "
f"Received: {list(kwargs.keys())}"
)
input_value = kwargs[self.input_field_name]
# build prompt
prompt = self._build_prompt(input_value)
# run async execution
response_text, trace, usage = await self._run_async(prompt)
# parse response based on output type
output_type = self.output_field.annotation
if is_pydantic_model(output_type):
try:
parsed_output = parse_json_response(response_text, output_type)
except Exception as e:
raise ValueError(
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
f"Response: {response_text}"
)
else:
# string output - extract text
parsed_output = extract_text_from_response(response_text)
# return prediction with typed output, trace, and usage
return Prediction(
**{
self.output_field_name: parsed_output,
"trace": trace,
"usage": usage,
}
)
async def disconnect(self) -> None:
"""Disconnect from Claude Code and clean up resources."""
if self._client and self._is_connected:
await self._client.disconnect()
self._is_connected = False
def __del__(self):
"""Cleanup on deletion."""
if self._client and self._is_connected:
try:
asyncio.run(self.disconnect())
except Exception:
# best effort cleanup
pass