Add more functionality to signature description parsing
This commit is contained in:
479
claude_dspy/agent.py
Normal file
479
claude_dspy/agent.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
import dspy
|
||||
from dspy.primitives.prediction import Prediction
|
||||
|
||||
from claude_agent_sdk import (
|
||||
ClaudeSDKClient,
|
||||
ClaudeAgentOptions,
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
|
||||
from claude_dspy.trace import (
|
||||
TraceItem,
|
||||
AgentMessageItem,
|
||||
ThinkingItem,
|
||||
ToolUseItem,
|
||||
ToolResultItem,
|
||||
ErrorItem,
|
||||
)
|
||||
from claude_dspy.utils import (
|
||||
Usage,
|
||||
is_pydantic_model,
|
||||
get_json_schema,
|
||||
parse_json_response,
|
||||
extract_text_from_response,
|
||||
)
|
||||
|
||||
|
||||
class ClaudeCodeConfig(PrecompiledConfig):
|
||||
"""Configuration for ClaudeCode agent."""
|
||||
model: str = "claude-opus-4-5-20251101"
|
||||
|
||||
class ClaudeCodeKwargs(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
signature: Any # str | dspy.Signature (validated manually)
|
||||
api_key: str | None = None
|
||||
working_directory: str = "."
|
||||
permission_mode: str | None = None
|
||||
allowed_tools: list[str] | None = None
|
||||
disallowed_tools: list[str] | None = None
|
||||
sandbox: dict[str, Any] | None = None
|
||||
system_prompt: str | dict[str, Any] | None = None
|
||||
|
||||
class ClaudeCode(PrecompiledProgram):
|
||||
"""DSPy module that wraps Claude Code SDK.
|
||||
|
||||
Each agent instance maintains a stateful conversation session.
|
||||
Perfect for multi-turn agentic workflows with context preservation.
|
||||
|
||||
Example:
|
||||
>>> config = ClaudeCodeConfig()
|
||||
>>> agent = ClaudeCode(
|
||||
... config,
|
||||
... signature='message:str -> answer:str',
|
||||
... working_directory="."
|
||||
... )
|
||||
>>> result = agent(message="What files are in this directory?")
|
||||
>>> print(result.answer) # Typed access
|
||||
>>> print(result.trace) # Execution trace
|
||||
>>> print(result.usage) # Token usage
|
||||
"""
|
||||
|
||||
config: ClaudeCodeConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ClaudeCodeConfig,
|
||||
**kwargs: dict,
|
||||
):
|
||||
super().__init__(config=config)
|
||||
|
||||
args = ClaudeCodeKwargs(**kwargs)
|
||||
|
||||
signature = args.signature
|
||||
api_key = args.api_key
|
||||
working_directory = args.working_directory
|
||||
permission_mode = args.permission_mode
|
||||
allowed_tools = args.allowed_tools
|
||||
disallowed_tools = args.disallowed_tools
|
||||
sandbox = args.sandbox
|
||||
system_prompt = args.system_prompt
|
||||
|
||||
# parse and validate signature
|
||||
if isinstance(signature, str):
|
||||
self.signature = dspy.Signature(signature)
|
||||
else:
|
||||
self.signature = signature
|
||||
|
||||
# validate signature has exactly 1 input and 1 output
|
||||
input_fields = list(self.signature.input_fields.keys())
|
||||
output_fields = list(self.signature.output_fields.keys())
|
||||
|
||||
if len(input_fields) != 1:
|
||||
raise ValueError(
|
||||
f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. "
|
||||
f"Found: {input_fields}"
|
||||
)
|
||||
|
||||
if len(output_fields) != 1:
|
||||
raise ValueError(
|
||||
f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. "
|
||||
f"Found: {output_fields}"
|
||||
)
|
||||
|
||||
self.input_field_name = input_fields[0]
|
||||
self.output_field_name = output_fields[0]
|
||||
self.input_field = self.signature.input_fields[self.input_field_name]
|
||||
self.output_field = self.signature.output_fields[self.output_field_name]
|
||||
|
||||
# store config values
|
||||
self.working_directory = Path(working_directory).resolve()
|
||||
self.model = config.model
|
||||
self.permission_mode = permission_mode
|
||||
self.allowed_tools = allowed_tools
|
||||
self.disallowed_tools = disallowed_tools
|
||||
self.sandbox = sandbox
|
||||
self.system_prompt = system_prompt
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
# No extra options since all kwargs are parsed by ClaudeCodeKwargs
|
||||
self.extra_options = {}
|
||||
|
||||
# determine output format upfront
|
||||
self.output_format = self._get_output_format()
|
||||
|
||||
# session state
|
||||
self._client: Optional[ClaudeSDKClient] = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._is_connected = False
|
||||
|
||||
@property
|
||||
def session_id(self) -> Optional[str]:
|
||||
"""Get the session ID for this agent instance.
|
||||
|
||||
Returns None until first forward() call.
|
||||
"""
|
||||
return self._session_id
|
||||
|
||||
def _create_client(self) -> ClaudeSDKClient:
|
||||
"""Create ClaudeSDKClient with configured options."""
|
||||
options = ClaudeAgentOptions(
|
||||
cwd=str(self.working_directory),
|
||||
model=self.model,
|
||||
permission_mode=self.permission_mode,
|
||||
allowed_tools=self.allowed_tools or [],
|
||||
disallowed_tools=self.disallowed_tools or [],
|
||||
sandbox=self.sandbox,
|
||||
system_prompt=self.system_prompt,
|
||||
output_format=self.output_format, # include output format
|
||||
**self.extra_options,
|
||||
)
|
||||
|
||||
# set API key if provided
|
||||
if self.api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = self.api_key
|
||||
|
||||
return ClaudeSDKClient(options=options)
|
||||
|
||||
def _build_prompt(self, input_value: str) -> str:
|
||||
"""Build prompt from signature docstring, field descriptions, and input value."""
|
||||
prompt_parts = []
|
||||
|
||||
# add signature docstring if present
|
||||
if self.signature.__doc__:
|
||||
doc = self.signature.__doc__.strip()
|
||||
if doc:
|
||||
prompt_parts.append(f"Task: {doc}")
|
||||
|
||||
# add input field description if present
|
||||
# DSPy fields store desc in json_schema_extra
|
||||
input_desc = None
|
||||
if (
|
||||
hasattr(self.input_field, "json_schema_extra")
|
||||
and self.input_field.json_schema_extra
|
||||
):
|
||||
input_desc = self.input_field.json_schema_extra.get("desc")
|
||||
|
||||
if input_desc:
|
||||
prompt_parts.append(f"Input context: {input_desc}")
|
||||
|
||||
# add the actual input value
|
||||
prompt_parts.append(input_value)
|
||||
|
||||
# add output field description if present
|
||||
output_desc = None
|
||||
if (
|
||||
hasattr(self.output_field, "json_schema_extra")
|
||||
and self.output_field.json_schema_extra
|
||||
):
|
||||
output_desc = self.output_field.json_schema_extra.get("desc")
|
||||
|
||||
if output_desc:
|
||||
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
|
||||
|
||||
# for Pydantic outputs, add explicit JSON instructions
|
||||
if self.output_format:
|
||||
schema = self.output_format["schema"]
|
||||
prompt_parts.append(
|
||||
f"\nYou MUST respond with ONLY valid JSON matching this schema:\n"
|
||||
f"{json.dumps(schema, indent=2)}\n\n"
|
||||
f"Do not include any explanatory text, markdown formatting, or code blocks. "
|
||||
f"Return ONLY the raw JSON object."
|
||||
)
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
def _get_output_format(self) -> Optional[dict[str, Any]]:
|
||||
"""Get output format configuration for structured outputs."""
|
||||
output_type = self.output_field.annotation
|
||||
|
||||
if is_pydantic_model(output_type):
|
||||
schema = get_json_schema(output_type)
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
async def _run_async(self, prompt: str) -> tuple[str, list[TraceItem], Usage]:
|
||||
"""Run the agent asynchronously and collect results."""
|
||||
# create client if needed
|
||||
if self._client is None:
|
||||
self._client = self._create_client()
|
||||
|
||||
# connect if not already connected
|
||||
if not self._is_connected:
|
||||
await self._client.connect()
|
||||
self._is_connected = True
|
||||
|
||||
# send query (output_format already configured in options)
|
||||
await self._client.query(prompt)
|
||||
|
||||
# collect messages and build trace
|
||||
trace: list[TraceItem] = []
|
||||
usage = Usage()
|
||||
response_text = ""
|
||||
|
||||
async for message in self._client.receive_response():
|
||||
# handle assistant messages
|
||||
if isinstance(message, AssistantMessage):
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
response_text += block.text
|
||||
trace.append(
|
||||
AgentMessageItem(text=block.text, model=message.model)
|
||||
)
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
trace.append(
|
||||
ThinkingItem(text=block.thinking, model=message.model)
|
||||
)
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
# Handle StructuredOutput tool (contains JSON response)
|
||||
if block.name == "StructuredOutput":
|
||||
# The JSON is directly in the tool input (already a dict)
|
||||
response_text = json.dumps(block.input)
|
||||
|
||||
trace.append(
|
||||
ToolUseItem(
|
||||
tool_name=block.name,
|
||||
tool_input=block.input,
|
||||
tool_use_id=block.id,
|
||||
)
|
||||
)
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
content_str = ""
|
||||
if isinstance(block.content, str):
|
||||
content_str = block.content
|
||||
elif isinstance(block.content, list):
|
||||
# Extract text from content blocks
|
||||
for item in block.content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
):
|
||||
content_str += item.get("text", "")
|
||||
|
||||
trace.append(
|
||||
ToolResultItem(
|
||||
tool_name="", # Tool name not in ToolResultBlock
|
||||
tool_use_id=block.tool_use_id,
|
||||
content=content_str,
|
||||
is_error=block.is_error or False,
|
||||
)
|
||||
)
|
||||
|
||||
# handle result messages (final message with usage info)
|
||||
elif isinstance(message, ResultMessage):
|
||||
# store session ID
|
||||
if hasattr(message, "session_id"):
|
||||
self._session_id = message.session_id
|
||||
|
||||
# extract usage
|
||||
if hasattr(message, "usage") and message.usage:
|
||||
usage_data = message.usage
|
||||
usage = Usage(
|
||||
input_tokens=usage_data.get("input_tokens", 0),
|
||||
cached_input_tokens=usage_data.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
output_tokens=usage_data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
# check for errors
|
||||
if hasattr(message, "is_error") and message.is_error:
|
||||
error_msg = (
|
||||
message.result
|
||||
if hasattr(message, "result")
|
||||
else "Unknown error"
|
||||
)
|
||||
trace.append(
|
||||
ErrorItem(message=error_msg, error_type="execution_error")
|
||||
)
|
||||
raise RuntimeError(f"Agent execution failed: {error_msg}")
|
||||
|
||||
# extract result if present (for structured outputs from result field)
|
||||
# Note: structured outputs may come from StructuredOutput tool instead
|
||||
if hasattr(message, "result") and message.result:
|
||||
response_text = message.result
|
||||
|
||||
# handle system messages
|
||||
elif isinstance(message, SystemMessage):
|
||||
# log system messages to trace but don't error
|
||||
if hasattr(message, "data") and message.data:
|
||||
data_str = str(message.data)
|
||||
trace.append(
|
||||
AgentMessageItem(text=f"[System: {data_str}]", model="system")
|
||||
)
|
||||
|
||||
return response_text, trace, usage
|
||||
|
||||
def forward(self, **kwargs: Any) -> Prediction:
|
||||
"""Execute the agent with an input message.
|
||||
|
||||
Args:
|
||||
**kwargs: Must contain the input field specified in signature
|
||||
|
||||
Returns:
|
||||
Prediction with:
|
||||
- Typed output field (named according to signature)
|
||||
- trace: list[TraceItem] - Execution trace
|
||||
- usage: Usage - Token usage statistics
|
||||
|
||||
Example:
|
||||
>>> result = agent(message="Hello")
|
||||
>>> print(result.answer) # Access typed output
|
||||
>>> print(result.trace) # List of execution items
|
||||
>>> print(result.usage) # Token usage stats
|
||||
"""
|
||||
# extract input value
|
||||
if self.input_field_name not in kwargs:
|
||||
raise ValueError(
|
||||
f"Missing required input field: {self.input_field_name}. "
|
||||
f"Received: {list(kwargs.keys())}"
|
||||
)
|
||||
|
||||
input_value = kwargs[self.input_field_name]
|
||||
|
||||
# build prompt
|
||||
prompt = self._build_prompt(input_value)
|
||||
|
||||
# run async execution in event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If already in async context, create new loop
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
response_text, trace, usage = loop.run_until_complete(
|
||||
self._run_async(prompt)
|
||||
)
|
||||
else:
|
||||
response_text, trace, usage = loop.run_until_complete(
|
||||
self._run_async(prompt)
|
||||
)
|
||||
except RuntimeError:
|
||||
# no event loop, create one
|
||||
response_text, trace, usage = asyncio.run(self._run_async(prompt))
|
||||
|
||||
# parse response based on output type
|
||||
output_type = self.output_field.annotation
|
||||
if is_pydantic_model(output_type):
|
||||
try:
|
||||
parsed_output = parse_json_response(response_text, output_type)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
|
||||
f"Response: {response_text}"
|
||||
)
|
||||
else:
|
||||
# string output - extract text
|
||||
parsed_output = extract_text_from_response(response_text)
|
||||
|
||||
# return prediction with typed output, trace, and usage
|
||||
return Prediction(
|
||||
**{
|
||||
self.output_field_name: parsed_output,
|
||||
"trace": trace,
|
||||
"usage": usage,
|
||||
}
|
||||
)
|
||||
|
||||
async def aforward(self, **kwargs: Any) -> Prediction:
|
||||
"""Async version of forward().
|
||||
|
||||
Use this when already in an async context to avoid event loop issues.
|
||||
|
||||
Args:
|
||||
**kwargs: Must contain the input field specified in signature
|
||||
|
||||
Returns:
|
||||
Prediction with typed output, trace, and usage
|
||||
"""
|
||||
# extract input value
|
||||
if self.input_field_name not in kwargs:
|
||||
raise ValueError(
|
||||
f"Missing required input field: {self.input_field_name}. "
|
||||
f"Received: {list(kwargs.keys())}"
|
||||
)
|
||||
|
||||
input_value = kwargs[self.input_field_name]
|
||||
|
||||
# build prompt
|
||||
prompt = self._build_prompt(input_value)
|
||||
|
||||
# run async execution
|
||||
response_text, trace, usage = await self._run_async(prompt)
|
||||
|
||||
# parse response based on output type
|
||||
output_type = self.output_field.annotation
|
||||
if is_pydantic_model(output_type):
|
||||
try:
|
||||
parsed_output = parse_json_response(response_text, output_type)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
|
||||
f"Response: {response_text}"
|
||||
)
|
||||
else:
|
||||
# string output - extract text
|
||||
parsed_output = extract_text_from_response(response_text)
|
||||
|
||||
# return prediction with typed output, trace, and usage
|
||||
return Prediction(
|
||||
**{
|
||||
self.output_field_name: parsed_output,
|
||||
"trace": trace,
|
||||
"usage": usage,
|
||||
}
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Claude Code and clean up resources."""
|
||||
if self._client and self._is_connected:
|
||||
await self._client.disconnect()
|
||||
self._is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
if self._client and self._is_connected:
|
||||
try:
|
||||
asyncio.run(self.disconnect())
|
||||
except Exception:
|
||||
# best effort cleanup
|
||||
pass
|
||||
Reference in New Issue
Block a user