move connect debug prints
This commit is contained in:
19
LICENSE
Normal file
19
LICENSE
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"AutoConfig": "claude_dspy.agent.ClaudeCodeConfig",
|
||||||
|
"AutoProgram": "claude_dspy.agent.ClaudeCode"
|
||||||
|
}
|
||||||
24
claude_dspy/__init__.py
Normal file
24
claude_dspy/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from .agent import ClaudeCode, ClaudeCodeConfig
|
||||||
|
from .trace import (
|
||||||
|
TraceItem,
|
||||||
|
AgentMessageItem,
|
||||||
|
ThinkingItem,
|
||||||
|
ToolUseItem,
|
||||||
|
ToolResultItem,
|
||||||
|
ErrorItem,
|
||||||
|
)
|
||||||
|
from .utils import Usage
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClaudeCode",
|
||||||
|
"ClaudeCodeConfig",
|
||||||
|
"TraceItem",
|
||||||
|
"AgentMessageItem",
|
||||||
|
"ThinkingItem",
|
||||||
|
"ToolUseItem",
|
||||||
|
"ToolResultItem",
|
||||||
|
"ErrorItem",
|
||||||
|
"Usage",
|
||||||
|
]
|
||||||
658
claude_dspy/agent.py
Normal file
658
claude_dspy/agent.py
Normal file
@@ -0,0 +1,658 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
from dspy.primitives.prediction import Prediction
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
ClaudeSDKClient,
|
||||||
|
ClaudeAgentOptions,
|
||||||
|
AssistantMessage,
|
||||||
|
ResultMessage,
|
||||||
|
SystemMessage,
|
||||||
|
TextBlock,
|
||||||
|
ThinkingBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .trace import (
|
||||||
|
TraceItem,
|
||||||
|
AgentMessageItem,
|
||||||
|
ThinkingItem,
|
||||||
|
ToolUseItem,
|
||||||
|
ToolResultItem,
|
||||||
|
ErrorItem,
|
||||||
|
)
|
||||||
|
from .utils import (
|
||||||
|
Usage,
|
||||||
|
is_pydantic_model,
|
||||||
|
get_json_schema,
|
||||||
|
parse_json_response,
|
||||||
|
extract_text_from_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeConfig(PrecompiledConfig):
|
||||||
|
"""Configuration for ClaudeCode agent."""
|
||||||
|
|
||||||
|
model: str = "claude-opus-4-5-20251101"
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeKwargs(BaseModel):
|
||||||
|
"""Arguments for ClaudeCode initialization.
|
||||||
|
|
||||||
|
Matches ClaudeAgentOptions from the SDK with additional DSPy-specific fields.
|
||||||
|
See: https://platform.claude.com/docs/en/agent-sdk/python#claudeagentoptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DSPy-specific (required)
|
||||||
|
signature: Any # str | dspy.Signature - validated manually in __init__
|
||||||
|
|
||||||
|
# auth
|
||||||
|
api_key: str | None = None
|
||||||
|
|
||||||
|
# basic config
|
||||||
|
working_directory: str = "."
|
||||||
|
permission_mode: str | None = None
|
||||||
|
allowed_tools: list[str] | None = None # Any Claude Code tools
|
||||||
|
disallowed_tools: list[str] | None = None
|
||||||
|
sandbox: dict[str, Any] | None = None
|
||||||
|
system_prompt: str | dict[str, Any] | None = None
|
||||||
|
|
||||||
|
# mcp servers
|
||||||
|
mcp_servers: dict[str, Any] | str | Path | None = None
|
||||||
|
|
||||||
|
# session management
|
||||||
|
continue_conversation: bool = False
|
||||||
|
resume: str | None = None
|
||||||
|
max_turns: int | None = None
|
||||||
|
fork_session: bool = False
|
||||||
|
|
||||||
|
# advanced options
|
||||||
|
permission_prompt_tool_name: str | None = None
|
||||||
|
settings: str | None = None
|
||||||
|
add_dirs: list[str | Path] | None = None
|
||||||
|
env: dict[str, str] | None = None
|
||||||
|
extra_args: dict[str, str | None] | None = None
|
||||||
|
max_buffer_size: int | None = None
|
||||||
|
|
||||||
|
# callbacks and hooks
|
||||||
|
stderr: Any | None = (
|
||||||
|
None # Callable[[str], None] - can't type check callables in Pydantic easily
|
||||||
|
)
|
||||||
|
can_use_tool: Any | None = None # CanUseTool callback
|
||||||
|
hooks: dict[str, list[dict[str, Any]]] | None = None
|
||||||
|
|
||||||
|
# user and settings
|
||||||
|
user: str | None = None
|
||||||
|
include_partial_messages: bool = False
|
||||||
|
setting_sources: list[str] | None = None # List of "user" | "project" | "local"
|
||||||
|
|
||||||
|
# subagents and plugins
|
||||||
|
agents: dict[str, dict[str, Any]] | None = None
|
||||||
|
plugins: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
# cli configuration
|
||||||
|
cli_path: str | Path | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCode(PrecompiledProgram):
|
||||||
|
"""DSPy module that wraps Claude Code SDK.
|
||||||
|
|
||||||
|
Each agent instance maintains a stateful conversation session.
|
||||||
|
Perfect for multi-turn agentic workflows with context preservation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: ClaudeCodeConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ClaudeCodeConfig,
|
||||||
|
**kwargs: dict,
|
||||||
|
):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
args = ClaudeCodeKwargs(**kwargs)
|
||||||
|
|
||||||
|
# validate signature
|
||||||
|
# Note: Raw string signatures only work with built-in types.
|
||||||
|
# For custom Pydantic models, users must pass:
|
||||||
|
# 1. A class-based signature, OR
|
||||||
|
# 2. Pre-constructed dspy.Signature (in their module where types are defined)
|
||||||
|
signature = args.signature
|
||||||
|
if isinstance(signature, str):
|
||||||
|
try:
|
||||||
|
self.signature = dspy.Signature(signature)
|
||||||
|
except ValueError as e:
|
||||||
|
if "Unknown name:" in str(e):
|
||||||
|
type_name = str(e).split("Unknown name: ")[-1]
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot resolve type '{type_name}' in string signature.\n"
|
||||||
|
f"String signatures only work with built-in types (str, int, list[str], etc.).\n\n"
|
||||||
|
f"For custom Pydantic models, use one of these approaches:\n\n"
|
||||||
|
f"Option 1 - Class-based signature (recommended):\n"
|
||||||
|
f" class MySignature(dspy.Signature):\n"
|
||||||
|
f" input: str = dspy.InputField()\n"
|
||||||
|
f" output: {type_name} = dspy.OutputField()\n"
|
||||||
|
f" agent = ClaudeCode(config, signature=MySignature, ...)\n\n"
|
||||||
|
f"Option 2 - Pre-construct signature in your module:\n"
|
||||||
|
f" sig = dspy.Signature('{signature}')\n"
|
||||||
|
f" agent = ClaudeCode(config, signature=sig, ...)\n"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self.signature = signature
|
||||||
|
|
||||||
|
# validate signature has exactly 1 input and 1 output TODO: support multiple inputs/outputs
|
||||||
|
input_fields = list(self.signature.input_fields.keys())
|
||||||
|
output_fields = list(self.signature.output_fields.keys())
|
||||||
|
|
||||||
|
if len(input_fields) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"ClaudeCode requires exactly 1 input field, got {len(input_fields)}. "
|
||||||
|
f"Found: {input_fields}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(output_fields) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"ClaudeCode requires exactly 1 output field, got {len(output_fields)}. "
|
||||||
|
f"Found: {output_fields}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_field_name = input_fields[0]
|
||||||
|
self.output_field_name = output_fields[0]
|
||||||
|
self.input_field = self.signature.input_fields[self.input_field_name]
|
||||||
|
self.output_field = self.signature.output_fields[self.output_field_name]
|
||||||
|
|
||||||
|
# store all configuration values
|
||||||
|
self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
self.working_directory = Path(args.working_directory).resolve()
|
||||||
|
self.model = config.model
|
||||||
|
|
||||||
|
# basic options
|
||||||
|
self.permission_mode = args.permission_mode
|
||||||
|
self.allowed_tools = args.allowed_tools
|
||||||
|
self.disallowed_tools = args.disallowed_tools
|
||||||
|
self.sandbox = args.sandbox
|
||||||
|
self.system_prompt = args.system_prompt
|
||||||
|
|
||||||
|
# mcp servers
|
||||||
|
self.mcp_servers = args.mcp_servers
|
||||||
|
|
||||||
|
# session management
|
||||||
|
self.continue_conversation = args.continue_conversation
|
||||||
|
self.resume = args.resume
|
||||||
|
self.max_turns = args.max_turns
|
||||||
|
self.fork_session = args.fork_session
|
||||||
|
|
||||||
|
# advanced options
|
||||||
|
self.permission_prompt_tool_name = args.permission_prompt_tool_name
|
||||||
|
self.settings = args.settings
|
||||||
|
self.add_dirs = args.add_dirs
|
||||||
|
self.env = args.env
|
||||||
|
self.extra_args = args.extra_args
|
||||||
|
self.max_buffer_size = args.max_buffer_size
|
||||||
|
|
||||||
|
# callbacks and hooks
|
||||||
|
self.stderr = args.stderr
|
||||||
|
self.can_use_tool = args.can_use_tool
|
||||||
|
self.hooks = args.hooks
|
||||||
|
|
||||||
|
# user and settings
|
||||||
|
self.user = args.user
|
||||||
|
self.include_partial_messages = args.include_partial_messages
|
||||||
|
self.setting_sources = args.setting_sources
|
||||||
|
|
||||||
|
# subagents and plugins
|
||||||
|
self.agents = args.agents
|
||||||
|
self.plugins = args.plugins
|
||||||
|
|
||||||
|
# cli configuration
|
||||||
|
self.cli_path = args.cli_path
|
||||||
|
|
||||||
|
# determine output format upfront
|
||||||
|
self.output_format = self._get_output_format()
|
||||||
|
|
||||||
|
# session state
|
||||||
|
self._client: Optional[ClaudeSDKClient] = None
|
||||||
|
self._session_id: Optional[str] = None
|
||||||
|
self._is_connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_id(self) -> Optional[str]:
|
||||||
|
"""Get the session ID for this agent instance.
|
||||||
|
|
||||||
|
Returns None until first forward() call.
|
||||||
|
"""
|
||||||
|
return self._session_id
|
||||||
|
|
||||||
|
def _create_client(self) -> ClaudeSDKClient:
|
||||||
|
"""Create ClaudeSDKClient with configured options."""
|
||||||
|
# build options dict, only including non-None values
|
||||||
|
options_dict = {
|
||||||
|
"cwd": str(self.working_directory),
|
||||||
|
"model": self.model,
|
||||||
|
"output_format": self.output_format,
|
||||||
|
}
|
||||||
|
|
||||||
|
# add optional fields only if they're not None
|
||||||
|
if self.permission_mode is not None:
|
||||||
|
options_dict["permission_mode"] = self.permission_mode
|
||||||
|
if self.allowed_tools is not None:
|
||||||
|
options_dict["allowed_tools"] = self.allowed_tools
|
||||||
|
if self.disallowed_tools is not None:
|
||||||
|
options_dict["disallowed_tools"] = self.disallowed_tools
|
||||||
|
if self.sandbox is not None:
|
||||||
|
options_dict["sandbox"] = self.sandbox
|
||||||
|
if self.system_prompt is not None:
|
||||||
|
options_dict["system_prompt"] = self.system_prompt
|
||||||
|
if self.mcp_servers is not None:
|
||||||
|
options_dict["mcp_servers"] = self.mcp_servers
|
||||||
|
if self.continue_conversation:
|
||||||
|
options_dict["continue_conversation"] = self.continue_conversation
|
||||||
|
if self.resume is not None:
|
||||||
|
options_dict["resume"] = self.resume
|
||||||
|
if self.max_turns is not None:
|
||||||
|
options_dict["max_turns"] = self.max_turns
|
||||||
|
if self.fork_session:
|
||||||
|
options_dict["fork_session"] = self.fork_session
|
||||||
|
if self.permission_prompt_tool_name is not None:
|
||||||
|
options_dict["permission_prompt_tool_name"] = (
|
||||||
|
self.permission_prompt_tool_name
|
||||||
|
)
|
||||||
|
if self.settings is not None:
|
||||||
|
options_dict["settings"] = self.settings
|
||||||
|
if self.add_dirs is not None:
|
||||||
|
options_dict["add_dirs"] = self.add_dirs
|
||||||
|
if self.env is not None:
|
||||||
|
options_dict["env"] = self.env
|
||||||
|
if self.extra_args is not None:
|
||||||
|
options_dict["extra_args"] = self.extra_args
|
||||||
|
if self.max_buffer_size is not None:
|
||||||
|
options_dict["max_buffer_size"] = self.max_buffer_size
|
||||||
|
if self.stderr is not None:
|
||||||
|
options_dict["stderr"] = self.stderr
|
||||||
|
if self.can_use_tool is not None:
|
||||||
|
options_dict["can_use_tool"] = self.can_use_tool
|
||||||
|
if self.hooks is not None:
|
||||||
|
options_dict["hooks"] = self.hooks
|
||||||
|
if self.user is not None:
|
||||||
|
options_dict["user"] = self.user
|
||||||
|
if self.include_partial_messages:
|
||||||
|
options_dict["include_partial_messages"] = self.include_partial_messages
|
||||||
|
if self.setting_sources is not None:
|
||||||
|
options_dict["setting_sources"] = self.setting_sources
|
||||||
|
if self.agents is not None:
|
||||||
|
options_dict["agents"] = self.agents
|
||||||
|
if self.plugins is not None:
|
||||||
|
options_dict["plugins"] = self.plugins
|
||||||
|
if self.cli_path is not None:
|
||||||
|
options_dict["cli_path"] = self.cli_path
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(**options_dict)
|
||||||
|
|
||||||
|
# set API key if provided
|
||||||
|
if self.api_key:
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = self.api_key
|
||||||
|
|
||||||
|
return ClaudeSDKClient(options=options)
|
||||||
|
|
||||||
|
def _build_prompt(self, input_value: str) -> str:
|
||||||
|
"""Build prompt from signature docstring, field descriptions, and input value.
|
||||||
|
|
||||||
|
Note: When using structured outputs, the SDK handles JSON formatting automatically
|
||||||
|
via the output_format parameter, so we don't add JSON instructions to the prompt.
|
||||||
|
"""
|
||||||
|
prompt_parts = []
|
||||||
|
|
||||||
|
# add signature docstring if present
|
||||||
|
if self.signature.__doc__:
|
||||||
|
doc = self.signature.__doc__.strip()
|
||||||
|
if doc:
|
||||||
|
prompt_parts.append(f"Task: {doc}")
|
||||||
|
|
||||||
|
# add input field description if present
|
||||||
|
# DSPy fields store desc in json_schema_extra
|
||||||
|
input_desc = None
|
||||||
|
if (
|
||||||
|
hasattr(self.input_field, "json_schema_extra")
|
||||||
|
and self.input_field.json_schema_extra
|
||||||
|
):
|
||||||
|
input_desc = self.input_field.json_schema_extra.get("desc")
|
||||||
|
|
||||||
|
# add the actual input value
|
||||||
|
prompt_parts.append(f"{self.input_field_name}: {input_value}")
|
||||||
|
|
||||||
|
if input_desc:
|
||||||
|
prompt_parts.append(f"({input_desc})")
|
||||||
|
|
||||||
|
# add output field description if present
|
||||||
|
output_desc = None
|
||||||
|
if (
|
||||||
|
hasattr(self.output_field, "json_schema_extra")
|
||||||
|
and self.output_field.json_schema_extra
|
||||||
|
):
|
||||||
|
output_desc = self.output_field.json_schema_extra.get("desc")
|
||||||
|
|
||||||
|
if output_desc:
|
||||||
|
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
|
||||||
|
|
||||||
|
# the schema is passed through ClaudeAgentOptions and enforced by the SDK
|
||||||
|
|
||||||
|
return "\n\n".join(prompt_parts)
|
||||||
|
|
||||||
|
def _get_output_format(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Get output format configuration for structured outputs.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Direct Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
"""
|
||||||
|
output_type = self.output_field.annotation
|
||||||
|
|
||||||
|
if is_pydantic_model(output_type):
|
||||||
|
schema = get_json_schema(output_type)
|
||||||
|
return {
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _run_async(
|
||||||
|
self, prompt: str
|
||||||
|
) -> tuple[str | dict | list | None, list[TraceItem], Usage]:
|
||||||
|
"""Run the agent asynchronously and collect results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- response: For structured outputs, returns dict/list from structured_output.
|
||||||
|
For text outputs, returns string from result or text blocks.
|
||||||
|
- trace: Execution trace items
|
||||||
|
- usage: Token usage statistics
|
||||||
|
"""
|
||||||
|
# create client if needed
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._create_client()
|
||||||
|
|
||||||
|
# connect if not already connected
|
||||||
|
if not self._is_connected:
|
||||||
|
await self._client.connect()
|
||||||
|
self._is_connected = True
|
||||||
|
print(
|
||||||
|
f"[ClaudeCode._run_async] Client connected (connected={self._is_connected})"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._client.query(prompt)
|
||||||
|
print(f"[ClaudeCode._run_async] Query sent, waiting for response...")
|
||||||
|
|
||||||
|
# collect messages and build trace
|
||||||
|
trace: list[TraceItem] = []
|
||||||
|
usage = Usage()
|
||||||
|
response_text = ""
|
||||||
|
structured_output = None
|
||||||
|
message_count = 0
|
||||||
|
|
||||||
|
async for message in self._client.receive_response():
|
||||||
|
message_count += 1
|
||||||
|
|
||||||
|
# handle assistant messages
|
||||||
|
if isinstance(message, AssistantMessage):
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
response_text += block.text
|
||||||
|
trace.append(
|
||||||
|
AgentMessageItem(text=block.text, model=message.model)
|
||||||
|
)
|
||||||
|
elif isinstance(block, ThinkingBlock):
|
||||||
|
trace.append(
|
||||||
|
ThinkingItem(text=block.thinking, model=message.model)
|
||||||
|
)
|
||||||
|
elif isinstance(block, ToolUseBlock):
|
||||||
|
# handle StructuredOutput tool (contains JSON response)
|
||||||
|
if block.name == "StructuredOutput":
|
||||||
|
# the JSON is directly in the tool input (already a dict)
|
||||||
|
response_text = json.dumps(block.input)
|
||||||
|
|
||||||
|
trace.append(
|
||||||
|
ToolUseItem(
|
||||||
|
tool_name=block.name,
|
||||||
|
tool_input=block.input,
|
||||||
|
tool_use_id=block.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(block, ToolResultBlock):
|
||||||
|
content_str = ""
|
||||||
|
if isinstance(block.content, str):
|
||||||
|
content_str = block.content
|
||||||
|
elif isinstance(block.content, list):
|
||||||
|
# extract text from content blocks
|
||||||
|
for item in block.content:
|
||||||
|
if (
|
||||||
|
isinstance(item, dict)
|
||||||
|
and item.get("type") == "text"
|
||||||
|
):
|
||||||
|
content_str += item.get("text", "")
|
||||||
|
|
||||||
|
trace.append(
|
||||||
|
ToolResultItem(
|
||||||
|
tool_name="", # tool name not in ToolResultBlock
|
||||||
|
tool_use_id=block.tool_use_id,
|
||||||
|
content=content_str,
|
||||||
|
is_error=block.is_error or False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle result messages (final message with usage info)
|
||||||
|
elif isinstance(message, ResultMessage):
|
||||||
|
# store session ID
|
||||||
|
if hasattr(message, "session_id"):
|
||||||
|
self._session_id = message.session_id
|
||||||
|
print(f"[ClaudeCode._run_async] - Session ID: {self._session_id}")
|
||||||
|
|
||||||
|
# extract usage
|
||||||
|
if hasattr(message, "usage") and message.usage:
|
||||||
|
usage_data = message.usage
|
||||||
|
usage = Usage(
|
||||||
|
input_tokens=usage_data.get("input_tokens", 0),
|
||||||
|
cached_input_tokens=usage_data.get(
|
||||||
|
"cache_read_input_tokens", 0
|
||||||
|
),
|
||||||
|
output_tokens=usage_data.get("output_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check for errors
|
||||||
|
if hasattr(message, "is_error") and message.is_error:
|
||||||
|
error_msg = (
|
||||||
|
message.result
|
||||||
|
if hasattr(message, "result")
|
||||||
|
else "Unknown error"
|
||||||
|
)
|
||||||
|
trace.append(
|
||||||
|
ErrorItem(message=error_msg, error_type="execution_error")
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"Agent execution failed: {error_msg}")
|
||||||
|
|
||||||
|
# prefer structured_output over result (when using output_format)
|
||||||
|
if (
|
||||||
|
hasattr(message, "structured_output")
|
||||||
|
and message.structured_output is not None
|
||||||
|
):
|
||||||
|
structured_output = message.structured_output
|
||||||
|
# fallback to result field for text outputs
|
||||||
|
elif hasattr(message, "result") and message.result:
|
||||||
|
response_text = message.result
|
||||||
|
|
||||||
|
# handle system messages
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
# log system messages to trace but don't error
|
||||||
|
if hasattr(message, "data") and message.data:
|
||||||
|
data_str = str(message.data)
|
||||||
|
trace.append(
|
||||||
|
AgentMessageItem(text=f"[System: {data_str}]", model="system")
|
||||||
|
)
|
||||||
|
|
||||||
|
# return structured_output if available (for Pydantic outputs), otherwise text
|
||||||
|
if structured_output is not None:
|
||||||
|
return structured_output, trace, usage
|
||||||
|
else:
|
||||||
|
return response_text, trace, usage
|
||||||
|
|
||||||
|
def forward(self, **kwargs: Any) -> Prediction:
|
||||||
|
"""Execute the agent with an input message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Must contain the input field specified in signature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prediction with:
|
||||||
|
- Typed output field (named according to signature)
|
||||||
|
- trace: list[TraceItem] - Execution trace
|
||||||
|
- usage: Usage - Token usage statistics
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = agent(message="Hello")
|
||||||
|
>>> print(result.answer) # Access typed output
|
||||||
|
>>> print(result.trace) # List of execution items
|
||||||
|
>>> print(result.usage) # Token usage stats
|
||||||
|
"""
|
||||||
|
# extract input value
|
||||||
|
if self.input_field_name not in kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required input field: {self.input_field_name}. "
|
||||||
|
f"Received: {list(kwargs.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_value = kwargs[self.input_field_name]
|
||||||
|
|
||||||
|
# build prompt
|
||||||
|
prompt = self._build_prompt(input_value)
|
||||||
|
print(prompt)
|
||||||
|
# run async execution in event loop
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# If already in async context, create new loop
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
response_text, trace, usage = loop.run_until_complete(
|
||||||
|
self._run_async(prompt)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response_text, trace, usage = loop.run_until_complete(
|
||||||
|
self._run_async(prompt)
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
# no event loop, create one
|
||||||
|
response_text, trace, usage = asyncio.run(self._run_async(prompt))
|
||||||
|
|
||||||
|
# parse response based on output type
|
||||||
|
output_type = self.output_field.annotation
|
||||||
|
if is_pydantic_model(output_type):
|
||||||
|
try:
|
||||||
|
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||||
|
parsed_output = parse_json_response(response_text, output_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse Claude response as {output_type}: {e}\n"
|
||||||
|
f"Response type: {type(response_text)}\n"
|
||||||
|
f"Response: {response_text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# string output - extract text
|
||||||
|
if isinstance(response_text, str):
|
||||||
|
parsed_output = extract_text_from_response(response_text)
|
||||||
|
else:
|
||||||
|
# Shouldn't happen, but handle gracefully
|
||||||
|
parsed_output = str(response_text)
|
||||||
|
|
||||||
|
|
||||||
|
# return prediction with typed output, trace, and usage
|
||||||
|
return Prediction(
|
||||||
|
**{
|
||||||
|
self.output_field_name: parsed_output,
|
||||||
|
"trace": trace,
|
||||||
|
"usage": usage,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aforward(self, **kwargs: Any) -> Prediction:
|
||||||
|
"""Async version of forward().
|
||||||
|
|
||||||
|
Use this when already in an async context to avoid event loop issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Must contain the input field specified in signature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prediction with typed output, trace, and usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
# extract input value
|
||||||
|
if self.input_field_name not in kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required input field: {self.input_field_name}. "
|
||||||
|
f"Received: {list(kwargs.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_value = kwargs[self.input_field_name]
|
||||||
|
|
||||||
|
# build prompt
|
||||||
|
prompt = self._build_prompt(input_value)
|
||||||
|
|
||||||
|
# run async execution
|
||||||
|
response_text, trace, usage = await self._run_async(prompt)
|
||||||
|
|
||||||
|
# parse response based on output type
|
||||||
|
output_type = self.output_field.annotation
|
||||||
|
if is_pydantic_model(output_type):
|
||||||
|
try:
|
||||||
|
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||||
|
parsed_output = parse_json_response(response_text, output_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse Claude response as {output_type}: {e}\n"
|
||||||
|
f"Response type: {type(response_text)}\n"
|
||||||
|
f"Response: {response_text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# string output - extract text
|
||||||
|
if isinstance(response_text, str):
|
||||||
|
parsed_output = extract_text_from_response(response_text)
|
||||||
|
else:
|
||||||
|
# Shouldn't happen, but handle gracefully
|
||||||
|
parsed_output = str(response_text)
|
||||||
|
|
||||||
|
# return prediction with typed output, trace, and usage
|
||||||
|
return Prediction(
|
||||||
|
**{
|
||||||
|
self.output_field_name: parsed_output,
|
||||||
|
"trace": trace,
|
||||||
|
"usage": usage,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from Claude Code and clean up resources."""
|
||||||
|
if self._client and self._is_connected:
|
||||||
|
await self._client.disconnect()
|
||||||
|
self._is_connected = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup on deletion."""
|
||||||
|
# Check attributes exist before accessing (may fail during __init__)
|
||||||
|
if hasattr(self, "_client") and hasattr(self, "_is_connected"):
|
||||||
|
if self._client and self._is_connected:
|
||||||
|
try:
|
||||||
|
asyncio.run(self.disconnect())
|
||||||
|
except Exception:
|
||||||
|
# best effort cleanup
|
||||||
|
pass
|
||||||
|
|
||||||
52
claude_dspy/trace.py
Normal file
52
claude_dspy/trace.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceItem:
|
||||||
|
"""Base class for trace items."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentMessageItem(TraceItem):
|
||||||
|
"""Agent's text response."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThinkingItem(TraceItem):
|
||||||
|
"""Agent's internal reasoning (extended thinking)."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolUseItem(TraceItem):
|
||||||
|
"""Tool invocation request."""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_input: dict[str, Any]
|
||||||
|
tool_use_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResultItem(TraceItem):
|
||||||
|
"""Tool execution result."""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_use_id: str
|
||||||
|
content: str | list[dict[str, Any]] | None = None
|
||||||
|
is_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ErrorItem(TraceItem):
|
||||||
|
"""Error that occurred during execution."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
error_type: str | None = None
|
||||||
214
claude_dspy/utils.py
Normal file
214
claude_dspy/utils.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, get_origin, get_args
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Usage:
|
||||||
|
"""Token usage statistics."""
|
||||||
|
|
||||||
|
input_tokens: int = 0
|
||||||
|
cached_input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Total tokens used (input + output)."""
|
||||||
|
return self.input_tokens + self.output_tokens
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Usage(input={self.input_tokens}, "
|
||||||
|
f"cached={self.cached_input_tokens}, "
|
||||||
|
f"output={self.output_tokens}, "
|
||||||
|
f"total={self.total_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pydantic_model(type_hint: Any) -> bool:
|
||||||
|
"""Check if a type hint is a Pydantic model or contains one (e.g., list[Model]).
|
||||||
|
|
||||||
|
Returns True for:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# direct Pydantic model
|
||||||
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# generic type (list, dict, etc.)
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
if origin is not None:
|
||||||
|
args = get_args(type_hint)
|
||||||
|
# check if any type argument is a Pydantic model
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||||
|
"""Generate JSON schema from type hint.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
|
||||||
|
Note: Claude API requires root type to be "object" for structured outputs (tools).
|
||||||
|
For list/dict types, we wrap them in an object with a single property.
|
||||||
|
|
||||||
|
Sets additionalProperties to false for all objects.
|
||||||
|
"""
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
|
# handle generic types (list, dict, etc.)
|
||||||
|
if origin is list:
|
||||||
|
# list[Model] - wrap in object since API requires root type = "object"
|
||||||
|
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
||||||
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||||
|
model = args[0]
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"items": {"type": "array", "items": model.model_json_schema()}
|
||||||
|
},
|
||||||
|
"required": ["items"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported list type: {type_hint}")
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
# dict[str, Model] - wrap in object since API requires root type = "object"
|
||||||
|
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
||||||
|
if (
|
||||||
|
len(args) >= 2
|
||||||
|
and isinstance(args[1], type)
|
||||||
|
and issubclass(args[1], BaseModel)
|
||||||
|
):
|
||||||
|
model = args[1]
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"values": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": model.model_json_schema(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["values"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dict type: {type_hint}")
|
||||||
|
|
||||||
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
# direct Pydantic model - already an object
|
||||||
|
schema = type_hint.model_json_schema()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
||||||
|
|
||||||
|
# recursively set additionalProperties: false for all nested objects
|
||||||
|
def set_additional_properties(obj: dict[str, Any]) -> None:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
||||||
|
obj["additionalProperties"] = False
|
||||||
|
for value in obj.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
set_additional_properties(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
set_additional_properties(item)
|
||||||
|
|
||||||
|
set_additional_properties(schema)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_response(
|
||||||
|
response: str | dict | list, type_hint: Any
|
||||||
|
) -> BaseModel | list[BaseModel] | dict[str, BaseModel]:
|
||||||
|
"""Parse JSON response into typed output.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
|
||||||
|
Note: When schema has list/dict at root, the SDK wraps them in {"items": [...]}
|
||||||
|
or {"values": {...}} because API requires root type = "object".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: JSON string or already-parsed dict/list from structured_output
|
||||||
|
type_hint: The output type annotation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated and typed output
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.JSONDecodeError: If response string is not valid JSON
|
||||||
|
pydantic.ValidationError: If JSON doesn't match schema
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
|
# parse string to dict/list if needed
|
||||||
|
if isinstance(response, str):
|
||||||
|
parsed = json.loads(response)
|
||||||
|
else:
|
||||||
|
parsed = response
|
||||||
|
|
||||||
|
# handle list[Model]
|
||||||
|
if origin is list:
|
||||||
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||||
|
model = args[0]
|
||||||
|
|
||||||
|
# unwrap from {"items": [...]} if present (from structured_output)
|
||||||
|
if isinstance(parsed, dict) and "items" in parsed:
|
||||||
|
parsed = parsed["items"]
|
||||||
|
|
||||||
|
if not isinstance(parsed, list):
|
||||||
|
raise ValueError(f"Expected list, got {type(parsed)}")
|
||||||
|
return [model.model_validate(item) for item in parsed]
|
||||||
|
|
||||||
|
# handle dict[str, Model]
|
||||||
|
elif origin is dict:
|
||||||
|
if (
|
||||||
|
len(args) >= 2
|
||||||
|
and isinstance(args[1], type)
|
||||||
|
and issubclass(args[1], BaseModel)
|
||||||
|
):
|
||||||
|
model = args[1]
|
||||||
|
|
||||||
|
# unwrap from {"values": {...}} if present (from structured_output)
|
||||||
|
if isinstance(parsed, dict) and "values" in parsed:
|
||||||
|
parsed = parsed["values"]
|
||||||
|
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError(f"Expected dict, got {type(parsed)}")
|
||||||
|
return {key: model.model_validate(value) for key, value in parsed.items()}
|
||||||
|
|
||||||
|
# handle direct Pydantic model
|
||||||
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
if isinstance(response, str):
|
||||||
|
return type_hint.model_validate_json(response)
|
||||||
|
else:
|
||||||
|
return type_hint.model_validate(parsed)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported type for parsing: {type_hint}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_response(response: str) -> str:
|
||||||
|
"""Extract plain text from response.
|
||||||
|
|
||||||
|
For string outputs, we just return the text as-is.
|
||||||
|
Claude Code may wrap responses in markdown or other formatting.
|
||||||
|
"""
|
||||||
|
return response.strip()
|
||||||
3
config.json
Normal file
3
config.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"model": "claude-opus-4-5-20251101"
|
||||||
|
}
|
||||||
47
main.py
Normal file
47
main.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from claude_dspy import ClaudeCode, ClaudeCodeConfig
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Literal
|
||||||
|
from modaic import AutoProgram
|
||||||
|
import dspy
|
||||||
|
|
||||||
|
class AffectedFile(BaseModel):
|
||||||
|
file_name: str = Field(..., description="Name of the file")
|
||||||
|
action: Literal["created", "updated", "deleted", "renamed"] = Field(..., description="Action taken on the file")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputResult(BaseModel):
|
||||||
|
success: bool = Field(..., description="Whether or not execution of the query was successful")
|
||||||
|
message: str = Field(..., description="Message")
|
||||||
|
affected_files: list[AffectedFile] = Field(..., description="List of files affected by the query")
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeSignature(dspy.Signature):
|
||||||
|
query: str = dspy.InputField(desc="Query to process")
|
||||||
|
output: OutputResult = dspy.OutputField(desc="Result of the query")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# create config
|
||||||
|
config = ClaudeCodeConfig()
|
||||||
|
|
||||||
|
# create agent
|
||||||
|
cc = ClaudeCode(
|
||||||
|
config,
|
||||||
|
signature=ClaudeCodeSignature,
|
||||||
|
working_directory=".",
|
||||||
|
permission_mode="acceptEdits",
|
||||||
|
allowed_tools=["Read", "Write", "Bash"],
|
||||||
|
)
|
||||||
|
result = cc(query="list the files in this directory")
|
||||||
|
output = result.output
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
cc.push_to_hub(
|
||||||
|
"modaic/claude-code",
|
||||||
|
with_code=True,
|
||||||
|
commit_message="move connect debug prints",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
9
program.json
Normal file
9
program.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"dependency_versions": {
|
||||||
|
"python": "3.13",
|
||||||
|
"dspy": "3.0.4",
|
||||||
|
"cloudpickle": "3.1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
13
pyproject.toml
Normal file
13
pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
[project]
|
||||||
|
name = "claude-code"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Claude Code SDK wrapped in a DSPy module"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = ["claude-agent-sdk>=0.1.12", "dspy>=3.0.4", "modaic>=0.8.2", "nest-asyncio>=1.6.0", "pydantic>=2.0.0"]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user